diff --git a/.gitattributes b/.gitattributes
index a6344aac8c09253b3b630fb776ae94478aa0275b..24b0494a9f15e592e3a494a41e71a1204821d026 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -33,3 +33,33 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
+diffusers/src/diffusers/models/__pycache__/attention_processor.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
+diffusers/src/diffusers/models/unets/__pycache__/unet_2d_blocks.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
+diffusers/src/diffusers/pipelines/__pycache__/pipeline_utils.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
+examples/line/example0/input.png filter=lfs diff=lfs merge=lfs -text
+examples/line/example0/reference_image_0.png filter=lfs diff=lfs merge=lfs -text
+examples/line/example0/reference_image_1.png filter=lfs diff=lfs merge=lfs -text
+examples/line/example0/reference_image_2.png filter=lfs diff=lfs merge=lfs -text
+examples/line/example1/input.png filter=lfs diff=lfs merge=lfs -text
+examples/line/example1/reference_image_0.png filter=lfs diff=lfs merge=lfs -text
+examples/line/example2/input.png filter=lfs diff=lfs merge=lfs -text
+examples/line/example2/reference_image_0.png filter=lfs diff=lfs merge=lfs -text
+examples/line/example2/reference_image_1.png filter=lfs diff=lfs merge=lfs -text
+examples/line/example2/reference_image_2.png filter=lfs diff=lfs merge=lfs -text
+examples/line/example2/reference_image_3.png filter=lfs diff=lfs merge=lfs -text
+examples/line/example3/input.png filter=lfs diff=lfs merge=lfs -text
+examples/line/example3/reference_image_0.png filter=lfs diff=lfs merge=lfs -text
+examples/shadow/example0/input.png filter=lfs diff=lfs merge=lfs -text
+examples/shadow/example0/reference_image_0.png filter=lfs diff=lfs merge=lfs -text
+examples/shadow/example0/reference_image_1.png filter=lfs diff=lfs merge=lfs -text
+examples/shadow/example0/reference_image_2.png filter=lfs diff=lfs merge=lfs -text
+examples/shadow/example0/reference_image_3.png filter=lfs diff=lfs merge=lfs -text
+examples/shadow/example1/input.png filter=lfs diff=lfs merge=lfs -text
+examples/shadow/example1/reference_image_0.png filter=lfs diff=lfs merge=lfs -text
+examples/shadow/example1/reference_image_1.png filter=lfs diff=lfs merge=lfs -text
+examples/shadow/example1/reference_image_2.png filter=lfs diff=lfs merge=lfs -text
+examples/shadow/example1/reference_image_3.png filter=lfs diff=lfs merge=lfs -text
+examples/shadow/example1/reference_image_4.png filter=lfs diff=lfs merge=lfs -text
+examples/shadow/example1/reference_image_5.png filter=lfs diff=lfs merge=lfs -text
+examples/shadow/example2/input.png filter=lfs diff=lfs merge=lfs -text
+examples/shadow/example2/reference_image_0.png filter=lfs diff=lfs merge=lfs -text
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..cedd8bd95b357ea19aad88212fe327a3d85e6b09
--- /dev/null
+++ b/app.py
@@ -0,0 +1,699 @@
+import contextlib
+import gc
+import json
+import logging
+import math
+import os
+import random
+import shutil
+import sys
+import time
+import itertools
+import copy
+from pathlib import Path
+
+import cv2
+import numpy as np
+from PIL import Image, ImageDraw
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+from torch.utils.data import Dataset
+from torchvision import transforms
+from tqdm.auto import tqdm
+
+import accelerate
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import ProjectConfiguration, set_seed
+
+from datasets import load_dataset
+from huggingface_hub import create_repo, upload_folder
+from packaging import version
+from safetensors.torch import load_model
+from peft import LoraConfig
+import gradio as gr
+import pandas as pd
+
+import transformers
+from transformers import (
+ AutoTokenizer,
+ PretrainedConfig,
+ CLIPVisionModelWithProjection,
+ CLIPImageProcessor,
+ CLIPProcessor,
+)
+
+import diffusers
+from diffusers import (
+ AutoencoderKL,
+ DDPMScheduler,
+ PixArtTransformer2DModel,
+ CausalSparseDiTModel,
+ CausalSparseDiTControlModel,
+ CobraPixArtAlphaPipeline,
+ UniPCMultistepScheduler,
+)
+from cobra_utils.utils import *
+
+from huggingface_hub import snapshot_download
+
+# model_global_path = snapshot_download(repo_id="JunhaoZhuang/Cobra", cache_dir='./Cobra/', repo_type="model")
+model_global_path = './Cobra/models--JunhaoZhuang--Cobra/snapshots/e87145ca2cb71423753905ff10863ebba23ffc19'
+print(model_global_path)
+examples = [
+ [
+ "./examples/shadow/example0/input.png",
+ ["./examples/shadow/example0/reference_image_0.png",
+ "./examples/shadow/example0/reference_image_1.png",
+ "./examples/shadow/example0/reference_image_2.png",
+ "./examples/shadow/example0/reference_image_3.png"],
+ "line + shadow", # style
+ 1, # seed
+ 10, # step
+ 20, # top k
+ ],
+ [
+ "./examples/shadow/example1/input.png",
+ ["./examples/shadow/example1/reference_image_0.png",
+ "./examples/shadow/example1/reference_image_1.png",
+ "./examples/shadow/example1/reference_image_2.png",
+ "./examples/shadow/example1/reference_image_3.png",
+ "./examples/shadow/example1/reference_image_4.png",
+ "./examples/shadow/example1/reference_image_5.png"],
+ "line + shadow", # style
+ 1, # seed
+ 10, # step
+ 20, # top k
+ ],
+ [
+ "./examples/shadow/example2/input.png",
+ ["./examples/shadow/example2/reference_image_0.png"],
+ "line + shadow", # style
+ 1, # seed
+ 10, # step
+ 3, # top k
+ ],
+ [
+ "./examples/line/example2/input.png",
+ ["./examples/line/example2/reference_image_0.png",
+ "./examples/line/example2/reference_image_1.png",
+ "./examples/line/example2/reference_image_2.png",
+ "./examples/line/example2/reference_image_3.png"],
+ "line", # style
+ 1, # seed
+ 10, # step
+ 20, # top k
+ ],
+ [
+ "./examples/line/example0/input.png",
+ ["./examples/line/example0/reference_image_0.png",
+ "./examples/line/example0/reference_image_1.png",
+ "./examples/line/example0/reference_image_2.png"],
+ "line", # style
+ 0, # seed
+ 10, # step
+ 6, # top k
+ ],
+ [
+ "./examples/line/example1/input.png",
+ ["./examples/line/example1/reference_image_0.png",],
+ "line", # style
+ 0, # seed
+ 10, # step
+ 3, # top k
+ ],
+ [
+ "./examples/line/example3/input.png",
+ ["./examples/line/example3/reference_image_0.png",],
+ "line", # style
+ 0, # seed
+ 10, # step
+ 3, # top k
+ ],]
+
+ratio_list = [[800, 800], [768, 896], [704, 928], [672, 960], [640, 1024], [608, 1056], [576, 1088], [576, 1184]]
+ratio_list += [[896, 768], [928, 704], [960, 672], [1024, 640], [1056, 608], [1088, 576], [1184, 576]]
+
+def get_rate(image):
+ # 计算输入图片的长宽比
+ input_rate = image.size[0] / image.size[1]
+ # 计算与每个预设比例的差距
+ min_diff = float('inf')
+ best_idx = 0
+
+ for i, ratio in enumerate(ratio_list):
+ ratio_rate = ratio[0] / ratio[1]
+ diff = abs(input_rate - ratio_rate)
+ if diff < min_diff:
+ min_diff = diff
+ best_idx = i
+
+ return ratio_list[best_idx]
+
+
+transform = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
+])
+weight_dtype = torch.float16
+
+# line model
+line_model_path = os.path.join(model_global_path, 'LE', 'erika.pth')
+line_model = res_skip()
+line_model.load_state_dict(torch.load(line_model_path))
+line_model.eval()
+line_model.cuda()
+
+
+# image encoder
+image_processor = CLIPImageProcessor()
+image_encoder = CLIPVisionModelWithProjection.from_pretrained(os.path.join(model_global_path, 'image_encoder')).to('cuda')
+# os.path.join(model_global_path, 'image_encoder')
+
+
+
+# model_sketch = create_model_sketch('default').to('cuda') # create a model given opt.model and other options
+# model_sketch.eval()
+
+
+global pipeline
+global MultiResNetModel
+
+def load_ckpt():
+ global pipeline
+ global MultiResNetModel
+ global causal_dit
+ global controlnet
+ weight_dtype = torch.float16
+
+ block_out_channels = [128, 128, 256, 512, 512]
+ MultiResNetModel = MultiHiddenResNetModel(block_out_channels, len(block_out_channels))
+ MultiResNetModel.load_state_dict(torch.load(os.path.join(model_global_path, 'shadow_GSRP', 'MultiResNetModel.bin'), map_location='cpu'), strict=True)
+ MultiResNetModel.to('cuda', dtype=weight_dtype)
+
+
+ # transformer
+ transform = transforms.Compose([
+ transforms.ToTensor(), # 将 PIL 图像转换为 Tensor
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) # 归一化
+ ])
+ # seed = 43
+ lora_rank = 128
+ pretrained_model_name_or_path = '/mnt/workspace/zhuangjunhao/ColorFlow/PixArt-XL-2-1024-MS/models--PixArt-alpha--PixArt-XL-2-1024-MS/snapshots/b89adadeccd9ead2adcb9fa2825d3fabec48d404'
+
+ transformer = PixArtTransformer2DModel.from_pretrained(
+ pretrained_model_name_or_path, subfolder="transformer", revision=None, variant=None
+ )
+ pixart_config = get_pixart_config()
+ causal_dit = CausalSparseDiTModel(num_attention_heads=pixart_config.get("num_attention_heads"),
+ attention_head_dim=pixart_config.get("attention_head_dim"),
+ in_channels=pixart_config.get("in_channels"),
+ out_channels=pixart_config.get("out_channels"),
+ num_layers=pixart_config.get("num_layers"),
+ dropout=pixart_config.get("dropout"),
+ norm_num_groups=pixart_config.get("norm_num_groups"),
+ cross_attention_dim=pixart_config.get("cross_attention_dim"),
+ attention_bias=pixart_config.get("attention_bias"),
+ sample_size=pixart_config.get("sample_size"),
+ patch_size=pixart_config.get("patch_size"),
+ activation_fn=pixart_config.get("activation_fn"),
+ num_embeds_ada_norm=pixart_config.get("num_embeds_ada_norm"),
+ upcast_attention=pixart_config.get("upcast_attention"),
+ norm_type=pixart_config.get("norm_type"),
+ norm_elementwise_affine=pixart_config.get("norm_elementwise_affine"),
+ norm_eps=pixart_config.get("norm_eps"),
+ caption_channels=pixart_config.get("caption_channels"),
+ attention_type=pixart_config.get("attention_type"))
+
+ causal_dit = init_causal_dit(causal_dit, transformer)
+ print('loaded causal_dit')
+ controlnet = CausalSparseDiTControlModel(num_attention_heads=pixart_config.get("num_attention_heads"),
+ attention_head_dim=pixart_config.get("attention_head_dim"),
+ in_channels=pixart_config.get("in_channels"),
+ cond_chanels = 9,
+ out_channels=pixart_config.get("out_channels"),
+ num_layers=pixart_config.get("num_layers"),
+ dropout=pixart_config.get("dropout"),
+ norm_num_groups=pixart_config.get("norm_num_groups"),
+ cross_attention_dim=pixart_config.get("cross_attention_dim"),
+ attention_bias=pixart_config.get("attention_bias"),
+ sample_size=pixart_config.get("sample_size"),
+ patch_size=pixart_config.get("patch_size"),
+ activation_fn=pixart_config.get("activation_fn"),
+ num_embeds_ada_norm=pixart_config.get("num_embeds_ada_norm"),
+ upcast_attention=pixart_config.get("upcast_attention"),
+ norm_type=pixart_config.get("norm_type"),
+ norm_elementwise_affine=pixart_config.get("norm_elementwise_affine"),
+ norm_eps=pixart_config.get("norm_eps"),
+ caption_channels=pixart_config.get("caption_channels"),
+ attention_type=pixart_config.get("attention_type")
+ )
+ # controlnet = init_controlnet(controlnet, causal_dit)
+ del transformer
+ transformer_lora_config = LoraConfig(
+ r=lora_rank,
+ lora_alpha=lora_rank,
+ # use_dora=True,
+ init_lora_weights="gaussian",
+ target_modules=["to_k",
+ "to_q",
+ "to_v",
+ "to_out.0",
+ "proj_in",
+ "proj_out",
+ "ff.net.0.proj",
+ "ff.net.2",
+ "proj",
+ "linear",
+ "linear_1",
+ "linear_2"],#ff.net.0.proj ff.net.2
+ )
+ causal_dit.add_adapter(transformer_lora_config)
+
+
+ lora_state_dict = torch.load(os.path.join(model_global_path, 'shadow_ckpt', 'transformer_lora_pos.bin'), map_location='cpu')
+ causal_dit.load_state_dict(lora_state_dict, strict=False)
+ controlnet_state_dict = torch.load(os.path.join(model_global_path, 'shadow_ckpt', 'controlnet.bin'), map_location='cpu')
+ controlnet.load_state_dict(controlnet_state_dict, strict=True)
+
+ causal_dit.to('cuda', dtype=weight_dtype)
+ controlnet.to('cuda', dtype=weight_dtype)
+
+ pipeline = CobraPixArtAlphaPipeline.from_pretrained(
+ pretrained_model_name_or_path,
+ transformer=causal_dit,
+ controlnet=controlnet,
+ safety_checker=None,
+ revision=None,
+ variant=None,
+ torch_dtype=weight_dtype,
+ )
+
+ pipeline = pipeline.to("cuda")
+
+global cur_style
+cur_style = 'line + shadow'
+def change_ckpt(style):
+ global pipeline
+ global MultiResNetModel
+ global cur_style
+ # global causal_dit
+ # global controlnet
+ weight_dtype = torch.float16
+
+ if style == 'line':
+ MultiResNetModel_path = os.path.join(model_global_path, 'line_GSRP', 'MultiResNetModel.bin')
+ causal_dit_lora_path = os.path.join(model_global_path, 'line_ckpt', 'transformer_lora_pos.bin')
+ controlnet_path = os.path.join(model_global_path, 'line_ckpt', 'controlnet.bin')
+ elif style == 'line + shadow':
+ MultiResNetModel_path = os.path.join(model_global_path, 'shadow_GSRP', 'MultiResNetModel.bin')
+ causal_dit_lora_path = os.path.join(model_global_path, 'shadow_ckpt', 'transformer_lora_pos.bin')
+ controlnet_path = os.path.join(model_global_path, 'shadow_ckpt', 'controlnet.bin')
+ else:
+ raise ValueError("Invalid style: {}".format(style))
+
+ cur_style = style
+
+ MultiResNetModel.load_state_dict(torch.load(MultiResNetModel_path, map_location='cpu'), strict=True)
+ MultiResNetModel.to('cuda', dtype=weight_dtype)
+
+
+ lora_state_dict = torch.load(causal_dit_lora_path, map_location='cpu')
+ pipeline.transformer.load_state_dict(lora_state_dict, strict=False)
+ controlnet_state_dict = torch.load(controlnet_path, map_location='cpu')
+ pipeline.controlnet.load_state_dict(controlnet_state_dict, strict=True)
+
+ pipeline.transformer.to('cuda', dtype=weight_dtype)
+ pipeline.controlnet.to('cuda', dtype=weight_dtype)
+
+ print('loaded {} ckpt'.format(style))
+
+ return style
+
+
+
+load_ckpt()
+
+def fix_random_seeds(seed):
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ if torch.cuda.is_available():
+ torch.cuda.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+
+def process_multi_images(files):
+ images = [Image.open(file.name) for file in files]
+ imgs = []
+ for i, img in enumerate(images):
+ imgs.append(img)
+ return imgs
+
+def extract_lines(image):
+ src = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2GRAY)
+
+ rows = int(np.ceil(src.shape[0] / 16)) * 16
+ cols = int(np.ceil(src.shape[1] / 16)) * 16
+
+ patch = np.ones((1, 1, rows, cols), dtype="float32")
+ patch[0, 0, 0:src.shape[0], 0:src.shape[1]] = src
+
+ tensor = torch.from_numpy(patch).cuda()
+
+ with torch.no_grad():
+ y = line_model(tensor)
+
+ yc = y.cpu().numpy()[0, 0, :, :]
+ yc[yc > 255] = 255
+ yc[yc < 0] = 0
+
+ outimg = yc[0:src.shape[0], 0:src.shape[1]]
+ outimg = outimg.astype(np.uint8)
+ outimg = Image.fromarray(outimg)
+ torch.cuda.empty_cache()
+ return outimg
+
+def extract_line_image(query_image_, resolution):
+ tar_width, tar_height = resolution
+ query_image = query_image_.resize((tar_width, tar_height))
+ # query_image.save('/mnt/workspace/zhuangjunhao/cobra_code/ColorFlow/examples/line/example3/input.png')
+ query_image = query_image.convert('L').convert('RGB')
+ extracted_line = extract_lines(query_image)
+ extracted_line = extracted_line.convert('L').convert('RGB')
+ torch.cuda.empty_cache()
+ return extracted_line, Image.new('RGB', (tar_width, tar_height), 'black')
+
+def extract_sketch_line_image(query_image_, input_style):
+ global cur_style
+ if input_style != cur_style:
+ change_ckpt(input_style)
+
+ resolution = get_rate(query_image_)
+ extracted_line, hint_mask = extract_line_image(query_image_, resolution)
+ extracted_sketch = extracted_line
+ extracted_sketch_line = Image.blend(extracted_sketch, extracted_line, 0.5)
+
+ extracted_sketch_line_ori = copy.deepcopy(extracted_sketch_line)
+
+ extracted_sketch_line_np = np.array(extracted_sketch_line)
+ # extracted_sketch_line_np[extracted_sketch_line_np < 236] = 0
+ # extracted_sketch_line_np[extracted_sketch_line_np >= 236] = 255
+ extracted_sketch_line = Image.fromarray(np.uint8(extracted_sketch_line_np))
+ if input_style == 'line + shadow':
+ print('line + shadow sketch')
+ black_rate = 74
+ black_value = 18
+ gary_rate = 155
+ up_bound = 145
+ ori_np = np.array(extracted_sketch_line_ori)
+ query_image_np = np.array(query_image_.resize(resolution).convert('L').convert('RGB'))
+ extracted_sketch_line_np = np.array(extracted_sketch_line.convert('L').convert('RGB'))
+ ori_np[query_image_np <= black_rate] = black_value
+ ori_np[(ori_np > gary_rate) & (query_image_np < up_bound) & (query_image_np > black_rate)] = gary_rate
+ extracted_sketch_line_ori = Image.fromarray(np.uint8(ori_np))
+
+ extracted_sketch_line_np[query_image_np <= black_rate] = black_value
+ extracted_sketch_line_np[(extracted_sketch_line_np > gary_rate) & (query_image_np < up_bound) & (query_image_np > black_rate)] = gary_rate
+ extracted_sketch_line = Image.fromarray(np.uint8(extracted_sketch_line_np))
+
+ return extracted_sketch_line.convert('RGB'), extracted_sketch_line.convert('RGB'), hint_mask, query_image_, extracted_sketch_line_ori.convert('RGB'), resolution
+
+def colorize_image(extracted_line, reference_images, resolution, seed, num_inference_steps, top_k, hint_mask=None, hint_color=None, query_image_origin=None, extracted_image_ori=None):
+ if extracted_line is None:
+ gr.Info("Please preprocess the image first")
+ raise ValueError("Please preprocess the image first")
+ global pipeline
+ global MultiResNetModel
+ reference_images = process_multi_images(reference_images)
+ fix_random_seeds(seed)
+
+ tar_width, tar_height = resolution
+
+ gr.Info("Image retrieval in progress...")
+
+ query_image_bw = extracted_line.resize((tar_width, tar_height))
+ query_image = query_image_bw.convert('RGB')
+
+ query_image_origin = query_image_origin.resize((tar_width, tar_height))
+
+ query_image_vae = extracted_image_ori.resize((int(tar_width*1.5), int(tar_height*1.5)))
+ reference_images = [process_image(ref_image, tar_width, tar_height) for ref_image in reference_images]
+ query_patches_pil = process_image_Q_varres(query_image_origin, tar_width, tar_height)
+ reference_patches_pil = []
+ # Save reference_images
+ # save_path = '/mnt/workspace/zhuangjunhao/cobra_code/ColorFlow/examples/line/example3'
+ # os.makedirs(save_path, exist_ok=True)
+ # for idx, ref_image in enumerate(reference_images):
+ # ref_image.save(os.path.join(save_path, f'reference_image_{idx}.png'))
+
+ for reference_image in reference_images:
+ reference_patches_pil += process_image_ref_varres(reference_image, tar_width, tar_height)
+ with torch.no_grad():
+ clip_img = image_processor(images=query_patches_pil, return_tensors="pt").pixel_values.to(image_encoder.device, dtype=image_encoder.dtype)
+ query_embeddings = image_encoder(clip_img).image_embeds
+ reference_patches_pil_gray = [rimg.convert('RGB').convert('RGB') for rimg in reference_patches_pil]
+ clip_img = image_processor(images=reference_patches_pil_gray, return_tensors="pt").pixel_values.to(image_encoder.device, dtype=image_encoder.dtype)
+ reference_embeddings = image_encoder(clip_img).image_embeds
+ cosine_similarities = F.cosine_similarity(query_embeddings.unsqueeze(1), reference_embeddings.unsqueeze(0), dim=-1)
+ len_ref = len(reference_patches_pil)
+ # print(cosine_similarities)
+ sorted_indices = torch.argsort(cosine_similarities, descending=True, dim=1).tolist()
+
+ top_k_indices = [cur_sortlist[:top_k] for cur_sortlist in sorted_indices]
+ available_ref_patches = [[],[],[],[]]
+ for i in range(len(top_k_indices)):
+ for j in range(top_k):
+ available_ref_patches[i].append(reference_patches_pil[top_k_indices[i][j]].resize((tar_width//2, tar_height//2)).convert('RGB'))
+
+ flat_available_ref_patches = [item for sublist in available_ref_patches for item in sublist]
+
+ # 正方形拼接 flat_available_ref_patches
+ grid_N = int(np.ceil(np.sqrt(len(flat_available_ref_patches))))
+ small_tar_width = tar_width//grid_N
+ small_tar_height = tar_height//grid_N
+ grid_img = Image.new('RGB', (grid_N*small_tar_width, grid_N*small_tar_height), 'black')
+ for i in range(len(flat_available_ref_patches)):
+ grid_img.paste(flat_available_ref_patches[i].resize((small_tar_width, small_tar_height)), (i%grid_N*small_tar_width, int(i/grid_N)*small_tar_height))
+
+ # grid_img 添加文字"Reference images"
+ draw = ImageDraw.Draw(grid_img)
+ draw.text((0, 0), "Reference Images", fill='red', font_size=50)
+
+ gr.Info("Model inference in progress...")
+ generator = torch.Generator(device='cuda').manual_seed(seed)
+ hint_mask = hint_mask.resize((tar_width//8, tar_height//8)).convert('RGB')
+ hint_color = hint_color.convert('RGB')
+
+ colorized_image = pipeline(
+ cond_input=query_image_bw.convert('RGB'),
+ cond_refs=available_ref_patches,
+ hint_mask=hint_mask,
+ hint_color=hint_color,
+ num_inference_steps=num_inference_steps,
+ generator = generator,
+ )[0][0]
+ gr.Info("Post-processing image...")
+ with torch.no_grad():
+ up_img = colorized_image.resize(query_image_vae.size)
+ test_low_color = transform(up_img).unsqueeze(0).to('cuda', dtype=weight_dtype)
+ query_image_vae_ = transform(query_image_vae).unsqueeze(0).to('cuda', dtype=weight_dtype)
+
+ h_color, hidden_list_color = pipeline.vae._encode(test_low_color,return_dict = False, hidden_flag = True)
+ h_bw, hidden_list_bw = pipeline.vae._encode(query_image_vae_, return_dict = False, hidden_flag = True)
+
+ hidden_list_double = [torch.cat((hidden_list_color[hidden_idx], hidden_list_bw[hidden_idx]), dim = 1) for hidden_idx in range(len(hidden_list_color))]
+
+
+ hidden_list = MultiResNetModel(hidden_list_double)
+ output = pipeline.vae._decode(h_color.sample(),return_dict = False, hidden_list = hidden_list)[0]
+
+ output[output > 1] = 1
+ output[output < -1] = -1
+ high_res_image = Image.fromarray(((output[0] * 0.5 + 0.5).permute(1, 2, 0).detach().cpu().numpy() * 255).astype(np.uint8)).convert("RGB")
+ gr.Info("Colorization complete!")
+ torch.cuda.empty_cache()
+
+ output_gallery = [high_res_image, query_image_bw, hint_mask, hint_color, grid_img]
+ return output_gallery
+
+
+# Function to get color value from reference image
+def get_color_value(reference_image, evt: gr.SelectData):
+ if reference_image is None:
+ return "Please upload a reference image first."
+ x, y = evt.index
+ color_value = reference_image[y, x]
+ return f"Get Color value: {color_value}", color_value
+
+# Function to draw a square on the line drawing image
+def draw_square(line_drawing_image_pil, hint_mask, color_value, evt: gr.SelectData):
+ line_drawing_image = np.array(line_drawing_image_pil)
+ # line_drawing_image = np.array(Image.new('RGB', line_drawing_image_pil.size, 'black'))
+ hint_mask = np.array(hint_mask)
+ if line_drawing_image is None:
+ return "Please upload a line drawing image first."
+ if color_value is None:
+ return "Please pick a color from the reference image first."
+ x, y = evt.index
+ # Calculate square boundaries
+ start_x = max(0, x - 8)
+ start_y = max(0, y - 8)
+ end_x = min(line_drawing_image.shape[1], x + 8)
+ end_y = min(line_drawing_image.shape[0], y + 8)
+ # Draw the square
+ line_drawing_image[start_y:end_y, start_x:end_x] = color_value
+ line_drawing_image_pil = Image.fromarray(np.uint8(line_drawing_image))
+ hint_mask[start_y:end_y, start_x:end_x] = 255
+ hint_mask_pil = Image.fromarray(np.uint8(hint_mask))
+ return line_drawing_image_pil, hint_mask_pil
+
+
+with gr.Blocks() as demo:
+ gr.HTML(
+ """
+
+
🎨 Cobra:
+
Efficient Line Art COlorization with BRoAder References
+
+ Project Page |
+ ArXiv Preprint |
+ GitHub Repository
+
+
+ NOTE: Each time you switch the input style, the corresponding model will be reloaded, which may take some time. Please be patient.
+
+
+ Welcome to the demo of Cobra . Follow the steps below to explore the capabilities of our model:
+
+
+
+
+ Choose your input style: either line + shadow or line only.
+ Upload your image: Click the 'Upload' button to select the image you want to colorize.
+ Preprocess the image: Click the 'Preprocess' button to extract the line art from your image.
+ (Optional) Obtain color values and add color hints: Upload an image to the left area and click to get color values; then, add color hints to the line art on the right.
+ Upload reference images: Upload several reference images to help guide the colorization process.
+ (Optional) Set inference parameters: Adjust the inference settings as needed.
+ Run: Click the Colorize button to start the process.
+
+
+ ⏱️ ZeroGPU Time Limit : Hugging Face ZeroGPU has an inference time limit of 180 seconds. You may need to log in with a free account to use this demo. Large sampling steps might lead to timeout (GPU Abort). In that case, please consider logging in with a Pro account or running it on your local machine.
+
+
+
+
+ 注意:每次切换输入样式时,相应的模型将被重新加载,可能需要一些时间。请耐心等待。
+
+
+ 欢迎使用 Cobra 演示。请按照以下步骤探索我们模型的能力:
+
+
+
+
+ 选择输入样式:线条+阴影或仅线条。
+ 上传您的图像:点击“上传”按钮选择您想要上色的图像。
+ 预处理图像:点击“预处理”按钮从您的图像中提取线稿。
+ (可选)获取颜色值并添加颜色提示:上传一张图像到左侧区域,点击获取颜色值;然后,为右侧的线稿添加颜色提示。
+ 上传参考图像:上传多个参考图像以帮助引导上色过程。
+ (可选)设置推理参数:根据需要调整推理设置。
+ 运行:点击 上色 按钮开始处理。
+
+
+ ⏱️ ZeroGPU时间限制 :Hugging Face ZeroGPU 的推理时间限制为 180 秒。您可能需要使用免费帐户登录以使用此演示。大采样步骤可能会导致超时(GPU 中止)。在这种情况下,请考虑使用专业帐户登录或在本地计算机上运行。
+
+
+ """
+)
+ # extracted_line = gr.State()
+ # example_loading = gr.State(value=None)
+ hint_mask = gr.State()
+ hint_color = gr.State()
+ query_image_origin = gr.State()
+ resolution = gr.State()
+ extracted_image_ori = gr.State()
+ style = gr.State()
+ # updated_mask = gr.State()
+ # model_name = gr.Textbox(label="Model Name", value=None)
+
+ # style = gr.Dropdown(label="Model Name", choices=["line + shadow","line"], value="line + shadow")
+
+ with gr.Column():
+ gr.Markdown("Load Model ")
+ with gr.Row():
+ model_name = gr.Textbox(label="Model Name", value=None)
+ with gr.Column():
+ style = gr.Dropdown(label="Model List", choices=["line + shadow","line"], value="line + shadow")
+ change_ckpt_button = gr.Button("Load Model")
+ change_ckpt_button.click(change_ckpt, inputs=[style], outputs=[model_name])
+ # model_name = gr.Textbox(label="Model Name", value=None)
+
+ # 添加文字 英文 线稿提取
+ gr.Markdown("Line Drawing Extraction ")
+ with gr.Row():
+ with gr.Column():
+ input_image = gr.Image(type="pil", label="Image to Colorize")
+ # resolution = gr.Radio(["800x800", "640x1024", "1024x640"], label="Select Resolution(Width*Height)", value="640x1024")
+ extract_button = gr.Button("Preprocess (Decolorize)")
+ extracted_image = gr.Image(type="pil", label="Decolorized Result")
+
+
+ gr.Markdown("Color Selection 🎨 (Left) and Hint Placement 💡 (Right) - Click with Mouse 🖱️ ")
+
+ with gr.Row():
+ with gr.Column():
+ get_color_img = gr.Image(label="Upload an image to extract colors", type="numpy")
+ color_value_output = gr.Textbox(label="Color Value")
+ color_value_state = gr.State()
+ get_color_img.select(
+ get_color_value,
+ [get_color_img],
+ [color_value_output, color_value_state]
+ )
+ with gr.Column():
+ hint_color = gr.Image(label="Line Drawing Image", type="pil")
+ # updated_image = gr.Image(label="Updated Image", type="pil")
+ hint_color.select(
+ draw_square,
+ [hint_color, hint_mask, color_value_state],
+ [hint_color, hint_mask]
+ )
+
+ gr.Markdown("Retrieval and Colorization ")
+ with gr.Row():
+ reference_images = gr.Files(label="Reference Images (Upload multiple)", file_count="multiple")
+ with gr.Column():
+ output_gallery = gr.Gallery(label="Colorization Results", type="pil")
+ seed = gr.Slider(label="Random Seed", minimum=0, maximum=100000, value=0, step=1)
+ num_inference_steps = gr.Slider(label="Inference Steps", minimum=1, maximum=100, value=10, step=1)
+ colorize_button = gr.Button("Colorize")
+ top_k = gr.Slider(label="Top K (Total Reference Images: 4K) ", minimum=1, maximum=50, value=3, step=1)
+
+
+ extract_button.click(
+ extract_sketch_line_image,
+ inputs=[input_image, model_name],
+ outputs=[extracted_image,
+ hint_color,
+ hint_mask,
+ query_image_origin,
+ extracted_image_ori,
+ resolution
+ ]
+ )
+ colorize_button.click(
+ colorize_image,
+ inputs=[extracted_image, reference_images, resolution, seed, num_inference_steps, top_k, hint_mask, hint_color, query_image_origin, extracted_image_ori],
+ outputs=output_gallery
+ )
+ with gr.Column():
+ gr.Markdown("### Quick Examples")
+ gr.Examples(
+ examples=examples,
+ inputs=[input_image, reference_images, model_name, seed, num_inference_steps, top_k],
+ label="Examples",
+ examples_per_page=8,
+ )
+
+
+demo.launch(server_name="0.0.0.0", server_port=52218)
\ No newline at end of file
diff --git a/cobra_utils/.DS_Store b/cobra_utils/.DS_Store
new file mode 100644
index 0000000000000000000000000000000000000000..a3d9c72411f907964c8c222c3674b8152b145eb8
Binary files /dev/null and b/cobra_utils/.DS_Store differ
diff --git a/cobra_utils/__pycache__/utils.cpython-311.pyc b/cobra_utils/__pycache__/utils.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bf60fa0942ca0a2b337d2297618c19caaa4e101b
Binary files /dev/null and b/cobra_utils/__pycache__/utils.cpython-311.pyc differ
diff --git a/cobra_utils/utils.py b/cobra_utils/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..8d650d2c1bc659c1587136e387d2e9b918c66910
--- /dev/null
+++ b/cobra_utils/utils.py
@@ -0,0 +1,603 @@
+import os
+import random
+import numpy as np
+import torch
+import torch.nn as nn
+from torch.utils.data import Dataset
+from PIL import Image
+from torchvision import transforms
+from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor
+import matplotlib.pyplot as plt
+import cv2
+import torch.nn.functional as F
+
+
+
+class _bn_relu_conv(nn.Module):
+ def __init__(self, in_filters, nb_filters, fw, fh, subsample=1):
+ super(_bn_relu_conv, self).__init__()
+ self.model = nn.Sequential(
+ nn.BatchNorm2d(in_filters, eps=1e-3),
+ nn.LeakyReLU(0.2),
+ nn.Conv2d(in_filters, nb_filters, (fw, fh), stride=subsample, padding=(fw//2, fh//2), padding_mode='zeros')
+ )
+
+ def forward(self, x):
+ return self.model(x)
+
+ # the following are for debugs
+ print("****", np.max(x.cpu().numpy()), np.min(x.cpu().numpy()), np.mean(x.cpu().numpy()), np.std(x.cpu().numpy()), x.shape)
+ for i,layer in enumerate(self.model):
+ if i != 2:
+ x = layer(x)
+ else:
+ x = layer(x)
+ #x = nn.functional.pad(x, (1, 1, 1, 1), mode='constant', value=0)
+ print("____", np.max(x.cpu().numpy()), np.min(x.cpu().numpy()), np.mean(x.cpu().numpy()), np.std(x.cpu().numpy()), x.shape)
+ print(x[0])
+ return x
+
+
+class _u_bn_relu_conv(nn.Module):
+ def __init__(self, in_filters, nb_filters, fw, fh, subsample=1):
+ super(_u_bn_relu_conv, self).__init__()
+ self.model = nn.Sequential(
+ nn.BatchNorm2d(in_filters, eps=1e-3),
+ nn.LeakyReLU(0.2),
+ nn.Conv2d(in_filters, nb_filters, (fw, fh), stride=subsample, padding=(fw//2, fh//2)),
+ nn.Upsample(scale_factor=2, mode='nearest')
+ )
+
+ def forward(self, x):
+ return self.model(x)
+
+
+
+class _shortcut(nn.Module):
+ def __init__(self, in_filters, nb_filters, subsample=1):
+ super(_shortcut, self).__init__()
+ self.process = False
+ self.model = None
+ if in_filters != nb_filters or subsample != 1:
+ self.process = True
+ self.model = nn.Sequential(
+ nn.Conv2d(in_filters, nb_filters, (1, 1), stride=subsample)
+ )
+
+ def forward(self, x, y):
+ #print(x.size(), y.size(), self.process)
+ if self.process:
+ y0 = self.model(x)
+ #print("merge+", torch.max(y0+y), torch.min(y0+y),torch.mean(y0+y), torch.std(y0+y), y0.shape)
+ return y0 + y
+ else:
+ #print("merge", torch.max(x+y), torch.min(x+y),torch.mean(x+y), torch.std(x+y), y.shape)
+ return x + y
+
+class _u_shortcut(nn.Module):
+ def __init__(self, in_filters, nb_filters, subsample):
+ super(_u_shortcut, self).__init__()
+ self.process = False
+ self.model = None
+ if in_filters != nb_filters:
+ self.process = True
+ self.model = nn.Sequential(
+ nn.Conv2d(in_filters, nb_filters, (1, 1), stride=subsample, padding_mode='zeros'),
+ nn.Upsample(scale_factor=2, mode='nearest')
+ )
+
+ def forward(self, x, y):
+ if self.process:
+ return self.model(x) + y
+ else:
+ return x + y
+
+
+class basic_block(nn.Module):
+ def __init__(self, in_filters, nb_filters, init_subsample=1):
+ super(basic_block, self).__init__()
+ self.conv1 = _bn_relu_conv(in_filters, nb_filters, 3, 3, subsample=init_subsample)
+ self.residual = _bn_relu_conv(nb_filters, nb_filters, 3, 3)
+ self.shortcut = _shortcut(in_filters, nb_filters, subsample=init_subsample)
+
+ def forward(self, x):
+ x1 = self.conv1(x)
+ x2 = self.residual(x1)
+ return self.shortcut(x, x2)
+
+class _u_basic_block(nn.Module):
+ def __init__(self, in_filters, nb_filters, init_subsample=1):
+ super(_u_basic_block, self).__init__()
+ self.conv1 = _u_bn_relu_conv(in_filters, nb_filters, 3, 3, subsample=init_subsample)
+ self.residual = _bn_relu_conv(nb_filters, nb_filters, 3, 3)
+ self.shortcut = _u_shortcut(in_filters, nb_filters, subsample=init_subsample)
+
+ def forward(self, x):
+ y = self.residual(self.conv1(x))
+ return self.shortcut(x, y)
+
+
+class _residual_block(nn.Module):
+ def __init__(self, in_filters, nb_filters, repetitions, is_first_layer=False):
+ super(_residual_block, self).__init__()
+ layers = []
+ for i in range(repetitions):
+ init_subsample = 1
+ if i == repetitions - 1 and not is_first_layer:
+ init_subsample = 2
+ if i == 0:
+ l = basic_block(in_filters=in_filters, nb_filters=nb_filters, init_subsample=init_subsample)
+ else:
+ l = basic_block(in_filters=nb_filters, nb_filters=nb_filters, init_subsample=init_subsample)
+ layers.append(l)
+
+ self.model = nn.Sequential(*layers)
+
+ def forward(self, x):
+ return self.model(x)
+
+
+class _upsampling_residual_block(nn.Module):
+ def __init__(self, in_filters, nb_filters, repetitions):
+ super(_upsampling_residual_block, self).__init__()
+ layers = []
+ for i in range(repetitions):
+ l = None
+ if i == 0:
+ l = _u_basic_block(in_filters=in_filters, nb_filters=nb_filters)#(input)
+ else:
+ l = basic_block(in_filters=nb_filters, nb_filters=nb_filters)#(input)
+ layers.append(l)
+
+ self.model = nn.Sequential(*layers)
+
+ def forward(self, x):
+ return self.model(x)
+
+
+class res_skip(nn.Module):
+
+ def __init__(self):
+ super(res_skip, self).__init__()
+ self.block0 = _residual_block(in_filters=1, nb_filters=24, repetitions=2, is_first_layer=True)#(input)
+ self.block1 = _residual_block(in_filters=24, nb_filters=48, repetitions=3)#(block0)
+ self.block2 = _residual_block(in_filters=48, nb_filters=96, repetitions=5)#(block1)
+ self.block3 = _residual_block(in_filters=96, nb_filters=192, repetitions=7)#(block2)
+ self.block4 = _residual_block(in_filters=192, nb_filters=384, repetitions=12)#(block3)
+
+ self.block5 = _upsampling_residual_block(in_filters=384, nb_filters=192, repetitions=7)#(block4)
+ self.res1 = _shortcut(in_filters=192, nb_filters=192)#(block3, block5, subsample=(1,1))
+
+ self.block6 = _upsampling_residual_block(in_filters=192, nb_filters=96, repetitions=5)#(res1)
+ self.res2 = _shortcut(in_filters=96, nb_filters=96)#(block2, block6, subsample=(1,1))
+
+ self.block7 = _upsampling_residual_block(in_filters=96, nb_filters=48, repetitions=3)#(res2)
+ self.res3 = _shortcut(in_filters=48, nb_filters=48)#(block1, block7, subsample=(1,1))
+
+ self.block8 = _upsampling_residual_block(in_filters=48, nb_filters=24, repetitions=2)#(res3)
+ self.res4 = _shortcut(in_filters=24, nb_filters=24)#(block0,block8, subsample=(1,1))
+
+ self.block9 = _residual_block(in_filters=24, nb_filters=16, repetitions=2, is_first_layer=True)#(res4)
+ self.conv15 = _bn_relu_conv(in_filters=16, nb_filters=1, fh=1, fw=1, subsample=1)#(block7)
+
+ def forward(self, x):
+ x0 = self.block0(x)
+ x1 = self.block1(x0)
+ x2 = self.block2(x1)
+ x3 = self.block3(x2)
+ x4 = self.block4(x3)
+
+ x5 = self.block5(x4)
+ res1 = self.res1(x3, x5)
+
+ x6 = self.block6(res1)
+ res2 = self.res2(x2, x6)
+
+ x7 = self.block7(res2)
+ res3 = self.res3(x1, x7)
+
+ x8 = self.block8(res3)
+ res4 = self.res4(x0, x8)
+
+ x9 = self.block9(res4)
+ y = self.conv15(x9)
+
+ return y
+
+class MyDataset(Dataset):
+ def __init__(self, image_paths, transform=None):
+ self.image_paths = image_paths
+ self.transform = transform
+
+ def get_class_label(self, image_name):
+ # your method here
+ head, tail = os.path.split(image_name)
+ #print(tail)
+ return tail
+
+ def __getitem__(self, index):
+ image_path = self.image_paths[index]
+ x = Image.open(image_path)
+ y = self.get_class_label(image_path.split('/')[-1])
+ if self.transform is not None:
+ x = self.transform(x)
+ return x, y
+
+ def __len__(self):
+ return len(self.image_paths)
+
+def loadImages(folder):
+ imgs = []
+ matches = []
+
+ for filename in os.listdir(folder):
+ file_path = os.path.join(folder, filename)
+ if os.path.isfile(file_path):
+ matches.append(file_path)
+
+ return matches
+
+
+def crop_center_square(image):
+
+ width, height = image.size
+
+
+ side_length = min(width, height)
+
+ left = (width - side_length) // 2
+ top = (height - side_length) // 2
+ right = left + side_length
+ bottom = top + side_length
+
+ cropped_image = image.crop((left, top, right, bottom))
+
+ return cropped_image
+
+def crop_image(image, crop_size, stride):
+
+ width, height = image.size
+ crop_width, crop_height = crop_size
+ cropped_images = []
+
+ for j in range(0, height - crop_height + 1, stride):
+ for i in range(0, width - crop_width + 1, stride):
+ crop_box = (i, j, i + crop_width, j + crop_height)
+ cropped_image = image.crop(crop_box)
+ cropped_images.append(cropped_image)
+
+ return cropped_images
+
+def process_image_ref(image):
+
+ resized_image_512 = image.resize((512, 512))
+
+
+ image_list = [resized_image_512]
+
+
+ crop_size_384 = (384, 384)
+ stride_384 = 128
+ image_list.extend(crop_image(resized_image_512, crop_size_384, stride_384))
+
+
+ return image_list
+
+
+def process_image_Q(image):
+
+ resized_image_512 = image.resize((512, 512)).convert("RGB").convert("RGB")
+
+ image_list = []
+
+ crop_size_384 = (384, 384)
+ stride_384 = 128
+ image_list.extend(crop_image(resized_image_512, crop_size_384, stride_384))
+
+ return image_list
+
+def process_image(image, target_width=512, target_height = 512):
+ img_width, img_height = image.size
+ img_ratio = img_width / img_height
+
+ target_ratio = target_width / target_height
+
+ ratio_error = abs(img_ratio - target_ratio) / target_ratio
+
+ if ratio_error < 0.15:
+ resized_image = image.resize((target_width, target_height), Image.BICUBIC)
+ else:
+ if img_ratio > target_ratio:
+ new_width = int(img_height * target_ratio)
+ left = int((0 + img_width - new_width)/2)
+ top = 0
+ right = left + new_width
+ bottom = img_height
+ else:
+ new_height = int(img_width / target_ratio)
+ left = 0
+ top = int((0 + img_height - new_height)/2)
+ right = img_width
+ bottom = top + new_height
+
+ cropped_image = image.crop((left, top, right, bottom))
+ resized_image = cropped_image.resize((target_width, target_height), Image.BICUBIC)
+
+ return resized_image.convert('RGB')
+
+def crop_image_varres(image, crop_size, h_stride, w_stride):
+ width, height = image.size
+ crop_width, crop_height = crop_size
+ cropped_images = []
+
+ for j in range(0, height - crop_height + 1, h_stride):
+ for i in range(0, width - crop_width + 1, w_stride):
+ crop_box = (i, j, i + crop_width, j + crop_height)
+ cropped_image = image.crop(crop_box)
+ cropped_images.append(cropped_image)
+
+ return cropped_images
+
+def process_image_ref_varres(image, target_width=512, target_height = 512):
+ resized_image_512 = image.resize((target_width, target_height))
+
+ image_list = [resized_image_512]
+
+ crop_size_384 = (target_width//4*3, target_height//4*3)
+ w_stride_384 = target_width//4
+ h_stride_384 = target_height//4
+ image_list.extend(crop_image_varres(resized_image_512, crop_size_384, h_stride = h_stride_384, w_stride = w_stride_384))
+
+ return image_list
+
+
+def process_image_Q_varres(image, target_width=512, target_height = 512):
+
+ resized_image_512 = image.resize((target_width, target_height)).convert("RGB").convert("RGB")
+
+ image_list = []
+
+ crop_size_384 = (target_width//4*3, target_height//4*3)
+ w_stride_384 = target_width//4
+ h_stride_384 = target_height//4
+ image_list.extend(crop_image_varres(resized_image_512, crop_size_384, h_stride = h_stride_384, w_stride = w_stride_384))
+
+
+ return image_list
+
+
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+class ResNetBlock(nn.Module):
+ def __init__(self, in_channels, out_channels, stride=1):
+ super(ResNetBlock, self).__init__()
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
+ self.bn1 = nn.BatchNorm2d(out_channels)
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
+ self.bn2 = nn.BatchNorm2d(out_channels)
+
+ self.shortcut = nn.Sequential()
+ if stride != 1 or in_channels != out_channels:
+ self.shortcut = nn.Sequential(
+ nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
+ nn.BatchNorm2d(out_channels)
+ )
+
+ def forward(self, x):
+ out = F.relu(self.bn1(self.conv1(x)))
+ out = self.bn2(self.conv2(out))
+ out += self.shortcut(x) # 直接相加
+ out = F.relu(out)
+ return out
+
+class TwoLayerResNet(nn.Module):
+ def __init__(self, in_channels, out_channels):
+ super(TwoLayerResNet, self).__init__()
+ self.block1 = ResNetBlock(in_channels, out_channels)
+ self.block2 = ResNetBlock(out_channels, out_channels)
+ self.block3 = ResNetBlock(out_channels, out_channels)
+ self.block4 = ResNetBlock(out_channels, out_channels)
+
+ def forward(self, x):
+ x = self.block1(x)
+ x = self.block2(x)
+ x = self.block3(x)
+ x = self.block4(x)
+ return x
+
+
+class MultiHiddenResNetModel(nn.Module):
+ def __init__(self, channels_list, num_tensors):
+ super(MultiHiddenResNetModel, self).__init__()
+ self.two_layer_resnets = nn.ModuleList([TwoLayerResNet(channels_list[idx]*2, channels_list[min(len(channels_list)-1,idx+2)]) for idx in range(num_tensors)])
+
+ def forward(self, tensor_list):
+ processed_list = []
+ for i, tensor in enumerate(tensor_list):
+ tensor = self.two_layer_resnets[i](tensor)
+ processed_list.append(tensor)
+
+ return processed_list
+
+
+def calculate_target_size(h, w):
+ if random.random()>0.5:
+ target_h = (h // 8) * 8
+ target_w = (w // 8) * 8
+ elif random.random()>0.5:
+ target_h = (h // 8) * 8
+ target_w = (w // 8) * 8
+ else:
+ target_h = (h // 8) * 8
+ target_w = (w // 8) * 8
+
+ if target_h == 0:
+ target_h = 8
+ if target_w == 0:
+ target_w = 8
+
+ return target_h, target_w
+
+
+def downsample_tensor(tensor):
+ b, c, h, w = tensor.shape
+
+ target_h, target_w = calculate_target_size(h, w)
+
+ downsampled_tensor = F.interpolate(tensor, size=(target_h, target_w), mode='bilinear', align_corners=False)
+
+ return downsampled_tensor
+
+
+
+def get_pixart_config():
+ pixart_config = {
+ "_class_name": "Transformer2DModel",
+ "_diffusers_version": "0.22.0.dev0",
+ "activation_fn": "gelu-approximate",
+ "attention_bias": True,
+ "attention_head_dim": 72,
+ "attention_type": "default",
+ "caption_channels": 4096,
+ "cross_attention_dim": 1152,
+ "double_self_attention": False,
+ "dropout": 0.0,
+ "in_channels": 4,
+ # "interpolation_scale": 2,
+ "norm_elementwise_affine": False,
+ "norm_eps": 1e-06,
+ "norm_num_groups": 32,
+ "norm_type": "ada_norm_single",
+ "num_attention_heads": 16,
+ "num_embeds_ada_norm": 1000,
+ "num_layers": 28,
+ "num_vector_embeds": None,
+ "only_cross_attention": False,
+ "out_channels": 8,
+ "patch_size": 2,
+ "sample_size": 128,
+ "upcast_attention": False,
+ # "use_additional_conditions": False,
+ "use_linear_projection": False
+ }
+ return pixart_config
+
+
+
+class DoubleConv(nn.Module):
+ def __init__(self, in_channels, out_channels):
+ super().__init__()
+ self.double_conv = nn.Sequential(
+ nn.Conv2d(in_channels, out_channels, 3, 1, 1),
+ nn.BatchNorm2d(out_channels),
+ nn.ReLU(),
+ nn.Conv2d(out_channels, out_channels, 3, 1, 1),
+ nn.BatchNorm2d(out_channels),
+ nn.ReLU()
+ )
+
+ def forward(self, x):
+ return self.double_conv(x)
+
+
+class UNet(nn.Module):
+ def __init__(self):
+ super().__init__()
+ # left
+ self.left_conv_1 = DoubleConv(6, 64)
+ self.down_1 = nn.MaxPool2d(2, 2)
+
+ self.left_conv_2 = DoubleConv(64, 128)
+ self.down_2 = nn.MaxPool2d(2, 2)
+
+ self.left_conv_3 = DoubleConv(128, 256)
+ self.down_3 = nn.MaxPool2d(2, 2)
+
+ self.left_conv_4 = DoubleConv(256, 512)
+ self.down_4 = nn.MaxPool2d(2, 2)
+
+ # center
+ self.center_conv = DoubleConv(512, 1024)
+
+ # right
+ self.up_1 = nn.ConvTranspose2d(1024, 512, 2, 2)
+ self.right_conv_1 = DoubleConv(1024, 512)
+
+ self.up_2 = nn.ConvTranspose2d(512, 256, 2, 2)
+ self.right_conv_2 = DoubleConv(512, 256)
+
+ self.up_3 = nn.ConvTranspose2d(256, 128, 2, 2)
+ self.right_conv_3 = DoubleConv(256, 128)
+
+ self.up_4 = nn.ConvTranspose2d(128, 64, 2, 2)
+ self.right_conv_4 = DoubleConv(128, 64)
+
+ # output
+ self.output = nn.Conv2d(64, 3, 1, 1, 0)
+
+ def forward(self, x):
+ # left
+ x1 = self.left_conv_1(x)
+ x1_down = self.down_1(x1)
+
+ x2 = self.left_conv_2(x1_down)
+ x2_down = self.down_2(x2)
+
+ x3 = self.left_conv_3(x2_down)
+ x3_down = self.down_3(x3)
+
+ x4 = self.left_conv_4(x3_down)
+ x4_down = self.down_4(x4)
+
+ # center
+ x5 = self.center_conv(x4_down)
+
+ # right
+ x6_up = self.up_1(x5)
+ temp = torch.cat((x6_up, x4), dim=1)
+ x6 = self.right_conv_1(temp)
+
+ x7_up = self.up_2(x6)
+ temp = torch.cat((x7_up, x3), dim=1)
+ x7 = self.right_conv_2(temp)
+
+ x8_up = self.up_3(x7)
+ temp = torch.cat((x8_up, x2), dim=1)
+ x8 = self.right_conv_3(temp)
+
+ x9_up = self.up_4(x8)
+ temp = torch.cat((x9_up, x1), dim=1)
+ x9 = self.right_conv_4(temp)
+
+ # output
+ output = self.output(x9)
+
+ return output
+
+
+
+from copy import deepcopy
+
+def init_causal_dit(model, base_model):
+ temp_ckpt = deepcopy(base_model)
+ checkpoint = temp_ckpt.state_dict()
+ # checkpoint['pos_embed_1d.weight'] = torch.zeros(3, model.config.num_attention_heads * model.config.attention_head_dim, device=model.pos_embed_1d.weight.device, dtype = model.pos_embed_1d.weight.dtype)
+ model.load_state_dict(checkpoint, strict=True)
+ del temp_ckpt
+ return model
+
+def init_controlnet(model, base_model):
+ temp_ckpt = deepcopy(base_model)
+ checkpoint = temp_ckpt.state_dict()
+ checkpoint_weight = checkpoint['pos_embed.proj.weight']
+ new_weight = torch.zeros(model.pos_embed.proj.weight.shape, device=model.pos_embed.proj.weight.device, dtype = model.pos_embed.proj.weight.dtype)
+ print('model.pos_embed.proj.weight.shape',model.pos_embed.proj.weight.shape)
+ new_weight[:, :4] = checkpoint_weight
+ checkpoint['pos_embed.proj.weight'] = new_weight
+ print('new_weight', new_weight.dtype)
+ model.load_state_dict(checkpoint, strict=False)
+ del temp_ckpt
+ return model
\ No newline at end of file
diff --git a/diffusers/.DS_Store b/diffusers/.DS_Store
new file mode 100644
index 0000000000000000000000000000000000000000..55e8b6050642f6a9af5511ace40cabc8b9db74fc
Binary files /dev/null and b/diffusers/.DS_Store differ
diff --git a/diffusers/.github/ISSUE_TEMPLATE/bug-report.yml b/diffusers/.github/ISSUE_TEMPLATE/bug-report.yml
new file mode 100644
index 0000000000000000000000000000000000000000..a0517725284e6062287e08dad3495d94b0bd0481
--- /dev/null
+++ b/diffusers/.github/ISSUE_TEMPLATE/bug-report.yml
@@ -0,0 +1,110 @@
+name: "\U0001F41B Bug Report"
+description: Report a bug on Diffusers
+labels: [ "bug" ]
+body:
+ - type: markdown
+ attributes:
+ value: |
+ Thanks a lot for taking the time to file this issue 🤗.
+ Issues do not only help to improve the library, but also publicly document common problems, questions, workflows for the whole community!
+ Thus, issues are of the same importance as pull requests when contributing to this library ❤️.
+ In order to make your issue as **useful for the community as possible**, let's try to stick to some simple guidelines:
+ - 1. Please try to be as precise and concise as possible.
+ *Give your issue a fitting title. Assume that someone which very limited knowledge of Diffusers can understand your issue. Add links to the source code, documentation other issues, pull requests etc...*
+ - 2. If your issue is about something not working, **always** provide a reproducible code snippet. The reader should be able to reproduce your issue by **only copy-pasting your code snippet into a Python shell**.
+ *The community cannot solve your issue if it cannot reproduce it. If your bug is related to training, add your training script and make everything needed to train public. Otherwise, just add a simple Python code snippet.*
+ - 3. Add the **minimum** amount of code / context that is needed to understand, reproduce your issue.
+ *Make the life of maintainers easy. `diffusers` is getting many issues every day. Make sure your issue is about one bug and one bug only. Make sure you add only the context, code needed to understand your issues - nothing more. Generally, every issue is a way of documenting this library, try to make it a good documentation entry.*
+ - 4. For issues related to community pipelines (i.e., the pipelines located in the `examples/community` folder), please tag the author of the pipeline in your issue thread as those pipelines are not maintained.
+ - type: markdown
+ attributes:
+ value: |
+ For more in-detail information on how to write good issues you can have a look [here](https://huggingface.co/course/chapter8/5?fw=pt).
+ - type: textarea
+ id: bug-description
+ attributes:
+ label: Describe the bug
+ description: A clear and concise description of what the bug is. If you intend to submit a pull request for this issue, tell us in the description. Thanks!
+ placeholder: Bug description
+ validations:
+ required: true
+ - type: textarea
+ id: reproduction
+ attributes:
+ label: Reproduction
+ description: Please provide a minimal reproducible code which we can copy/paste and reproduce the issue.
+ placeholder: Reproduction
+ validations:
+ required: true
+ - type: textarea
+ id: logs
+ attributes:
+ label: Logs
+ description: "Please include the Python logs if you can."
+ render: shell
+ - type: textarea
+ id: system-info
+ attributes:
+ label: System Info
+ description: Please share your system info with us. You can run the command `diffusers-cli env` and copy-paste its output below.
+ placeholder: Diffusers version, platform, Python version, ...
+ validations:
+ required: true
+ - type: textarea
+ id: who-can-help
+ attributes:
+ label: Who can help?
+ description: |
+ Your issue will be replied to more quickly if you can figure out the right person to tag with @.
+ If you know how to use git blame, that is the easiest way, otherwise, here is a rough guide of **who to tag**.
+
+ All issues are read by one of the core maintainers, so if you don't know who to tag, just leave this blank and
+ a core maintainer will ping the right person.
+
+ Please tag a maximum of 2 people.
+
+ Questions on DiffusionPipeline (Saving, Loading, From pretrained, ...): @sayakpaul @DN6
+
+ Questions on pipelines:
+ - Stable Diffusion @yiyixuxu @asomoza
+ - Stable Diffusion XL @yiyixuxu @sayakpaul @DN6
+ - Stable Diffusion 3: @yiyixuxu @sayakpaul @DN6 @asomoza
+ - Kandinsky @yiyixuxu
+ - ControlNet @sayakpaul @yiyixuxu @DN6
+ - T2I Adapter @sayakpaul @yiyixuxu @DN6
+ - IF @DN6
+ - Text-to-Video / Video-to-Video @DN6 @a-r-r-o-w
+ - Wuerstchen @DN6
+ - Other: @yiyixuxu @DN6
+ - Improving generation quality: @asomoza
+
+ Questions on models:
+ - UNet @DN6 @yiyixuxu @sayakpaul
+ - VAE @sayakpaul @DN6 @yiyixuxu
+ - Transformers/Attention @DN6 @yiyixuxu @sayakpaul
+
+ Questions on single file checkpoints: @DN6
+
+ Questions on Schedulers: @yiyixuxu
+
+ Questions on LoRA: @sayakpaul
+
+ Questions on Textual Inversion: @sayakpaul
+
+ Questions on Training:
+ - DreamBooth @sayakpaul
+ - Text-to-Image Fine-tuning @sayakpaul
+ - Textual Inversion @sayakpaul
+ - ControlNet @sayakpaul
+
+ Questions on Tests: @DN6 @sayakpaul @yiyixuxu
+
+ Questions on Documentation: @stevhliu
+
+ Questions on JAX- and MPS-related things: @pcuenca
+
+ Questions on audio pipelines: @sanchit-gandhi
+
+
+
+ placeholder: "@Username ..."
diff --git a/diffusers/.github/ISSUE_TEMPLATE/config.yml b/diffusers/.github/ISSUE_TEMPLATE/config.yml
new file mode 100644
index 0000000000000000000000000000000000000000..e81992fe3c69b65f58f627252ffa6569d1cd67e2
--- /dev/null
+++ b/diffusers/.github/ISSUE_TEMPLATE/config.yml
@@ -0,0 +1,4 @@
+contact_links:
+ - name: Questions / Discussions
+ url: https://github.com/huggingface/diffusers/discussions
+ about: General usage questions and community discussions
diff --git a/diffusers/.github/ISSUE_TEMPLATE/feature_request.md b/diffusers/.github/ISSUE_TEMPLATE/feature_request.md
new file mode 100644
index 0000000000000000000000000000000000000000..42f93232c1de7c73dcd90cdb6b0733bbb4461508
--- /dev/null
+++ b/diffusers/.github/ISSUE_TEMPLATE/feature_request.md
@@ -0,0 +1,20 @@
+---
+name: "\U0001F680 Feature Request"
+about: Suggest an idea for this project
+title: ''
+labels: ''
+assignees: ''
+
+---
+
+**Is your feature request related to a problem? Please describe.**
+A clear and concise description of what the problem is. Ex. I'm always frustrated when [...].
+
+**Describe the solution you'd like.**
+A clear and concise description of what you want to happen.
+
+**Describe alternatives you've considered.**
+A clear and concise description of any alternative solutions or features you've considered.
+
+**Additional context.**
+Add any other context or screenshots about the feature request here.
diff --git a/diffusers/.github/ISSUE_TEMPLATE/feedback.md b/diffusers/.github/ISSUE_TEMPLATE/feedback.md
new file mode 100644
index 0000000000000000000000000000000000000000..25808b6575a405694f64dbf1b5a0ece8e0fcd2e2
--- /dev/null
+++ b/diffusers/.github/ISSUE_TEMPLATE/feedback.md
@@ -0,0 +1,12 @@
+---
+name: "💬 Feedback about API Design"
+about: Give feedback about the current API design
+title: ''
+labels: ''
+assignees: ''
+
+---
+
+**What API design would you like to have changed or added to the library? Why?**
+
+**What use case would this enable or better enable? Can you give us a code example?**
diff --git a/diffusers/.github/ISSUE_TEMPLATE/new-model-addition.yml b/diffusers/.github/ISSUE_TEMPLATE/new-model-addition.yml
new file mode 100644
index 0000000000000000000000000000000000000000..432e287dd3348965466a696ee5e01a187f179ee5
--- /dev/null
+++ b/diffusers/.github/ISSUE_TEMPLATE/new-model-addition.yml
@@ -0,0 +1,31 @@
+name: "\U0001F31F New Model/Pipeline/Scheduler Addition"
+description: Submit a proposal/request to implement a new diffusion model/pipeline/scheduler
+labels: [ "New model/pipeline/scheduler" ]
+
+body:
+ - type: textarea
+ id: description-request
+ validations:
+ required: true
+ attributes:
+ label: Model/Pipeline/Scheduler description
+ description: |
+ Put any and all important information relative to the model/pipeline/scheduler
+
+ - type: checkboxes
+ id: information-tasks
+ attributes:
+ label: Open source status
+ description: |
+ Please note that if the model implementation isn't available or if the weights aren't open-source, we are less likely to implement it in `diffusers`.
+ options:
+ - label: "The model implementation is available."
+ - label: "The model weights are available (Only relevant if addition is not a scheduler)."
+
+ - type: textarea
+ id: additional-info
+ attributes:
+ label: Provide useful links for the implementation
+ description: |
+ Please provide information regarding the implementation, the weights, and the authors.
+ Please mention the authors by @gh-username if you're aware of their usernames.
diff --git a/diffusers/.github/ISSUE_TEMPLATE/translate.md b/diffusers/.github/ISSUE_TEMPLATE/translate.md
new file mode 100644
index 0000000000000000000000000000000000000000..3471ec9640d727e7cdf223852d2012834660e88a
--- /dev/null
+++ b/diffusers/.github/ISSUE_TEMPLATE/translate.md
@@ -0,0 +1,29 @@
+---
+name: 🌐 Translating a New Language?
+about: Start a new translation effort in your language
+title: '[] Translating docs to '
+labels: WIP
+assignees: ''
+
+---
+
+
+
+Hi!
+
+Let's bring the documentation to all the -speaking community 🌐.
+
+Who would want to translate? Please follow the 🤗 [TRANSLATING guide](https://github.com/huggingface/diffusers/blob/main/docs/TRANSLATING.md). Here is a list of the files ready for translation. Let us know in this issue if you'd like to translate any, and we'll add your name to the list.
+
+Some notes:
+
+* Please translate using an informal tone (imagine you are talking with a friend about Diffusers 🤗).
+* Please translate in a gender-neutral way.
+* Add your translations to the folder called `` inside the [source folder](https://github.com/huggingface/diffusers/tree/main/docs/source).
+* Register your translation in `/_toctree.yml`; please follow the order of the [English version](https://github.com/huggingface/diffusers/blob/main/docs/source/en/_toctree.yml).
+* Once you're finished, open a pull request and tag this issue by including #issue-number in the description, where issue-number is the number of this issue. Please ping @stevhliu for review.
+* 🙋 If you'd like others to help you with the translation, you can also post in the 🤗 [forums](https://discuss.huggingface.co/c/discussion-related-to-httpsgithubcomhuggingfacediffusers/63).
+
+Thank you so much for your help! 🤗
diff --git a/diffusers/.github/PULL_REQUEST_TEMPLATE.md b/diffusers/.github/PULL_REQUEST_TEMPLATE.md
new file mode 100644
index 0000000000000000000000000000000000000000..e4b2b45a4ecde200e65e3a2e16ac0cdb0b87d7c3
--- /dev/null
+++ b/diffusers/.github/PULL_REQUEST_TEMPLATE.md
@@ -0,0 +1,61 @@
+# What does this PR do?
+
+
+
+
+
+Fixes # (issue)
+
+
+## Before submitting
+- [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
+- [ ] Did you read the [contributor guideline](https://github.com/huggingface/diffusers/blob/main/CONTRIBUTING.md)?
+- [ ] Did you read our [philosophy doc](https://github.com/huggingface/diffusers/blob/main/PHILOSOPHY.md) (important for complex PRs)?
+- [ ] Was this discussed/approved via a GitHub issue or the [forum](https://discuss.huggingface.co/c/discussion-related-to-httpsgithubcomhuggingfacediffusers/63)? Please add a link to it if that's the case.
+- [ ] Did you make sure to update the documentation with your changes? Here are the
+ [documentation guidelines](https://github.com/huggingface/diffusers/tree/main/docs), and
+ [here are tips on formatting docstrings](https://github.com/huggingface/diffusers/tree/main/docs#writing-source-documentation).
+- [ ] Did you write any new necessary tests?
+
+
+## Who can review?
+
+Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
+members/contributors who may be interested in your PR.
+
+
diff --git a/diffusers/.github/actions/setup-miniconda/action.yml b/diffusers/.github/actions/setup-miniconda/action.yml
new file mode 100644
index 0000000000000000000000000000000000000000..b1f4f194bfe1fd14e03239269e466e7978e3d5c5
--- /dev/null
+++ b/diffusers/.github/actions/setup-miniconda/action.yml
@@ -0,0 +1,146 @@
+name: Set up conda environment for testing
+
+description: Sets up miniconda in your ${RUNNER_TEMP} environment and gives you the ${CONDA_RUN} environment variable so you don't have to worry about polluting non-empeheral runners anymore
+
+inputs:
+ python-version:
+ description: If set to any value, don't use sudo to clean the workspace
+ required: false
+ type: string
+ default: "3.9"
+ miniconda-version:
+ description: Miniconda version to install
+ required: false
+ type: string
+ default: "4.12.0"
+ environment-file:
+ description: Environment file to install dependencies from
+ required: false
+ type: string
+ default: ""
+
+runs:
+ using: composite
+ steps:
+ # Use the same trick from https://github.com/marketplace/actions/setup-miniconda
+ # to refresh the cache daily. This is kind of optional though
+ - name: Get date
+ id: get-date
+ shell: bash
+ run: echo "today=$(/bin/date -u '+%Y%m%d')d" >> $GITHUB_OUTPUT
+ - name: Setup miniconda cache
+ id: miniconda-cache
+ uses: actions/cache@v2
+ with:
+ path: ${{ runner.temp }}/miniconda
+ key: miniconda-${{ runner.os }}-${{ runner.arch }}-${{ inputs.python-version }}-${{ steps.get-date.outputs.today }}
+ - name: Install miniconda (${{ inputs.miniconda-version }})
+ if: steps.miniconda-cache.outputs.cache-hit != 'true'
+ env:
+ MINICONDA_VERSION: ${{ inputs.miniconda-version }}
+ shell: bash -l {0}
+ run: |
+ MINICONDA_INSTALL_PATH="${RUNNER_TEMP}/miniconda"
+ mkdir -p "${MINICONDA_INSTALL_PATH}"
+ case ${RUNNER_OS}-${RUNNER_ARCH} in
+ Linux-X64)
+ MINICONDA_ARCH="Linux-x86_64"
+ ;;
+ macOS-ARM64)
+ MINICONDA_ARCH="MacOSX-arm64"
+ ;;
+ macOS-X64)
+ MINICONDA_ARCH="MacOSX-x86_64"
+ ;;
+ *)
+ echo "::error::Platform ${RUNNER_OS}-${RUNNER_ARCH} currently unsupported using this action"
+ exit 1
+ ;;
+ esac
+ MINICONDA_URL="https://repo.anaconda.com/miniconda/Miniconda3-py39_${MINICONDA_VERSION}-${MINICONDA_ARCH}.sh"
+ curl -fsSL "${MINICONDA_URL}" -o "${MINICONDA_INSTALL_PATH}/miniconda.sh"
+ bash "${MINICONDA_INSTALL_PATH}/miniconda.sh" -b -u -p "${MINICONDA_INSTALL_PATH}"
+ rm -rf "${MINICONDA_INSTALL_PATH}/miniconda.sh"
+ - name: Update GitHub path to include miniconda install
+ shell: bash
+ run: |
+ MINICONDA_INSTALL_PATH="${RUNNER_TEMP}/miniconda"
+ echo "${MINICONDA_INSTALL_PATH}/bin" >> $GITHUB_PATH
+ - name: Setup miniconda env cache (with env file)
+ id: miniconda-env-cache-env-file
+ if: ${{ runner.os }} == 'macOS' && ${{ inputs.environment-file }} != ''
+ uses: actions/cache@v2
+ with:
+ path: ${{ runner.temp }}/conda-python-${{ inputs.python-version }}
+ key: miniconda-env-${{ runner.os }}-${{ runner.arch }}-${{ inputs.python-version }}-${{ steps.get-date.outputs.today }}-${{ hashFiles(inputs.environment-file) }}
+ - name: Setup miniconda env cache (without env file)
+ id: miniconda-env-cache
+ if: ${{ runner.os }} == 'macOS' && ${{ inputs.environment-file }} == ''
+ uses: actions/cache@v2
+ with:
+ path: ${{ runner.temp }}/conda-python-${{ inputs.python-version }}
+ key: miniconda-env-${{ runner.os }}-${{ runner.arch }}-${{ inputs.python-version }}-${{ steps.get-date.outputs.today }}
+ - name: Setup conda environment with python (v${{ inputs.python-version }})
+ if: steps.miniconda-env-cache-env-file.outputs.cache-hit != 'true' && steps.miniconda-env-cache.outputs.cache-hit != 'true'
+ shell: bash
+ env:
+ PYTHON_VERSION: ${{ inputs.python-version }}
+ ENV_FILE: ${{ inputs.environment-file }}
+ run: |
+ CONDA_BASE_ENV="${RUNNER_TEMP}/conda-python-${PYTHON_VERSION}"
+ ENV_FILE_FLAG=""
+ if [[ -f "${ENV_FILE}" ]]; then
+ ENV_FILE_FLAG="--file ${ENV_FILE}"
+ elif [[ -n "${ENV_FILE}" ]]; then
+ echo "::warning::Specified env file (${ENV_FILE}) not found, not going to include it"
+ fi
+ conda create \
+ --yes \
+ --prefix "${CONDA_BASE_ENV}" \
+ "python=${PYTHON_VERSION}" \
+ ${ENV_FILE_FLAG} \
+ cmake=3.22 \
+ conda-build=3.21 \
+ ninja=1.10 \
+ pkg-config=0.29 \
+ wheel=0.37
+ - name: Clone the base conda environment and update GitHub env
+ shell: bash
+ env:
+ PYTHON_VERSION: ${{ inputs.python-version }}
+ CONDA_BASE_ENV: ${{ runner.temp }}/conda-python-${{ inputs.python-version }}
+ run: |
+ CONDA_ENV="${RUNNER_TEMP}/conda_environment_${GITHUB_RUN_ID}"
+ conda create \
+ --yes \
+ --prefix "${CONDA_ENV}" \
+ --clone "${CONDA_BASE_ENV}"
+ # TODO: conda-build could not be cloned because it hardcodes the path, so it
+ # could not be cached
+ conda install --yes -p ${CONDA_ENV} conda-build=3.21
+ echo "CONDA_ENV=${CONDA_ENV}" >> "${GITHUB_ENV}"
+ echo "CONDA_RUN=conda run -p ${CONDA_ENV} --no-capture-output" >> "${GITHUB_ENV}"
+ echo "CONDA_BUILD=conda run -p ${CONDA_ENV} conda-build" >> "${GITHUB_ENV}"
+ echo "CONDA_INSTALL=conda install -p ${CONDA_ENV}" >> "${GITHUB_ENV}"
+ - name: Get disk space usage and throw an error for low disk space
+ shell: bash
+ run: |
+ echo "Print the available disk space for manual inspection"
+ df -h
+ # Set the minimum requirement space to 4GB
+ MINIMUM_AVAILABLE_SPACE_IN_GB=4
+ MINIMUM_AVAILABLE_SPACE_IN_KB=$(($MINIMUM_AVAILABLE_SPACE_IN_GB * 1024 * 1024))
+ # Use KB to avoid floating point warning like 3.1GB
+ df -k | tr -s ' ' | cut -d' ' -f 4,9 | while read -r LINE;
+ do
+ AVAIL=$(echo $LINE | cut -f1 -d' ')
+ MOUNT=$(echo $LINE | cut -f2 -d' ')
+ if [ "$MOUNT" = "/" ]; then
+ if [ "$AVAIL" -lt "$MINIMUM_AVAILABLE_SPACE_IN_KB" ]; then
+ echo "There is only ${AVAIL}KB free space left in $MOUNT, which is less than the minimum requirement of ${MINIMUM_AVAILABLE_SPACE_IN_KB}KB. Please help create an issue to PyTorch Release Engineering via https://github.com/pytorch/test-infra/issues and provide the link to the workflow run."
+ exit 1;
+ else
+ echo "There is ${AVAIL}KB free space left in $MOUNT, continue"
+ fi
+ fi
+ done
diff --git a/diffusers/.github/workflows/benchmark.yml b/diffusers/.github/workflows/benchmark.yml
new file mode 100644
index 0000000000000000000000000000000000000000..d311c1c73f1105bfd7d942ba4a80c6d431ebf30a
--- /dev/null
+++ b/diffusers/.github/workflows/benchmark.yml
@@ -0,0 +1,67 @@
+name: Benchmarking tests
+
+on:
+ workflow_dispatch:
+ schedule:
+ - cron: "30 1 1,15 * *" # every 2 weeks on the 1st and the 15th of every month at 1:30 AM
+
+env:
+ DIFFUSERS_IS_CI: yes
+ HF_HUB_ENABLE_HF_TRANSFER: 1
+ HF_HOME: /mnt/cache
+ OMP_NUM_THREADS: 8
+ MKL_NUM_THREADS: 8
+
+jobs:
+ torch_pipelines_cuda_benchmark_tests:
+ env:
+ SLACK_WEBHOOK_URL: ${{ secrets.SLACK_WEBHOOK_URL_BENCHMARK }}
+ name: Torch Core Pipelines CUDA Benchmarking Tests
+ strategy:
+ fail-fast: false
+ max-parallel: 1
+ runs-on:
+ group: aws-g6-4xlarge-plus
+ container:
+ image: diffusers/diffusers-pytorch-compile-cuda
+ options: --shm-size "16gb" --ipc host --gpus 0
+ steps:
+ - name: Checkout diffusers
+ uses: actions/checkout@v3
+ with:
+ fetch-depth: 2
+ - name: NVIDIA-SMI
+ run: |
+ nvidia-smi
+ - name: Install dependencies
+ run: |
+ python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
+ python -m uv pip install -e [quality,test]
+ python -m uv pip install pandas peft
+ - name: Environment
+ run: |
+ python utils/print_env.py
+ - name: Diffusers Benchmarking
+ env:
+ HF_TOKEN: ${{ secrets.DIFFUSERS_BOT_TOKEN }}
+ BASE_PATH: benchmark_outputs
+ run: |
+ export TOTAL_GPU_MEMORY=$(python -c "import torch; print(torch.cuda.get_device_properties(0).total_memory / (1024**3))")
+ cd benchmarks && mkdir ${BASE_PATH} && python run_all.py && python push_results.py
+
+ - name: Test suite reports artifacts
+ if: ${{ always() }}
+ uses: actions/upload-artifact@v4
+ with:
+ name: benchmark_test_reports
+ path: benchmarks/benchmark_outputs
+
+ - name: Report success status
+ if: ${{ success() }}
+ run: |
+ pip install requests && python utils/notify_benchmarking_status.py --status=success
+
+ - name: Report failure status
+ if: ${{ failure() }}
+ run: |
+ pip install requests && python utils/notify_benchmarking_status.py --status=failure
\ No newline at end of file
diff --git a/diffusers/.github/workflows/build_docker_images.yml b/diffusers/.github/workflows/build_docker_images.yml
new file mode 100644
index 0000000000000000000000000000000000000000..9f4776db4315d9418a8a5bd07877f0a08f9ec2dd
--- /dev/null
+++ b/diffusers/.github/workflows/build_docker_images.yml
@@ -0,0 +1,103 @@
+name: Test, build, and push Docker images
+
+on:
+ pull_request: # During PRs, we just check if the changes Dockerfiles can be successfully built
+ branches:
+ - main
+ paths:
+ - "docker/**"
+ workflow_dispatch:
+ schedule:
+ - cron: "0 0 * * *" # every day at midnight
+
+concurrency:
+ group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
+ cancel-in-progress: true
+
+env:
+ REGISTRY: diffusers
+ CI_SLACK_CHANNEL: ${{ secrets.CI_DOCKER_CHANNEL }}
+
+jobs:
+ test-build-docker-images:
+ runs-on:
+ group: aws-general-8-plus
+ if: github.event_name == 'pull_request'
+ steps:
+ - name: Set up Docker Buildx
+ uses: docker/setup-buildx-action@v1
+
+ - name: Check out code
+ uses: actions/checkout@v3
+
+ - name: Find Changed Dockerfiles
+ id: file_changes
+ uses: jitterbit/get-changed-files@v1
+ with:
+ format: 'space-delimited'
+ token: ${{ secrets.GITHUB_TOKEN }}
+
+ - name: Build Changed Docker Images
+ run: |
+ CHANGED_FILES="${{ steps.file_changes.outputs.all }}"
+ for FILE in $CHANGED_FILES; do
+ if [[ "$FILE" == docker/*Dockerfile ]]; then
+ DOCKER_PATH="${FILE%/Dockerfile}"
+ DOCKER_TAG=$(basename "$DOCKER_PATH")
+ echo "Building Docker image for $DOCKER_TAG"
+ docker build -t "$DOCKER_TAG" "$DOCKER_PATH"
+ fi
+ done
+ if: steps.file_changes.outputs.all != ''
+
+ build-and-push-docker-images:
+ runs-on:
+ group: aws-general-8-plus
+ if: github.event_name != 'pull_request'
+
+ permissions:
+ contents: read
+ packages: write
+
+ strategy:
+ fail-fast: false
+ matrix:
+ image-name:
+ - diffusers-pytorch-cpu
+ - diffusers-pytorch-cuda
+ - diffusers-pytorch-compile-cuda
+ - diffusers-pytorch-xformers-cuda
+ - diffusers-flax-cpu
+ - diffusers-flax-tpu
+ - diffusers-onnxruntime-cpu
+ - diffusers-onnxruntime-cuda
+ - diffusers-doc-builder
+
+ steps:
+ - name: Checkout repository
+ uses: actions/checkout@v3
+ - name: Set up Docker Buildx
+ uses: docker/setup-buildx-action@v1
+ - name: Login to Docker Hub
+ uses: docker/login-action@v2
+ with:
+ username: ${{ env.REGISTRY }}
+ password: ${{ secrets.DOCKERHUB_TOKEN }}
+ - name: Build and push
+ uses: docker/build-push-action@v3
+ with:
+ no-cache: true
+ context: ./docker/${{ matrix.image-name }}
+ push: true
+ tags: ${{ env.REGISTRY }}/${{ matrix.image-name }}:latest
+
+ - name: Post to a Slack channel
+ id: slack
+ uses: huggingface/hf-workflows/.github/actions/post-slack@main
+ with:
+ # Slack channel id, channel name, or user id to post message.
+ # See also: https://api.slack.com/methods/chat.postMessage#channels
+ slack_channel: ${{ env.CI_SLACK_CHANNEL }}
+ title: "🤗 Results of the ${{ matrix.image-name }} Docker Image build"
+ status: ${{ job.status }}
+ slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
diff --git a/diffusers/.github/workflows/build_documentation.yml b/diffusers/.github/workflows/build_documentation.yml
new file mode 100644
index 0000000000000000000000000000000000000000..6d4193e3cccc42e6824e9f0881a6b3e50bfa7173
--- /dev/null
+++ b/diffusers/.github/workflows/build_documentation.yml
@@ -0,0 +1,27 @@
+name: Build documentation
+
+on:
+ push:
+ branches:
+ - main
+ - doc-builder*
+ - v*-release
+ - v*-patch
+ paths:
+ - "src/diffusers/**.py"
+ - "examples/**"
+ - "docs/**"
+
+jobs:
+ build:
+ uses: huggingface/doc-builder/.github/workflows/build_main_documentation.yml@main
+ with:
+ commit_sha: ${{ github.sha }}
+ install_libgl1: true
+ package: diffusers
+ notebook_folder: diffusers_doc
+ languages: en ko zh ja pt
+ custom_container: diffusers/diffusers-doc-builder
+ secrets:
+ token: ${{ secrets.HUGGINGFACE_PUSH }}
+ hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }}
diff --git a/diffusers/.github/workflows/build_pr_documentation.yml b/diffusers/.github/workflows/build_pr_documentation.yml
new file mode 100644
index 0000000000000000000000000000000000000000..52e0757331639c7335132766cfc2afb2d74e0368
--- /dev/null
+++ b/diffusers/.github/workflows/build_pr_documentation.yml
@@ -0,0 +1,23 @@
+name: Build PR Documentation
+
+on:
+ pull_request:
+ paths:
+ - "src/diffusers/**.py"
+ - "examples/**"
+ - "docs/**"
+
+concurrency:
+ group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
+ cancel-in-progress: true
+
+jobs:
+ build:
+ uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yml@main
+ with:
+ commit_sha: ${{ github.event.pull_request.head.sha }}
+ pr_number: ${{ github.event.number }}
+ install_libgl1: true
+ package: diffusers
+ languages: en ko zh ja pt
+ custom_container: diffusers/diffusers-doc-builder
diff --git a/diffusers/.github/workflows/mirror_community_pipeline.yml b/diffusers/.github/workflows/mirror_community_pipeline.yml
new file mode 100644
index 0000000000000000000000000000000000000000..a7a2a809bbebc5b32aec5d3743d746dd0d7c794b
--- /dev/null
+++ b/diffusers/.github/workflows/mirror_community_pipeline.yml
@@ -0,0 +1,102 @@
+name: Mirror Community Pipeline
+
+on:
+ # Push changes on the main branch
+ push:
+ branches:
+ - main
+ paths:
+ - 'examples/community/**.py'
+
+ # And on tag creation (e.g. `v0.28.1`)
+ tags:
+ - '*'
+
+ # Manual trigger with ref input
+ workflow_dispatch:
+ inputs:
+ ref:
+ description: "Either 'main' or a tag ref"
+ required: true
+ default: 'main'
+
+jobs:
+ mirror_community_pipeline:
+ env:
+ SLACK_WEBHOOK_URL: ${{ secrets.SLACK_WEBHOOK_URL_COMMUNITY_MIRROR }}
+
+ runs-on: ubuntu-latest
+ steps:
+ # Checkout to correct ref
+ # If workflow dispatch
+ # If ref is 'main', set:
+ # CHECKOUT_REF=refs/heads/main
+ # PATH_IN_REPO=main
+ # Else it must be a tag. Set:
+ # CHECKOUT_REF=refs/tags/{tag}
+ # PATH_IN_REPO={tag}
+ # If not workflow dispatch
+ # If ref is 'refs/heads/main' => set 'main'
+ # Else it must be a tag => set {tag}
+ - name: Set checkout_ref and path_in_repo
+ run: |
+ if [ "${{ github.event_name }}" == "workflow_dispatch" ]; then
+ if [ -z "${{ github.event.inputs.ref }}" ]; then
+ echo "Error: Missing ref input"
+ exit 1
+ elif [ "${{ github.event.inputs.ref }}" == "main" ]; then
+ echo "CHECKOUT_REF=refs/heads/main" >> $GITHUB_ENV
+ echo "PATH_IN_REPO=main" >> $GITHUB_ENV
+ else
+ echo "CHECKOUT_REF=refs/tags/${{ github.event.inputs.ref }}" >> $GITHUB_ENV
+ echo "PATH_IN_REPO=${{ github.event.inputs.ref }}" >> $GITHUB_ENV
+ fi
+ elif [ "${{ github.ref }}" == "refs/heads/main" ]; then
+ echo "CHECKOUT_REF=${{ github.ref }}" >> $GITHUB_ENV
+ echo "PATH_IN_REPO=main" >> $GITHUB_ENV
+ else
+ # e.g. refs/tags/v0.28.1 -> v0.28.1
+ echo "CHECKOUT_REF=${{ github.ref }}" >> $GITHUB_ENV
+ echo "PATH_IN_REPO=$(echo ${{ github.ref }} | sed 's/^refs\/tags\///')" >> $GITHUB_ENV
+ fi
+ - name: Print env vars
+ run: |
+ echo "CHECKOUT_REF: ${{ env.CHECKOUT_REF }}"
+ echo "PATH_IN_REPO: ${{ env.PATH_IN_REPO }}"
+ - uses: actions/checkout@v3
+ with:
+ ref: ${{ env.CHECKOUT_REF }}
+
+ # Setup + install dependencies
+ - name: Set up Python
+ uses: actions/setup-python@v4
+ with:
+ python-version: "3.10"
+ - name: Install dependencies
+ run: |
+ python -m pip install --upgrade pip
+ pip install --upgrade huggingface_hub
+
+ # Check secret is set
+ - name: whoami
+ run: huggingface-cli whoami
+ env:
+ HF_TOKEN: ${{ secrets.HF_TOKEN_MIRROR_COMMUNITY_PIPELINES }}
+
+ # Push to HF! (under subfolder based on checkout ref)
+ # https://huggingface.co/datasets/diffusers/community-pipelines-mirror
+ - name: Mirror community pipeline to HF
+ run: huggingface-cli upload diffusers/community-pipelines-mirror ./examples/community ${PATH_IN_REPO} --repo-type dataset
+ env:
+ PATH_IN_REPO: ${{ env.PATH_IN_REPO }}
+ HF_TOKEN: ${{ secrets.HF_TOKEN_MIRROR_COMMUNITY_PIPELINES }}
+
+ - name: Report success status
+ if: ${{ success() }}
+ run: |
+ pip install requests && python utils/notify_community_pipelines_mirror.py --status=success
+
+ - name: Report failure status
+ if: ${{ failure() }}
+ run: |
+ pip install requests && python utils/notify_community_pipelines_mirror.py --status=failure
\ No newline at end of file
diff --git a/diffusers/.github/workflows/nightly_tests.yml b/diffusers/.github/workflows/nightly_tests.yml
new file mode 100644
index 0000000000000000000000000000000000000000..142dbb0f1e8fcddc0d27951904990edd00c40d9c
--- /dev/null
+++ b/diffusers/.github/workflows/nightly_tests.yml
@@ -0,0 +1,408 @@
+name: Nightly and release tests on main/release branch
+
+on:
+ workflow_dispatch:
+ schedule:
+ - cron: "0 0 * * *" # every day at midnight
+
+env:
+ DIFFUSERS_IS_CI: yes
+ HF_HUB_ENABLE_HF_TRANSFER: 1
+ OMP_NUM_THREADS: 8
+ MKL_NUM_THREADS: 8
+ PYTEST_TIMEOUT: 600
+ RUN_SLOW: yes
+ RUN_NIGHTLY: yes
+ PIPELINE_USAGE_CUTOFF: 5000
+ SLACK_API_TOKEN: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
+
+jobs:
+ setup_torch_cuda_pipeline_matrix:
+ name: Setup Torch Pipelines CUDA Slow Tests Matrix
+ runs-on:
+ group: aws-general-8-plus
+ container:
+ image: diffusers/diffusers-pytorch-cpu
+ outputs:
+ pipeline_test_matrix: ${{ steps.fetch_pipeline_matrix.outputs.pipeline_test_matrix }}
+ steps:
+ - name: Checkout diffusers
+ uses: actions/checkout@v3
+ with:
+ fetch-depth: 2
+ - name: Install dependencies
+ run: |
+ pip install -e .[test]
+ pip install huggingface_hub
+ - name: Fetch Pipeline Matrix
+ id: fetch_pipeline_matrix
+ run: |
+ matrix=$(python utils/fetch_torch_cuda_pipeline_test_matrix.py)
+ echo $matrix
+ echo "pipeline_test_matrix=$matrix" >> $GITHUB_OUTPUT
+
+ - name: Pipeline Tests Artifacts
+ if: ${{ always() }}
+ uses: actions/upload-artifact@v4
+ with:
+ name: test-pipelines.json
+ path: reports
+
+ run_nightly_tests_for_torch_pipelines:
+ name: Nightly Torch Pipelines CUDA Tests
+ needs: setup_torch_cuda_pipeline_matrix
+ strategy:
+ fail-fast: false
+ max-parallel: 8
+ matrix:
+ module: ${{ fromJson(needs.setup_torch_cuda_pipeline_matrix.outputs.pipeline_test_matrix) }}
+ runs-on:
+ group: aws-g4dn-2xlarge
+ container:
+ image: diffusers/diffusers-pytorch-cuda
+ options: --shm-size "16gb" --ipc host --gpus 0
+ steps:
+ - name: Checkout diffusers
+ uses: actions/checkout@v3
+ with:
+ fetch-depth: 2
+ - name: NVIDIA-SMI
+ run: nvidia-smi
+ - name: Install dependencies
+ run: |
+ python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
+ python -m uv pip install -e [quality,test]
+ pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
+ python -m uv pip install pytest-reportlog
+ - name: Environment
+ run: |
+ python utils/print_env.py
+ - name: Pipeline CUDA Test
+ env:
+ HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
+ # https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
+ CUBLAS_WORKSPACE_CONFIG: :16:8
+ run: |
+ python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
+ -s -v -k "not Flax and not Onnx" \
+ --make-reports=tests_pipeline_${{ matrix.module }}_cuda \
+ --report-log=tests_pipeline_${{ matrix.module }}_cuda.log \
+ tests/pipelines/${{ matrix.module }}
+ - name: Failure short reports
+ if: ${{ failure() }}
+ run: |
+ cat reports/tests_pipeline_${{ matrix.module }}_cuda_stats.txt
+ cat reports/tests_pipeline_${{ matrix.module }}_cuda_failures_short.txt
+ - name: Test suite reports artifacts
+ if: ${{ always() }}
+ uses: actions/upload-artifact@v4
+ with:
+ name: pipeline_${{ matrix.module }}_test_reports
+ path: reports
+ - name: Generate Report and Notify Channel
+ if: always()
+ run: |
+ pip install slack_sdk tabulate
+ python utils/log_reports.py >> $GITHUB_STEP_SUMMARY
+
+ run_nightly_tests_for_other_torch_modules:
+ name: Nightly Torch CUDA Tests
+ runs-on:
+ group: aws-g4dn-2xlarge
+ container:
+ image: diffusers/diffusers-pytorch-cuda
+ options: --shm-size "16gb" --ipc host --gpus 0
+ defaults:
+ run:
+ shell: bash
+ strategy:
+ fail-fast: false
+ max-parallel: 2
+ matrix:
+ module: [models, schedulers, lora, others, single_file, examples]
+ steps:
+ - name: Checkout diffusers
+ uses: actions/checkout@v3
+ with:
+ fetch-depth: 2
+
+ - name: Install dependencies
+ run: |
+ python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
+ python -m uv pip install -e [quality,test]
+ python -m uv pip install peft@git+https://github.com/huggingface/peft.git
+ pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
+ python -m uv pip install pytest-reportlog
+ - name: Environment
+ run: python utils/print_env.py
+
+ - name: Run nightly PyTorch CUDA tests for non-pipeline modules
+ if: ${{ matrix.module != 'examples'}}
+ env:
+ HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
+ # https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
+ CUBLAS_WORKSPACE_CONFIG: :16:8
+ run: |
+ python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
+ -s -v -k "not Flax and not Onnx" \
+ --make-reports=tests_torch_${{ matrix.module }}_cuda \
+ --report-log=tests_torch_${{ matrix.module }}_cuda.log \
+ tests/${{ matrix.module }}
+
+ - name: Run nightly example tests with Torch
+ if: ${{ matrix.module == 'examples' }}
+ env:
+ HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
+ # https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
+ CUBLAS_WORKSPACE_CONFIG: :16:8
+ run: |
+ python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
+ -s -v --make-reports=examples_torch_cuda \
+ --report-log=examples_torch_cuda.log \
+ examples/
+
+ - name: Failure short reports
+ if: ${{ failure() }}
+ run: |
+ cat reports/tests_torch_${{ matrix.module }}_cuda_stats.txt
+ cat reports/tests_torch_${{ matrix.module }}_cuda_failures_short.txt
+
+ - name: Test suite reports artifacts
+ if: ${{ always() }}
+ uses: actions/upload-artifact@v4
+ with:
+ name: torch_${{ matrix.module }}_cuda_test_reports
+ path: reports
+
+ - name: Generate Report and Notify Channel
+ if: always()
+ run: |
+ pip install slack_sdk tabulate
+ python utils/log_reports.py >> $GITHUB_STEP_SUMMARY
+
+ run_flax_tpu_tests:
+ name: Nightly Flax TPU Tests
+ runs-on: docker-tpu
+ if: github.event_name == 'schedule'
+
+ container:
+ image: diffusers/diffusers-flax-tpu
+ options: --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache/ --privileged
+ defaults:
+ run:
+ shell: bash
+ steps:
+ - name: Checkout diffusers
+ uses: actions/checkout@v3
+ with:
+ fetch-depth: 2
+
+ - name: Install dependencies
+ run: |
+ python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
+ python -m uv pip install -e [quality,test]
+ pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
+ python -m uv pip install pytest-reportlog
+
+ - name: Environment
+ run: python utils/print_env.py
+
+ - name: Run nightly Flax TPU tests
+ env:
+ HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
+ run: |
+ python -m pytest -n 0 \
+ -s -v -k "Flax" \
+ --make-reports=tests_flax_tpu \
+ --report-log=tests_flax_tpu.log \
+ tests/
+
+ - name: Failure short reports
+ if: ${{ failure() }}
+ run: |
+ cat reports/tests_flax_tpu_stats.txt
+ cat reports/tests_flax_tpu_failures_short.txt
+
+ - name: Test suite reports artifacts
+ if: ${{ always() }}
+ uses: actions/upload-artifact@v4
+ with:
+ name: flax_tpu_test_reports
+ path: reports
+
+ - name: Generate Report and Notify Channel
+ if: always()
+ run: |
+ pip install slack_sdk tabulate
+ python utils/log_reports.py >> $GITHUB_STEP_SUMMARY
+
+ run_nightly_onnx_tests:
+ name: Nightly ONNXRuntime CUDA tests on Ubuntu
+ runs-on:
+ group: aws-g4dn-2xlarge
+ container:
+ image: diffusers/diffusers-onnxruntime-cuda
+ options: --gpus 0 --shm-size "16gb" --ipc host
+
+ steps:
+ - name: Checkout diffusers
+ uses: actions/checkout@v3
+ with:
+ fetch-depth: 2
+
+ - name: NVIDIA-SMI
+ run: nvidia-smi
+
+ - name: Install dependencies
+ run: |
+ python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
+ python -m uv pip install -e [quality,test]
+ pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
+ python -m uv pip install pytest-reportlog
+ - name: Environment
+ run: python utils/print_env.py
+
+ - name: Run Nightly ONNXRuntime CUDA tests
+ env:
+ HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
+ run: |
+ python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
+ -s -v -k "Onnx" \
+ --make-reports=tests_onnx_cuda \
+ --report-log=tests_onnx_cuda.log \
+ tests/
+
+ - name: Failure short reports
+ if: ${{ failure() }}
+ run: |
+ cat reports/tests_onnx_cuda_stats.txt
+ cat reports/tests_onnx_cuda_failures_short.txt
+
+ - name: Test suite reports artifacts
+ if: ${{ always() }}
+ uses: actions/upload-artifact@v4
+ with:
+ name: tests_onnx_cuda_reports
+ path: reports
+
+ - name: Generate Report and Notify Channel
+ if: always()
+ run: |
+ pip install slack_sdk tabulate
+ python utils/log_reports.py >> $GITHUB_STEP_SUMMARY
+
+# M1 runner currently not well supported
+# TODO: (Dhruv) add these back when we setup better testing for Apple Silicon
+# run_nightly_tests_apple_m1:
+# name: Nightly PyTorch MPS tests on MacOS
+# runs-on: [ self-hosted, apple-m1 ]
+# if: github.event_name == 'schedule'
+#
+# steps:
+# - name: Checkout diffusers
+# uses: actions/checkout@v3
+# with:
+# fetch-depth: 2
+#
+# - name: Clean checkout
+# shell: arch -arch arm64 bash {0}
+# run: |
+# git clean -fxd
+# - name: Setup miniconda
+# uses: ./.github/actions/setup-miniconda
+# with:
+# python-version: 3.9
+#
+# - name: Install dependencies
+# shell: arch -arch arm64 bash {0}
+# run: |
+# ${CONDA_RUN} python -m pip install --upgrade pip uv
+# ${CONDA_RUN} python -m uv pip install -e [quality,test]
+# ${CONDA_RUN} python -m uv pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu
+# ${CONDA_RUN} python -m uv pip install accelerate@git+https://github.com/huggingface/accelerate
+# ${CONDA_RUN} python -m uv pip install pytest-reportlog
+# - name: Environment
+# shell: arch -arch arm64 bash {0}
+# run: |
+# ${CONDA_RUN} python utils/print_env.py
+# - name: Run nightly PyTorch tests on M1 (MPS)
+# shell: arch -arch arm64 bash {0}
+# env:
+# HF_HOME: /System/Volumes/Data/mnt/cache
+# HF_TOKEN: ${{ secrets.HF_TOKEN }}
+# run: |
+# ${CONDA_RUN} python -m pytest -n 1 -s -v --make-reports=tests_torch_mps \
+# --report-log=tests_torch_mps.log \
+# tests/
+# - name: Failure short reports
+# if: ${{ failure() }}
+# run: cat reports/tests_torch_mps_failures_short.txt
+#
+# - name: Test suite reports artifacts
+# if: ${{ always() }}
+# uses: actions/upload-artifact@v4
+# with:
+# name: torch_mps_test_reports
+# path: reports
+#
+# - name: Generate Report and Notify Channel
+# if: always()
+# run: |
+# pip install slack_sdk tabulate
+# python utils/log_reports.py >> $GITHUB_STEP_SUMMARY run_nightly_tests_apple_m1:
+# name: Nightly PyTorch MPS tests on MacOS
+# runs-on: [ self-hosted, apple-m1 ]
+# if: github.event_name == 'schedule'
+#
+# steps:
+# - name: Checkout diffusers
+# uses: actions/checkout@v3
+# with:
+# fetch-depth: 2
+#
+# - name: Clean checkout
+# shell: arch -arch arm64 bash {0}
+# run: |
+# git clean -fxd
+# - name: Setup miniconda
+# uses: ./.github/actions/setup-miniconda
+# with:
+# python-version: 3.9
+#
+# - name: Install dependencies
+# shell: arch -arch arm64 bash {0}
+# run: |
+# ${CONDA_RUN} python -m pip install --upgrade pip uv
+# ${CONDA_RUN} python -m uv pip install -e [quality,test]
+# ${CONDA_RUN} python -m uv pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu
+# ${CONDA_RUN} python -m uv pip install accelerate@git+https://github.com/huggingface/accelerate
+# ${CONDA_RUN} python -m uv pip install pytest-reportlog
+# - name: Environment
+# shell: arch -arch arm64 bash {0}
+# run: |
+# ${CONDA_RUN} python utils/print_env.py
+# - name: Run nightly PyTorch tests on M1 (MPS)
+# shell: arch -arch arm64 bash {0}
+# env:
+# HF_HOME: /System/Volumes/Data/mnt/cache
+# HF_TOKEN: ${{ secrets.HF_TOKEN }}
+# run: |
+# ${CONDA_RUN} python -m pytest -n 1 -s -v --make-reports=tests_torch_mps \
+# --report-log=tests_torch_mps.log \
+# tests/
+# - name: Failure short reports
+# if: ${{ failure() }}
+# run: cat reports/tests_torch_mps_failures_short.txt
+#
+# - name: Test suite reports artifacts
+# if: ${{ always() }}
+# uses: actions/upload-artifact@v4
+# with:
+# name: torch_mps_test_reports
+# path: reports
+#
+# - name: Generate Report and Notify Channel
+# if: always()
+# run: |
+# pip install slack_sdk tabulate
+# python utils/log_reports.py >> $GITHUB_STEP_SUMMARY
\ No newline at end of file
diff --git a/diffusers/.github/workflows/notify_slack_about_release.yml b/diffusers/.github/workflows/notify_slack_about_release.yml
new file mode 100644
index 0000000000000000000000000000000000000000..f33e917f095cf64f0b8b640d65e66fb54849517a
--- /dev/null
+++ b/diffusers/.github/workflows/notify_slack_about_release.yml
@@ -0,0 +1,23 @@
+name: Notify Slack about a release
+
+on:
+ workflow_dispatch:
+ release:
+ types: [published]
+
+jobs:
+ build:
+ runs-on: ubuntu-latest
+
+ steps:
+ - uses: actions/checkout@v3
+
+ - name: Setup Python
+ uses: actions/setup-python@v4
+ with:
+ python-version: '3.8'
+
+ - name: Notify Slack about the release
+ env:
+ SLACK_WEBHOOK_URL: ${{ secrets.SLACK_WEBHOOK_URL }}
+ run: pip install requests && python utils/notify_slack_about_release.py
diff --git a/diffusers/.github/workflows/pr_dependency_test.yml b/diffusers/.github/workflows/pr_dependency_test.yml
new file mode 100644
index 0000000000000000000000000000000000000000..c4bd86112fb88556315d25ae9824287e1b5e0ee4
--- /dev/null
+++ b/diffusers/.github/workflows/pr_dependency_test.yml
@@ -0,0 +1,35 @@
+name: Run dependency tests
+
+on:
+ pull_request:
+ branches:
+ - main
+ paths:
+ - "src/diffusers/**.py"
+ push:
+ branches:
+ - main
+
+concurrency:
+ group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
+ cancel-in-progress: true
+
+jobs:
+ check_dependencies:
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v3
+ - name: Set up Python
+ uses: actions/setup-python@v4
+ with:
+ python-version: "3.8"
+ - name: Install dependencies
+ run: |
+ python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
+ python -m pip install --upgrade pip uv
+ python -m uv pip install -e .
+ python -m uv pip install pytest
+ - name: Check for soft dependencies
+ run: |
+ python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
+ pytest tests/others/test_dependencies.py
diff --git a/diffusers/.github/workflows/pr_flax_dependency_test.yml b/diffusers/.github/workflows/pr_flax_dependency_test.yml
new file mode 100644
index 0000000000000000000000000000000000000000..bbad729299177146ab68cd417ba603f3016104dd
--- /dev/null
+++ b/diffusers/.github/workflows/pr_flax_dependency_test.yml
@@ -0,0 +1,38 @@
+name: Run Flax dependency tests
+
+on:
+ pull_request:
+ branches:
+ - main
+ paths:
+ - "src/diffusers/**.py"
+ push:
+ branches:
+ - main
+
+concurrency:
+ group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
+ cancel-in-progress: true
+
+jobs:
+ check_flax_dependencies:
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v3
+ - name: Set up Python
+ uses: actions/setup-python@v4
+ with:
+ python-version: "3.8"
+ - name: Install dependencies
+ run: |
+ python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
+ python -m pip install --upgrade pip uv
+ python -m uv pip install -e .
+ python -m uv pip install "jax[cpu]>=0.2.16,!=0.3.2"
+ python -m uv pip install "flax>=0.4.1"
+ python -m uv pip install "jaxlib>=0.1.65"
+ python -m uv pip install pytest
+ - name: Check for soft dependencies
+ run: |
+ python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
+ pytest tests/others/test_dependencies.py
diff --git a/diffusers/.github/workflows/pr_test_fetcher.yml b/diffusers/.github/workflows/pr_test_fetcher.yml
new file mode 100644
index 0000000000000000000000000000000000000000..b032bb842786a73761b9ef73dfc1cc87a4ab0e26
--- /dev/null
+++ b/diffusers/.github/workflows/pr_test_fetcher.yml
@@ -0,0 +1,177 @@
+name: Fast tests for PRs - Test Fetcher
+
+on: workflow_dispatch
+
+env:
+ DIFFUSERS_IS_CI: yes
+ OMP_NUM_THREADS: 4
+ MKL_NUM_THREADS: 4
+ PYTEST_TIMEOUT: 60
+
+concurrency:
+ group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
+ cancel-in-progress: true
+
+jobs:
+ setup_pr_tests:
+ name: Setup PR Tests
+ runs-on:
+ group: aws-general-8-plus
+ container:
+ image: diffusers/diffusers-pytorch-cpu
+ options: --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache/
+ defaults:
+ run:
+ shell: bash
+ outputs:
+ matrix: ${{ steps.set_matrix.outputs.matrix }}
+ test_map: ${{ steps.set_matrix.outputs.test_map }}
+ steps:
+ - name: Checkout diffusers
+ uses: actions/checkout@v3
+ with:
+ fetch-depth: 0
+ - name: Install dependencies
+ run: |
+ python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
+ python -m uv pip install -e [quality,test]
+ - name: Environment
+ run: |
+ python utils/print_env.py
+ echo $(git --version)
+ - name: Fetch Tests
+ run: |
+ python utils/tests_fetcher.py | tee test_preparation.txt
+ - name: Report fetched tests
+ uses: actions/upload-artifact@v3
+ with:
+ name: test_fetched
+ path: test_preparation.txt
+ - id: set_matrix
+ name: Create Test Matrix
+ # The `keys` is used as GitHub actions matrix for jobs, i.e. `models`, `pipelines`, etc.
+ # The `test_map` is used to get the actual identified test files under each key.
+ # If no test to run (so no `test_map.json` file), create a dummy map (empty matrix will fail)
+ run: |
+ if [ -f test_map.json ]; then
+ keys=$(python3 -c 'import json; fp = open("test_map.json"); test_map = json.load(fp); fp.close(); d = list(test_map.keys()); print(json.dumps(d))')
+ test_map=$(python3 -c 'import json; fp = open("test_map.json"); test_map = json.load(fp); fp.close(); print(json.dumps(test_map))')
+ else
+ keys=$(python3 -c 'keys = ["dummy"]; print(keys)')
+ test_map=$(python3 -c 'test_map = {"dummy": []}; print(test_map)')
+ fi
+ echo $keys
+ echo $test_map
+ echo "matrix=$keys" >> $GITHUB_OUTPUT
+ echo "test_map=$test_map" >> $GITHUB_OUTPUT
+
+ run_pr_tests:
+ name: Run PR Tests
+ needs: setup_pr_tests
+ if: contains(fromJson(needs.setup_pr_tests.outputs.matrix), 'dummy') != true
+ strategy:
+ fail-fast: false
+ max-parallel: 2
+ matrix:
+ modules: ${{ fromJson(needs.setup_pr_tests.outputs.matrix) }}
+ runs-on:
+ group: aws-general-8-plus
+ container:
+ image: diffusers/diffusers-pytorch-cpu
+ options: --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache/
+ defaults:
+ run:
+ shell: bash
+ steps:
+ - name: Checkout diffusers
+ uses: actions/checkout@v3
+ with:
+ fetch-depth: 2
+
+ - name: Install dependencies
+ run: |
+ python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
+ python -m pip install -e [quality,test]
+ python -m pip install accelerate
+
+ - name: Environment
+ run: |
+ python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
+ python utils/print_env.py
+
+ - name: Run all selected tests on CPU
+ run: |
+ python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
+ python -m pytest -n 2 --dist=loadfile -v --make-reports=${{ matrix.modules }}_tests_cpu ${{ fromJson(needs.setup_pr_tests.outputs.test_map)[matrix.modules] }}
+
+ - name: Failure short reports
+ if: ${{ failure() }}
+ continue-on-error: true
+ run: |
+ cat reports/${{ matrix.modules }}_tests_cpu_stats.txt
+ cat reports/${{ matrix.modules }}_tests_cpu_failures_short.txt
+
+ - name: Test suite reports artifacts
+ if: ${{ always() }}
+ uses: actions/upload-artifact@v3
+ with:
+ name: ${{ matrix.modules }}_test_reports
+ path: reports
+
+ run_staging_tests:
+ strategy:
+ fail-fast: false
+ matrix:
+ config:
+ - name: Hub tests for models, schedulers, and pipelines
+ framework: hub_tests_pytorch
+ runner: aws-general-8-plus
+ image: diffusers/diffusers-pytorch-cpu
+ report: torch_hub
+
+ name: ${{ matrix.config.name }}
+ runs-on:
+ group: ${{ matrix.config.runner }}
+ container:
+ image: ${{ matrix.config.image }}
+ options: --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache/
+
+ defaults:
+ run:
+ shell: bash
+
+ steps:
+ - name: Checkout diffusers
+ uses: actions/checkout@v3
+ with:
+ fetch-depth: 2
+
+ - name: Install dependencies
+ run: |
+ python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
+ python -m pip install -e [quality,test]
+
+ - name: Environment
+ run: |
+ python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
+ python utils/print_env.py
+
+ - name: Run Hub tests for models, schedulers, and pipelines on a staging env
+ if: ${{ matrix.config.framework == 'hub_tests_pytorch' }}
+ run: |
+ python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
+ HUGGINGFACE_CO_STAGING=true python -m pytest \
+ -m "is_staging_test" \
+ --make-reports=tests_${{ matrix.config.report }} \
+ tests
+
+ - name: Failure short reports
+ if: ${{ failure() }}
+ run: cat reports/tests_${{ matrix.config.report }}_failures_short.txt
+
+ - name: Test suite reports artifacts
+ if: ${{ always() }}
+ uses: actions/upload-artifact@v4
+ with:
+ name: pr_${{ matrix.config.report }}_test_reports
+ path: reports
diff --git a/diffusers/.github/workflows/pr_test_peft_backend.yml b/diffusers/.github/workflows/pr_test_peft_backend.yml
new file mode 100644
index 0000000000000000000000000000000000000000..0433067e54bad832d798ed1434d3764154ce6753
--- /dev/null
+++ b/diffusers/.github/workflows/pr_test_peft_backend.yml
@@ -0,0 +1,132 @@
+name: Fast tests for PRs - PEFT backend
+
+on:
+ pull_request:
+ branches:
+ - main
+ paths:
+ - "src/diffusers/**.py"
+ - "tests/**.py"
+
+concurrency:
+ group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
+ cancel-in-progress: true
+
+env:
+ DIFFUSERS_IS_CI: yes
+ OMP_NUM_THREADS: 4
+ MKL_NUM_THREADS: 4
+ PYTEST_TIMEOUT: 60
+
+jobs:
+ check_code_quality:
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v3
+ - name: Set up Python
+ uses: actions/setup-python@v4
+ with:
+ python-version: "3.8"
+ - name: Install dependencies
+ run: |
+ python -m pip install --upgrade pip
+ pip install .[quality]
+ - name: Check quality
+ run: make quality
+ - name: Check if failure
+ if: ${{ failure() }}
+ run: |
+ echo "Quality check failed. Please ensure the right dependency versions are installed with 'pip install -e .[quality]' and run 'make style && make quality'" >> $GITHUB_STEP_SUMMARY
+
+ check_repository_consistency:
+ needs: check_code_quality
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v3
+ - name: Set up Python
+ uses: actions/setup-python@v4
+ with:
+ python-version: "3.8"
+ - name: Install dependencies
+ run: |
+ python -m pip install --upgrade pip
+ pip install .[quality]
+ - name: Check repo consistency
+ run: |
+ python utils/check_copies.py
+ python utils/check_dummies.py
+ make deps_table_check_updated
+ - name: Check if failure
+ if: ${{ failure() }}
+ run: |
+ echo "Repo consistency check failed. Please ensure the right dependency versions are installed with 'pip install -e .[quality]' and run 'make fix-copies'" >> $GITHUB_STEP_SUMMARY
+
+ run_fast_tests:
+ needs: [check_code_quality, check_repository_consistency]
+ strategy:
+ fail-fast: false
+ matrix:
+ lib-versions: ["main", "latest"]
+
+
+ name: LoRA - ${{ matrix.lib-versions }}
+
+ runs-on:
+ group: aws-general-8-plus
+
+ container:
+ image: diffusers/diffusers-pytorch-cpu
+ options: --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache/
+
+ defaults:
+ run:
+ shell: bash
+
+ steps:
+ - name: Checkout diffusers
+ uses: actions/checkout@v3
+ with:
+ fetch-depth: 2
+
+ - name: Install dependencies
+ run: |
+ python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
+ python -m uv pip install -e [quality,test]
+ if [ "${{ matrix.lib-versions }}" == "main" ]; then
+ python -m pip install -U peft@git+https://github.com/huggingface/peft.git
+ python -m uv pip install -U transformers@git+https://github.com/huggingface/transformers.git
+ pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
+ else
+ python -m uv pip install -U peft transformers accelerate
+ fi
+
+ - name: Environment
+ run: |
+ python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
+ python utils/print_env.py
+
+ - name: Run fast PyTorch LoRA CPU tests with PEFT backend
+ run: |
+ python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
+ python -m pytest -n 4 --max-worker-restart=0 --dist=loadfile \
+ -s -v \
+ --make-reports=tests_${{ matrix.lib-versions }} \
+ tests/lora/
+ python -m pytest -n 4 --max-worker-restart=0 --dist=loadfile \
+ -s -v \
+ --make-reports=tests_models_lora_${{ matrix.lib-versions }} \
+ tests/models/ -k "lora"
+
+
+ - name: Failure short reports
+ if: ${{ failure() }}
+ run: |
+ cat reports/tests_${{ matrix.lib-versions }}_failures_short.txt
+ cat reports/tests_models_lora_${{ matrix.lib-versions }}_failures_short.txt
+
+ - name: Test suite reports artifacts
+ if: ${{ always() }}
+ uses: actions/upload-artifact@v4
+ with:
+ name: pr_${{ matrix.lib-versions }}_test_reports
+ path: reports
diff --git a/diffusers/.github/workflows/pr_tests.yml b/diffusers/.github/workflows/pr_tests.yml
new file mode 100644
index 0000000000000000000000000000000000000000..6bb67976b170cdaab0323508db37ee939e57c4af
--- /dev/null
+++ b/diffusers/.github/workflows/pr_tests.yml
@@ -0,0 +1,236 @@
+name: Fast tests for PRs
+
+on:
+ pull_request:
+ branches:
+ - main
+ paths:
+ - "src/diffusers/**.py"
+ - "benchmarks/**.py"
+ - "examples/**.py"
+ - "scripts/**.py"
+ - "tests/**.py"
+ - ".github/**.yml"
+ - "utils/**.py"
+ push:
+ branches:
+ - ci-*
+
+concurrency:
+ group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
+ cancel-in-progress: true
+
+env:
+ DIFFUSERS_IS_CI: yes
+ HF_HUB_ENABLE_HF_TRANSFER: 1
+ OMP_NUM_THREADS: 4
+ MKL_NUM_THREADS: 4
+ PYTEST_TIMEOUT: 60
+
+jobs:
+ check_code_quality:
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v3
+ - name: Set up Python
+ uses: actions/setup-python@v4
+ with:
+ python-version: "3.8"
+ - name: Install dependencies
+ run: |
+ python -m pip install --upgrade pip
+ pip install .[quality]
+ - name: Check quality
+ run: make quality
+ - name: Check if failure
+ if: ${{ failure() }}
+ run: |
+ echo "Quality check failed. Please ensure the right dependency versions are installed with 'pip install -e .[quality]' and run 'make style && make quality'" >> $GITHUB_STEP_SUMMARY
+
+ check_repository_consistency:
+ needs: check_code_quality
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v3
+ - name: Set up Python
+ uses: actions/setup-python@v4
+ with:
+ python-version: "3.8"
+ - name: Install dependencies
+ run: |
+ python -m pip install --upgrade pip
+ pip install .[quality]
+ - name: Check repo consistency
+ run: |
+ python utils/check_copies.py
+ python utils/check_dummies.py
+ make deps_table_check_updated
+ - name: Check if failure
+ if: ${{ failure() }}
+ run: |
+ echo "Repo consistency check failed. Please ensure the right dependency versions are installed with 'pip install -e .[quality]' and run 'make fix-copies'" >> $GITHUB_STEP_SUMMARY
+
+ run_fast_tests:
+ needs: [check_code_quality, check_repository_consistency]
+ strategy:
+ fail-fast: false
+ matrix:
+ config:
+ - name: Fast PyTorch Pipeline CPU tests
+ framework: pytorch_pipelines
+ runner: aws-highmemory-32-plus
+ image: diffusers/diffusers-pytorch-cpu
+ report: torch_cpu_pipelines
+ - name: Fast PyTorch Models & Schedulers CPU tests
+ framework: pytorch_models
+ runner: aws-general-8-plus
+ image: diffusers/diffusers-pytorch-cpu
+ report: torch_cpu_models_schedulers
+ - name: Fast Flax CPU tests
+ framework: flax
+ runner: aws-general-8-plus
+ image: diffusers/diffusers-flax-cpu
+ report: flax_cpu
+ - name: PyTorch Example CPU tests
+ framework: pytorch_examples
+ runner: aws-general-8-plus
+ image: diffusers/diffusers-pytorch-cpu
+ report: torch_example_cpu
+
+ name: ${{ matrix.config.name }}
+
+ runs-on:
+ group: ${{ matrix.config.runner }}
+
+ container:
+ image: ${{ matrix.config.image }}
+ options: --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache/
+
+ defaults:
+ run:
+ shell: bash
+
+ steps:
+ - name: Checkout diffusers
+ uses: actions/checkout@v3
+ with:
+ fetch-depth: 2
+
+ - name: Install dependencies
+ run: |
+ python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
+ python -m uv pip install -e [quality,test]
+ python -m uv pip install accelerate
+
+ - name: Environment
+ run: |
+ python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
+ python utils/print_env.py
+
+ - name: Run fast PyTorch Pipeline CPU tests
+ if: ${{ matrix.config.framework == 'pytorch_pipelines' }}
+ run: |
+ python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
+ python -m pytest -n 8 --max-worker-restart=0 --dist=loadfile \
+ -s -v -k "not Flax and not Onnx" \
+ --make-reports=tests_${{ matrix.config.report }} \
+ tests/pipelines
+
+ - name: Run fast PyTorch Model Scheduler CPU tests
+ if: ${{ matrix.config.framework == 'pytorch_models' }}
+ run: |
+ python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
+ python -m pytest -n 4 --max-worker-restart=0 --dist=loadfile \
+ -s -v -k "not Flax and not Onnx and not Dependency" \
+ --make-reports=tests_${{ matrix.config.report }} \
+ tests/models tests/schedulers tests/others
+
+ - name: Run fast Flax TPU tests
+ if: ${{ matrix.config.framework == 'flax' }}
+ run: |
+ python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
+ python -m pytest -n 4 --max-worker-restart=0 --dist=loadfile \
+ -s -v -k "Flax" \
+ --make-reports=tests_${{ matrix.config.report }} \
+ tests
+
+ - name: Run example PyTorch CPU tests
+ if: ${{ matrix.config.framework == 'pytorch_examples' }}
+ run: |
+ python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
+ python -m uv pip install peft timm
+ python -m pytest -n 4 --max-worker-restart=0 --dist=loadfile \
+ --make-reports=tests_${{ matrix.config.report }} \
+ examples
+
+ - name: Failure short reports
+ if: ${{ failure() }}
+ run: cat reports/tests_${{ matrix.config.report }}_failures_short.txt
+
+ - name: Test suite reports artifacts
+ if: ${{ always() }}
+ uses: actions/upload-artifact@v4
+ with:
+ name: pr_${{ matrix.config.framework }}_${{ matrix.config.report }}_test_reports
+ path: reports
+
+ run_staging_tests:
+ needs: [check_code_quality, check_repository_consistency]
+ strategy:
+ fail-fast: false
+ matrix:
+ config:
+ - name: Hub tests for models, schedulers, and pipelines
+ framework: hub_tests_pytorch
+ runner:
+ group: aws-general-8-plus
+ image: diffusers/diffusers-pytorch-cpu
+ report: torch_hub
+
+ name: ${{ matrix.config.name }}
+
+ runs-on: ${{ matrix.config.runner }}
+
+ container:
+ image: ${{ matrix.config.image }}
+ options: --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache/
+
+ defaults:
+ run:
+ shell: bash
+
+ steps:
+ - name: Checkout diffusers
+ uses: actions/checkout@v3
+ with:
+ fetch-depth: 2
+
+ - name: Install dependencies
+ run: |
+ python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
+ python -m uv pip install -e [quality,test]
+
+ - name: Environment
+ run: |
+ python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
+ python utils/print_env.py
+
+ - name: Run Hub tests for models, schedulers, and pipelines on a staging env
+ if: ${{ matrix.config.framework == 'hub_tests_pytorch' }}
+ run: |
+ python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
+ HUGGINGFACE_CO_STAGING=true python -m pytest \
+ -m "is_staging_test" \
+ --make-reports=tests_${{ matrix.config.report }} \
+ tests
+
+ - name: Failure short reports
+ if: ${{ failure() }}
+ run: cat reports/tests_${{ matrix.config.report }}_failures_short.txt
+
+ - name: Test suite reports artifacts
+ if: ${{ always() }}
+ uses: actions/upload-artifact@v4
+ with:
+ name: pr_${{ matrix.config.report }}_test_reports
+ path: reports
diff --git a/diffusers/.github/workflows/pr_torch_dependency_test.yml b/diffusers/.github/workflows/pr_torch_dependency_test.yml
new file mode 100644
index 0000000000000000000000000000000000000000..16a7724fe744beb2d0e173fa85ad7ab1075f49ab
--- /dev/null
+++ b/diffusers/.github/workflows/pr_torch_dependency_test.yml
@@ -0,0 +1,36 @@
+name: Run Torch dependency tests
+
+on:
+ pull_request:
+ branches:
+ - main
+ paths:
+ - "src/diffusers/**.py"
+ push:
+ branches:
+ - main
+
+concurrency:
+ group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
+ cancel-in-progress: true
+
+jobs:
+ check_torch_dependencies:
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v3
+ - name: Set up Python
+ uses: actions/setup-python@v4
+ with:
+ python-version: "3.8"
+ - name: Install dependencies
+ run: |
+ python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
+ python -m pip install --upgrade pip uv
+ python -m uv pip install -e .
+ python -m uv pip install torch torchvision torchaudio
+ python -m uv pip install pytest
+ - name: Check for soft dependencies
+ run: |
+ python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
+ pytest tests/others/test_dependencies.py
diff --git a/diffusers/.github/workflows/push_tests.yml b/diffusers/.github/workflows/push_tests.yml
new file mode 100644
index 0000000000000000000000000000000000000000..f07e6cda0d59b8d73b4a0e28f543a283ab346f7f
--- /dev/null
+++ b/diffusers/.github/workflows/push_tests.yml
@@ -0,0 +1,391 @@
+name: Fast GPU Tests on main
+
+on:
+ workflow_dispatch:
+ push:
+ branches:
+ - main
+ paths:
+ - "src/diffusers/**.py"
+ - "examples/**.py"
+ - "tests/**.py"
+
+env:
+ DIFFUSERS_IS_CI: yes
+ OMP_NUM_THREADS: 8
+ MKL_NUM_THREADS: 8
+ HF_HUB_ENABLE_HF_TRANSFER: 1
+ PYTEST_TIMEOUT: 600
+ PIPELINE_USAGE_CUTOFF: 50000
+
+jobs:
+ setup_torch_cuda_pipeline_matrix:
+ name: Setup Torch Pipelines CUDA Slow Tests Matrix
+ runs-on:
+ group: aws-general-8-plus
+ container:
+ image: diffusers/diffusers-pytorch-cpu
+ outputs:
+ pipeline_test_matrix: ${{ steps.fetch_pipeline_matrix.outputs.pipeline_test_matrix }}
+ steps:
+ - name: Checkout diffusers
+ uses: actions/checkout@v3
+ with:
+ fetch-depth: 2
+ - name: Install dependencies
+ run: |
+ python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
+ python -m uv pip install -e [quality,test]
+ - name: Environment
+ run: |
+ python utils/print_env.py
+ - name: Fetch Pipeline Matrix
+ id: fetch_pipeline_matrix
+ run: |
+ matrix=$(python utils/fetch_torch_cuda_pipeline_test_matrix.py)
+ echo $matrix
+ echo "pipeline_test_matrix=$matrix" >> $GITHUB_OUTPUT
+ - name: Pipeline Tests Artifacts
+ if: ${{ always() }}
+ uses: actions/upload-artifact@v4
+ with:
+ name: test-pipelines.json
+ path: reports
+
+ torch_pipelines_cuda_tests:
+ name: Torch Pipelines CUDA Tests
+ needs: setup_torch_cuda_pipeline_matrix
+ strategy:
+ fail-fast: false
+ max-parallel: 8
+ matrix:
+ module: ${{ fromJson(needs.setup_torch_cuda_pipeline_matrix.outputs.pipeline_test_matrix) }}
+ runs-on:
+ group: aws-g4dn-2xlarge
+ container:
+ image: diffusers/diffusers-pytorch-cuda
+ options: --shm-size "16gb" --ipc host --gpus 0
+ steps:
+ - name: Checkout diffusers
+ uses: actions/checkout@v3
+ with:
+ fetch-depth: 2
+ - name: NVIDIA-SMI
+ run: |
+ nvidia-smi
+ - name: Install dependencies
+ run: |
+ python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
+ python -m uv pip install -e [quality,test]
+ pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
+ - name: Environment
+ run: |
+ python utils/print_env.py
+ - name: Slow PyTorch CUDA checkpoint tests on Ubuntu
+ env:
+ HF_TOKEN: ${{ secrets.HF_TOKEN }}
+ # https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
+ CUBLAS_WORKSPACE_CONFIG: :16:8
+ run: |
+ python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
+ -s -v -k "not Flax and not Onnx" \
+ --make-reports=tests_pipeline_${{ matrix.module }}_cuda \
+ tests/pipelines/${{ matrix.module }}
+ - name: Failure short reports
+ if: ${{ failure() }}
+ run: |
+ cat reports/tests_pipeline_${{ matrix.module }}_cuda_stats.txt
+ cat reports/tests_pipeline_${{ matrix.module }}_cuda_failures_short.txt
+ - name: Test suite reports artifacts
+ if: ${{ always() }}
+ uses: actions/upload-artifact@v4
+ with:
+ name: pipeline_${{ matrix.module }}_test_reports
+ path: reports
+
+ torch_cuda_tests:
+ name: Torch CUDA Tests
+ runs-on:
+ group: aws-g4dn-2xlarge
+ container:
+ image: diffusers/diffusers-pytorch-cuda
+ options: --shm-size "16gb" --ipc host --gpus 0
+ defaults:
+ run:
+ shell: bash
+ strategy:
+ fail-fast: false
+ max-parallel: 2
+ matrix:
+ module: [models, schedulers, lora, others, single_file]
+ steps:
+ - name: Checkout diffusers
+ uses: actions/checkout@v3
+ with:
+ fetch-depth: 2
+
+ - name: Install dependencies
+ run: |
+ python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
+ python -m uv pip install -e [quality,test]
+ python -m uv pip install peft@git+https://github.com/huggingface/peft.git
+ pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
+
+ - name: Environment
+ run: |
+ python utils/print_env.py
+
+ - name: Run PyTorch CUDA tests
+ env:
+ HF_TOKEN: ${{ secrets.HF_TOKEN }}
+ # https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
+ CUBLAS_WORKSPACE_CONFIG: :16:8
+ run: |
+ python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
+ -s -v -k "not Flax and not Onnx" \
+ --make-reports=tests_torch_cuda_${{ matrix.module }} \
+ tests/${{ matrix.module }}
+
+ - name: Failure short reports
+ if: ${{ failure() }}
+ run: |
+ cat reports/tests_torch_cuda_${{ matrix.module }}_stats.txt
+ cat reports/tests_torch_cuda_${{ matrix.module }}_failures_short.txt
+
+ - name: Test suite reports artifacts
+ if: ${{ always() }}
+ uses: actions/upload-artifact@v4
+ with:
+ name: torch_cuda_test_reports_${{ matrix.module }}
+ path: reports
+
+ flax_tpu_tests:
+ name: Flax TPU Tests
+ runs-on: docker-tpu
+ container:
+ image: diffusers/diffusers-flax-tpu
+ options: --shm-size "16gb" --ipc host -v /mnt/cache/.cache/huggingface:/mnt/cache/ --privileged
+ defaults:
+ run:
+ shell: bash
+ steps:
+ - name: Checkout diffusers
+ uses: actions/checkout@v3
+ with:
+ fetch-depth: 2
+
+ - name: Install dependencies
+ run: |
+ python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
+ python -m uv pip install -e [quality,test]
+ pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
+
+ - name: Environment
+ run: |
+ python utils/print_env.py
+
+ - name: Run slow Flax TPU tests
+ env:
+ HF_TOKEN: ${{ secrets.HF_TOKEN }}
+ run: |
+ python -m pytest -n 0 \
+ -s -v -k "Flax" \
+ --make-reports=tests_flax_tpu \
+ tests/
+
+ - name: Failure short reports
+ if: ${{ failure() }}
+ run: |
+ cat reports/tests_flax_tpu_stats.txt
+ cat reports/tests_flax_tpu_failures_short.txt
+
+ - name: Test suite reports artifacts
+ if: ${{ always() }}
+ uses: actions/upload-artifact@v4
+ with:
+ name: flax_tpu_test_reports
+ path: reports
+
+ onnx_cuda_tests:
+ name: ONNX CUDA Tests
+ runs-on:
+ group: aws-g4dn-2xlarge
+ container:
+ image: diffusers/diffusers-onnxruntime-cuda
+ options: --shm-size "16gb" --ipc host -v /mnt/cache/.cache/huggingface:/mnt/cache/ --gpus 0
+ defaults:
+ run:
+ shell: bash
+ steps:
+ - name: Checkout diffusers
+ uses: actions/checkout@v3
+ with:
+ fetch-depth: 2
+
+ - name: Install dependencies
+ run: |
+ python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
+ python -m uv pip install -e [quality,test]
+ pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
+
+ - name: Environment
+ run: |
+ python utils/print_env.py
+
+ - name: Run slow ONNXRuntime CUDA tests
+ env:
+ HF_TOKEN: ${{ secrets.HF_TOKEN }}
+ run: |
+ python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
+ -s -v -k "Onnx" \
+ --make-reports=tests_onnx_cuda \
+ tests/
+
+ - name: Failure short reports
+ if: ${{ failure() }}
+ run: |
+ cat reports/tests_onnx_cuda_stats.txt
+ cat reports/tests_onnx_cuda_failures_short.txt
+
+ - name: Test suite reports artifacts
+ if: ${{ always() }}
+ uses: actions/upload-artifact@v4
+ with:
+ name: onnx_cuda_test_reports
+ path: reports
+
+ run_torch_compile_tests:
+ name: PyTorch Compile CUDA tests
+
+ runs-on:
+ group: aws-g4dn-2xlarge
+
+ container:
+ image: diffusers/diffusers-pytorch-compile-cuda
+ options: --gpus 0 --shm-size "16gb" --ipc host
+
+ steps:
+ - name: Checkout diffusers
+ uses: actions/checkout@v3
+ with:
+ fetch-depth: 2
+
+ - name: NVIDIA-SMI
+ run: |
+ nvidia-smi
+ - name: Install dependencies
+ run: |
+ python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
+ python -m uv pip install -e [quality,test,training]
+ - name: Environment
+ run: |
+ python utils/print_env.py
+ - name: Run example tests on GPU
+ env:
+ HF_TOKEN: ${{ secrets.HF_TOKEN }}
+ RUN_COMPILE: yes
+ run: |
+ python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v -k "compile" --make-reports=tests_torch_compile_cuda tests/
+ - name: Failure short reports
+ if: ${{ failure() }}
+ run: cat reports/tests_torch_compile_cuda_failures_short.txt
+
+ - name: Test suite reports artifacts
+ if: ${{ always() }}
+ uses: actions/upload-artifact@v4
+ with:
+ name: torch_compile_test_reports
+ path: reports
+
+ run_xformers_tests:
+ name: PyTorch xformers CUDA tests
+
+ runs-on:
+ group: aws-g4dn-2xlarge
+
+ container:
+ image: diffusers/diffusers-pytorch-xformers-cuda
+ options: --gpus 0 --shm-size "16gb" --ipc host
+
+ steps:
+ - name: Checkout diffusers
+ uses: actions/checkout@v3
+ with:
+ fetch-depth: 2
+
+ - name: NVIDIA-SMI
+ run: |
+ nvidia-smi
+ - name: Install dependencies
+ run: |
+ python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
+ python -m uv pip install -e [quality,test,training]
+ - name: Environment
+ run: |
+ python utils/print_env.py
+ - name: Run example tests on GPU
+ env:
+ HF_TOKEN: ${{ secrets.HF_TOKEN }}
+ run: |
+ python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v -k "xformers" --make-reports=tests_torch_xformers_cuda tests/
+ - name: Failure short reports
+ if: ${{ failure() }}
+ run: cat reports/tests_torch_xformers_cuda_failures_short.txt
+
+ - name: Test suite reports artifacts
+ if: ${{ always() }}
+ uses: actions/upload-artifact@v4
+ with:
+ name: torch_xformers_test_reports
+ path: reports
+
+ run_examples_tests:
+ name: Examples PyTorch CUDA tests on Ubuntu
+
+ runs-on:
+ group: aws-g4dn-2xlarge
+
+ container:
+ image: diffusers/diffusers-pytorch-cuda
+ options: --gpus 0 --shm-size "16gb" --ipc host
+
+ steps:
+ - name: Checkout diffusers
+ uses: actions/checkout@v3
+ with:
+ fetch-depth: 2
+
+ - name: NVIDIA-SMI
+ run: |
+ nvidia-smi
+
+ - name: Install dependencies
+ run: |
+ python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
+ python -m uv pip install -e [quality,test,training]
+
+ - name: Environment
+ run: |
+ python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
+ python utils/print_env.py
+
+ - name: Run example tests on GPU
+ env:
+ HF_TOKEN: ${{ secrets.HF_TOKEN }}
+ run: |
+ python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
+ python -m uv pip install timm
+ python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v --make-reports=examples_torch_cuda examples/
+
+ - name: Failure short reports
+ if: ${{ failure() }}
+ run: |
+ cat reports/examples_torch_cuda_stats.txt
+ cat reports/examples_torch_cuda_failures_short.txt
+
+ - name: Test suite reports artifacts
+ if: ${{ always() }}
+ uses: actions/upload-artifact@v4
+ with:
+ name: examples_test_reports
+ path: reports
diff --git a/diffusers/.github/workflows/push_tests_fast.yml b/diffusers/.github/workflows/push_tests_fast.yml
new file mode 100644
index 0000000000000000000000000000000000000000..e8a73446de737999a05b7d3e829e5a4832425bce
--- /dev/null
+++ b/diffusers/.github/workflows/push_tests_fast.yml
@@ -0,0 +1,126 @@
+name: Fast tests on main
+
+on:
+ push:
+ branches:
+ - main
+ paths:
+ - "src/diffusers/**.py"
+ - "examples/**.py"
+ - "tests/**.py"
+
+concurrency:
+ group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
+ cancel-in-progress: true
+
+env:
+ DIFFUSERS_IS_CI: yes
+ HF_HOME: /mnt/cache
+ OMP_NUM_THREADS: 8
+ MKL_NUM_THREADS: 8
+ HF_HUB_ENABLE_HF_TRANSFER: 1
+ PYTEST_TIMEOUT: 600
+ RUN_SLOW: no
+
+jobs:
+ run_fast_tests:
+ strategy:
+ fail-fast: false
+ matrix:
+ config:
+ - name: Fast PyTorch CPU tests on Ubuntu
+ framework: pytorch
+ runner: aws-general-8-plus
+ image: diffusers/diffusers-pytorch-cpu
+ report: torch_cpu
+ - name: Fast Flax CPU tests on Ubuntu
+ framework: flax
+ runner: aws-general-8-plus
+ image: diffusers/diffusers-flax-cpu
+ report: flax_cpu
+ - name: Fast ONNXRuntime CPU tests on Ubuntu
+ framework: onnxruntime
+ runner: aws-general-8-plus
+ image: diffusers/diffusers-onnxruntime-cpu
+ report: onnx_cpu
+ - name: PyTorch Example CPU tests on Ubuntu
+ framework: pytorch_examples
+ runner: aws-general-8-plus
+ image: diffusers/diffusers-pytorch-cpu
+ report: torch_example_cpu
+
+ name: ${{ matrix.config.name }}
+
+ runs-on:
+ group: ${{ matrix.config.runner }}
+
+ container:
+ image: ${{ matrix.config.image }}
+ options: --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache/
+
+ defaults:
+ run:
+ shell: bash
+
+ steps:
+ - name: Checkout diffusers
+ uses: actions/checkout@v3
+ with:
+ fetch-depth: 2
+
+ - name: Install dependencies
+ run: |
+ python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
+ python -m uv pip install -e [quality,test]
+
+ - name: Environment
+ run: |
+ python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
+ python utils/print_env.py
+
+ - name: Run fast PyTorch CPU tests
+ if: ${{ matrix.config.framework == 'pytorch' }}
+ run: |
+ python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
+ python -m pytest -n 4 --max-worker-restart=0 --dist=loadfile \
+ -s -v -k "not Flax and not Onnx" \
+ --make-reports=tests_${{ matrix.config.report }} \
+ tests/
+
+ - name: Run fast Flax TPU tests
+ if: ${{ matrix.config.framework == 'flax' }}
+ run: |
+ python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
+ python -m pytest -n 4 --max-worker-restart=0 --dist=loadfile \
+ -s -v -k "Flax" \
+ --make-reports=tests_${{ matrix.config.report }} \
+ tests/
+
+ - name: Run fast ONNXRuntime CPU tests
+ if: ${{ matrix.config.framework == 'onnxruntime' }}
+ run: |
+ python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
+ python -m pytest -n 4 --max-worker-restart=0 --dist=loadfile \
+ -s -v -k "Onnx" \
+ --make-reports=tests_${{ matrix.config.report }} \
+ tests/
+
+ - name: Run example PyTorch CPU tests
+ if: ${{ matrix.config.framework == 'pytorch_examples' }}
+ run: |
+ python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
+ python -m uv pip install peft timm
+ python -m pytest -n 4 --max-worker-restart=0 --dist=loadfile \
+ --make-reports=tests_${{ matrix.config.report }} \
+ examples
+
+ - name: Failure short reports
+ if: ${{ failure() }}
+ run: cat reports/tests_${{ matrix.config.report }}_failures_short.txt
+
+ - name: Test suite reports artifacts
+ if: ${{ always() }}
+ uses: actions/upload-artifact@v4
+ with:
+ name: pr_${{ matrix.config.report }}_test_reports
+ path: reports
diff --git a/diffusers/.github/workflows/push_tests_mps.yml b/diffusers/.github/workflows/push_tests_mps.yml
new file mode 100644
index 0000000000000000000000000000000000000000..8d521074a08fdf05c457479daa06f5f99bad9faa
--- /dev/null
+++ b/diffusers/.github/workflows/push_tests_mps.yml
@@ -0,0 +1,76 @@
+name: Fast mps tests on main
+
+on:
+ push:
+ branches:
+ - main
+ paths:
+ - "src/diffusers/**.py"
+ - "tests/**.py"
+
+env:
+ DIFFUSERS_IS_CI: yes
+ HF_HOME: /mnt/cache
+ OMP_NUM_THREADS: 8
+ MKL_NUM_THREADS: 8
+ HF_HUB_ENABLE_HF_TRANSFER: 1
+ PYTEST_TIMEOUT: 600
+ RUN_SLOW: no
+
+concurrency:
+ group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
+ cancel-in-progress: true
+
+jobs:
+ run_fast_tests_apple_m1:
+ name: Fast PyTorch MPS tests on MacOS
+ runs-on: macos-13-xlarge
+
+ steps:
+ - name: Checkout diffusers
+ uses: actions/checkout@v3
+ with:
+ fetch-depth: 2
+
+ - name: Clean checkout
+ shell: arch -arch arm64 bash {0}
+ run: |
+ git clean -fxd
+
+ - name: Setup miniconda
+ uses: ./.github/actions/setup-miniconda
+ with:
+ python-version: 3.9
+
+ - name: Install dependencies
+ shell: arch -arch arm64 bash {0}
+ run: |
+ ${CONDA_RUN} python -m pip install --upgrade pip uv
+ ${CONDA_RUN} python -m uv pip install -e [quality,test]
+ ${CONDA_RUN} python -m uv pip install torch torchvision torchaudio
+ ${CONDA_RUN} python -m uv pip install accelerate@git+https://github.com/huggingface/accelerate.git
+ ${CONDA_RUN} python -m uv pip install transformers --upgrade
+
+ - name: Environment
+ shell: arch -arch arm64 bash {0}
+ run: |
+ ${CONDA_RUN} python utils/print_env.py
+
+ - name: Run fast PyTorch tests on M1 (MPS)
+ shell: arch -arch arm64 bash {0}
+ env:
+ HF_HOME: /System/Volumes/Data/mnt/cache
+ HF_TOKEN: ${{ secrets.HF_TOKEN }}
+ run: |
+ ${CONDA_RUN} python -m pytest -n 0 -s -v --make-reports=tests_torch_mps tests/
+
+ - name: Failure short reports
+ if: ${{ failure() }}
+ run: cat reports/tests_torch_mps_failures_short.txt
+
+ - name: Test suite reports artifacts
+ if: ${{ always() }}
+ uses: actions/upload-artifact@v4
+ with:
+ name: pr_torch_mps_test_reports
+ path: reports
diff --git a/diffusers/.github/workflows/pypi_publish.yaml b/diffusers/.github/workflows/pypi_publish.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..6b09f60a352324da6ee0ed72aeade2c4ff3d9aeb
--- /dev/null
+++ b/diffusers/.github/workflows/pypi_publish.yaml
@@ -0,0 +1,81 @@
+# Adapted from https://blog.deepjyoti30.dev/pypi-release-github-action
+
+name: PyPI release
+
+on:
+ workflow_dispatch:
+ push:
+ tags:
+ - "*"
+
+jobs:
+ find-and-checkout-latest-branch:
+ runs-on: ubuntu-latest
+ outputs:
+ latest_branch: ${{ steps.set_latest_branch.outputs.latest_branch }}
+ steps:
+ - name: Checkout Repo
+ uses: actions/checkout@v3
+
+ - name: Set up Python
+ uses: actions/setup-python@v4
+ with:
+ python-version: '3.8'
+
+ - name: Fetch latest branch
+ id: fetch_latest_branch
+ run: |
+ pip install -U requests packaging
+ LATEST_BRANCH=$(python utils/fetch_latest_release_branch.py)
+ echo "Latest branch: $LATEST_BRANCH"
+ echo "latest_branch=$LATEST_BRANCH" >> $GITHUB_ENV
+
+ - name: Set latest branch output
+ id: set_latest_branch
+ run: echo "::set-output name=latest_branch::${{ env.latest_branch }}"
+
+ release:
+ needs: find-and-checkout-latest-branch
+ runs-on: ubuntu-latest
+
+ steps:
+ - name: Checkout Repo
+ uses: actions/checkout@v3
+ with:
+ ref: ${{ needs.find-and-checkout-latest-branch.outputs.latest_branch }}
+
+ - name: Setup Python
+ uses: actions/setup-python@v4
+ with:
+ python-version: "3.8"
+
+ - name: Install dependencies
+ run: |
+ python -m pip install --upgrade pip
+ pip install -U setuptools wheel twine
+ pip install -U torch --index-url https://download.pytorch.org/whl/cpu
+ pip install -U transformers
+
+ - name: Build the dist files
+ run: python setup.py bdist_wheel && python setup.py sdist
+
+ - name: Publish to the test PyPI
+ env:
+ TWINE_USERNAME: ${{ secrets.TEST_PYPI_USERNAME }}
+ TWINE_PASSWORD: ${{ secrets.TEST_PYPI_PASSWORD }}
+ run: twine upload dist/* -r pypitest --repository-url=https://test.pypi.org/legacy/
+
+ - name: Test installing diffusers and importing
+ run: |
+ pip install diffusers && pip uninstall diffusers -y
+ pip install -i https://testpypi.python.org/pypi diffusers
+ python -c "from diffusers import __version__; print(__version__)"
+ python -c "from diffusers import DiffusionPipeline; pipe = DiffusionPipeline.from_pretrained('fusing/unet-ldm-dummy-update'); pipe()"
+ python -c "from diffusers import DiffusionPipeline; pipe = DiffusionPipeline.from_pretrained('hf-internal-testing/tiny-stable-diffusion-pipe', safety_checker=None); pipe('ah suh du')"
+ python -c "from diffusers import *"
+
+ - name: Publish to PyPI
+ env:
+ TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
+ TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
+ run: twine upload dist/* -r pypi
diff --git a/diffusers/.github/workflows/release_tests_fast.yml b/diffusers/.github/workflows/release_tests_fast.yml
new file mode 100644
index 0000000000000000000000000000000000000000..a8a6f2699dca8a2ae02dd5c8dc9546efdc1e0f5c
--- /dev/null
+++ b/diffusers/.github/workflows/release_tests_fast.yml
@@ -0,0 +1,389 @@
+# Duplicate workflow to push_tests.yml that is meant to run on release/patch branches as a final check
+# Creating a duplicate workflow here is simpler than adding complex path/branch parsing logic to push_tests.yml
+# Needs to be updated if push_tests.yml updated
+name: (Release) Fast GPU Tests on main
+
+on:
+ push:
+ branches:
+ - "v*.*.*-release"
+ - "v*.*.*-patch"
+
+env:
+ DIFFUSERS_IS_CI: yes
+ OMP_NUM_THREADS: 8
+ MKL_NUM_THREADS: 8
+ PYTEST_TIMEOUT: 600
+ PIPELINE_USAGE_CUTOFF: 50000
+
+jobs:
+ setup_torch_cuda_pipeline_matrix:
+ name: Setup Torch Pipelines CUDA Slow Tests Matrix
+ runs-on:
+ group: aws-general-8-plus
+ container:
+ image: diffusers/diffusers-pytorch-cpu
+ outputs:
+ pipeline_test_matrix: ${{ steps.fetch_pipeline_matrix.outputs.pipeline_test_matrix }}
+ steps:
+ - name: Checkout diffusers
+ uses: actions/checkout@v3
+ with:
+ fetch-depth: 2
+ - name: Install dependencies
+ run: |
+ python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
+ python -m uv pip install -e [quality,test]
+ - name: Environment
+ run: |
+ python utils/print_env.py
+ - name: Fetch Pipeline Matrix
+ id: fetch_pipeline_matrix
+ run: |
+ matrix=$(python utils/fetch_torch_cuda_pipeline_test_matrix.py)
+ echo $matrix
+ echo "pipeline_test_matrix=$matrix" >> $GITHUB_OUTPUT
+ - name: Pipeline Tests Artifacts
+ if: ${{ always() }}
+ uses: actions/upload-artifact@v4
+ with:
+ name: test-pipelines.json
+ path: reports
+
+ torch_pipelines_cuda_tests:
+ name: Torch Pipelines CUDA Tests
+ needs: setup_torch_cuda_pipeline_matrix
+ strategy:
+ fail-fast: false
+ max-parallel: 8
+ matrix:
+ module: ${{ fromJson(needs.setup_torch_cuda_pipeline_matrix.outputs.pipeline_test_matrix) }}
+ runs-on:
+ group: aws-g4dn-2xlarge
+ container:
+ image: diffusers/diffusers-pytorch-cuda
+ options: --shm-size "16gb" --ipc host --gpus 0
+ steps:
+ - name: Checkout diffusers
+ uses: actions/checkout@v3
+ with:
+ fetch-depth: 2
+ - name: NVIDIA-SMI
+ run: |
+ nvidia-smi
+ - name: Install dependencies
+ run: |
+ python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
+ python -m uv pip install -e [quality,test]
+ pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
+ - name: Environment
+ run: |
+ python utils/print_env.py
+ - name: Slow PyTorch CUDA checkpoint tests on Ubuntu
+ env:
+ HF_TOKEN: ${{ secrets.HF_TOKEN }}
+ # https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
+ CUBLAS_WORKSPACE_CONFIG: :16:8
+ run: |
+ python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
+ -s -v -k "not Flax and not Onnx" \
+ --make-reports=tests_pipeline_${{ matrix.module }}_cuda \
+ tests/pipelines/${{ matrix.module }}
+ - name: Failure short reports
+ if: ${{ failure() }}
+ run: |
+ cat reports/tests_pipeline_${{ matrix.module }}_cuda_stats.txt
+ cat reports/tests_pipeline_${{ matrix.module }}_cuda_failures_short.txt
+ - name: Test suite reports artifacts
+ if: ${{ always() }}
+ uses: actions/upload-artifact@v4
+ with:
+ name: pipeline_${{ matrix.module }}_test_reports
+ path: reports
+
+ torch_cuda_tests:
+ name: Torch CUDA Tests
+ runs-on:
+ group: aws-g4dn-2xlarge
+ container:
+ image: diffusers/diffusers-pytorch-cuda
+ options: --shm-size "16gb" --ipc host --gpus 0
+ defaults:
+ run:
+ shell: bash
+ strategy:
+ fail-fast: false
+ max-parallel: 2
+ matrix:
+ module: [models, schedulers, lora, others, single_file]
+ steps:
+ - name: Checkout diffusers
+ uses: actions/checkout@v3
+ with:
+ fetch-depth: 2
+
+ - name: Install dependencies
+ run: |
+ python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
+ python -m uv pip install -e [quality,test]
+ python -m uv pip install peft@git+https://github.com/huggingface/peft.git
+ pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
+
+ - name: Environment
+ run: |
+ python utils/print_env.py
+
+ - name: Run PyTorch CUDA tests
+ env:
+ HF_TOKEN: ${{ secrets.HF_TOKEN }}
+ # https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
+ CUBLAS_WORKSPACE_CONFIG: :16:8
+ run: |
+ python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
+ -s -v -k "not Flax and not Onnx" \
+ --make-reports=tests_torch_${{ matrix.module }}_cuda \
+ tests/${{ matrix.module }}
+
+ - name: Failure short reports
+ if: ${{ failure() }}
+ run: |
+ cat reports/tests_torch_${{ matrix.module }}_cuda_stats.txt
+ cat reports/tests_torch_${{ matrix.module }}_cuda_failures_short.txt
+
+ - name: Test suite reports artifacts
+ if: ${{ always() }}
+ uses: actions/upload-artifact@v4
+ with:
+ name: torch_cuda_${{ matrix.module }}_test_reports
+ path: reports
+
+ flax_tpu_tests:
+ name: Flax TPU Tests
+ runs-on: docker-tpu
+ container:
+ image: diffusers/diffusers-flax-tpu
+ options: --shm-size "16gb" --ipc host -v /mnt/cache/.cache/huggingface:/mnt/cache/ --privileged
+ defaults:
+ run:
+ shell: bash
+ steps:
+ - name: Checkout diffusers
+ uses: actions/checkout@v3
+ with:
+ fetch-depth: 2
+
+ - name: Install dependencies
+ run: |
+ python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
+ python -m uv pip install -e [quality,test]
+ pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
+
+ - name: Environment
+ run: |
+ python utils/print_env.py
+
+ - name: Run slow Flax TPU tests
+ env:
+ HF_TOKEN: ${{ secrets.HF_TOKEN }}
+ run: |
+ python -m pytest -n 0 \
+ -s -v -k "Flax" \
+ --make-reports=tests_flax_tpu \
+ tests/
+
+ - name: Failure short reports
+ if: ${{ failure() }}
+ run: |
+ cat reports/tests_flax_tpu_stats.txt
+ cat reports/tests_flax_tpu_failures_short.txt
+
+ - name: Test suite reports artifacts
+ if: ${{ always() }}
+ uses: actions/upload-artifact@v4
+ with:
+ name: flax_tpu_test_reports
+ path: reports
+
+ onnx_cuda_tests:
+ name: ONNX CUDA Tests
+ runs-on:
+ group: aws-g4dn-2xlarge
+ container:
+ image: diffusers/diffusers-onnxruntime-cuda
+ options: --shm-size "16gb" --ipc host -v /mnt/cache/.cache/huggingface:/mnt/cache/ --gpus 0
+ defaults:
+ run:
+ shell: bash
+ steps:
+ - name: Checkout diffusers
+ uses: actions/checkout@v3
+ with:
+ fetch-depth: 2
+
+ - name: Install dependencies
+ run: |
+ python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
+ python -m uv pip install -e [quality,test]
+ pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
+
+ - name: Environment
+ run: |
+ python utils/print_env.py
+
+ - name: Run slow ONNXRuntime CUDA tests
+ env:
+ HF_TOKEN: ${{ secrets.HF_TOKEN }}
+ run: |
+ python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
+ -s -v -k "Onnx" \
+ --make-reports=tests_onnx_cuda \
+ tests/
+
+ - name: Failure short reports
+ if: ${{ failure() }}
+ run: |
+ cat reports/tests_onnx_cuda_stats.txt
+ cat reports/tests_onnx_cuda_failures_short.txt
+
+ - name: Test suite reports artifacts
+ if: ${{ always() }}
+ uses: actions/upload-artifact@v4
+ with:
+ name: onnx_cuda_test_reports
+ path: reports
+
+ run_torch_compile_tests:
+ name: PyTorch Compile CUDA tests
+
+ runs-on:
+ group: aws-g4dn-2xlarge
+
+ container:
+ image: diffusers/diffusers-pytorch-compile-cuda
+ options: --gpus 0 --shm-size "16gb" --ipc host
+
+ steps:
+ - name: Checkout diffusers
+ uses: actions/checkout@v3
+ with:
+ fetch-depth: 2
+
+ - name: NVIDIA-SMI
+ run: |
+ nvidia-smi
+ - name: Install dependencies
+ run: |
+ python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
+ python -m uv pip install -e [quality,test,training]
+ - name: Environment
+ run: |
+ python utils/print_env.py
+ - name: Run example tests on GPU
+ env:
+ HF_TOKEN: ${{ secrets.HF_TOKEN }}
+ RUN_COMPILE: yes
+ run: |
+ python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v -k "compile" --make-reports=tests_torch_compile_cuda tests/
+ - name: Failure short reports
+ if: ${{ failure() }}
+ run: cat reports/tests_torch_compile_cuda_failures_short.txt
+
+ - name: Test suite reports artifacts
+ if: ${{ always() }}
+ uses: actions/upload-artifact@v4
+ with:
+ name: torch_compile_test_reports
+ path: reports
+
+ run_xformers_tests:
+ name: PyTorch xformers CUDA tests
+
+ runs-on:
+ group: aws-g4dn-2xlarge
+
+ container:
+ image: diffusers/diffusers-pytorch-xformers-cuda
+ options: --gpus 0 --shm-size "16gb" --ipc host
+
+ steps:
+ - name: Checkout diffusers
+ uses: actions/checkout@v3
+ with:
+ fetch-depth: 2
+
+ - name: NVIDIA-SMI
+ run: |
+ nvidia-smi
+ - name: Install dependencies
+ run: |
+ python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
+ python -m uv pip install -e [quality,test,training]
+ - name: Environment
+ run: |
+ python utils/print_env.py
+ - name: Run example tests on GPU
+ env:
+ HF_TOKEN: ${{ secrets.HF_TOKEN }}
+ run: |
+ python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v -k "xformers" --make-reports=tests_torch_xformers_cuda tests/
+ - name: Failure short reports
+ if: ${{ failure() }}
+ run: cat reports/tests_torch_xformers_cuda_failures_short.txt
+
+ - name: Test suite reports artifacts
+ if: ${{ always() }}
+ uses: actions/upload-artifact@v4
+ with:
+ name: torch_xformers_test_reports
+ path: reports
+
+ run_examples_tests:
+ name: Examples PyTorch CUDA tests on Ubuntu
+
+ runs-on:
+ group: aws-g4dn-2xlarge
+
+ container:
+ image: diffusers/diffusers-pytorch-cuda
+ options: --gpus 0 --shm-size "16gb" --ipc host
+
+ steps:
+ - name: Checkout diffusers
+ uses: actions/checkout@v3
+ with:
+ fetch-depth: 2
+
+ - name: NVIDIA-SMI
+ run: |
+ nvidia-smi
+
+ - name: Install dependencies
+ run: |
+ python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
+ python -m uv pip install -e [quality,test,training]
+
+ - name: Environment
+ run: |
+ python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
+ python utils/print_env.py
+
+ - name: Run example tests on GPU
+ env:
+ HF_TOKEN: ${{ secrets.HF_TOKEN }}
+ run: |
+ python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
+ python -m uv pip install timm
+ python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v --make-reports=examples_torch_cuda examples/
+
+ - name: Failure short reports
+ if: ${{ failure() }}
+ run: |
+ cat reports/examples_torch_cuda_stats.txt
+ cat reports/examples_torch_cuda_failures_short.txt
+
+ - name: Test suite reports artifacts
+ if: ${{ always() }}
+ uses: actions/upload-artifact@v4
+ with:
+ name: examples_test_reports
+ path: reports
diff --git a/diffusers/.github/workflows/run_tests_from_a_pr.yml b/diffusers/.github/workflows/run_tests_from_a_pr.yml
new file mode 100644
index 0000000000000000000000000000000000000000..1e736e5430891ec5b1b7843ba2deadddadb9dca1
--- /dev/null
+++ b/diffusers/.github/workflows/run_tests_from_a_pr.yml
@@ -0,0 +1,74 @@
+name: Check running SLOW tests from a PR (only GPU)
+
+on:
+ workflow_dispatch:
+ inputs:
+ docker_image:
+ default: 'diffusers/diffusers-pytorch-cuda'
+ description: 'Name of the Docker image'
+ required: true
+ branch:
+ description: 'PR Branch to test on'
+ required: true
+ test:
+ description: 'Tests to run (e.g.: `tests/models`).'
+ required: true
+
+env:
+ DIFFUSERS_IS_CI: yes
+ IS_GITHUB_CI: "1"
+ HF_HOME: /mnt/cache
+ OMP_NUM_THREADS: 8
+ MKL_NUM_THREADS: 8
+ PYTEST_TIMEOUT: 600
+ RUN_SLOW: yes
+
+jobs:
+ run_tests:
+ name: "Run a test on our runner from a PR"
+ runs-on:
+ group: aws-g4dn-2xlarge
+ container:
+ image: ${{ github.event.inputs.docker_image }}
+ options: --gpus 0 --privileged --ipc host -v /mnt/cache/.cache/huggingface:/mnt/cache/
+
+ steps:
+ - name: Validate test files input
+ id: validate_test_files
+ env:
+ PY_TEST: ${{ github.event.inputs.test }}
+ run: |
+ if [[ ! "$PY_TEST" =~ ^tests/ ]]; then
+ echo "Error: The input string must start with 'tests/'."
+ exit 1
+ fi
+
+ if [[ ! "$PY_TEST" =~ ^tests/(models|pipelines) ]]; then
+ echo "Error: The input string must contain either 'models' or 'pipelines' after 'tests/'."
+ exit 1
+ fi
+
+ if [[ "$PY_TEST" == *";"* ]]; then
+ echo "Error: The input string must not contain ';'."
+ exit 1
+ fi
+ echo "$PY_TEST"
+
+ - name: Checkout PR branch
+ uses: actions/checkout@v4
+ with:
+ ref: ${{ github.event.inputs.branch }}
+ repository: ${{ github.event.pull_request.head.repo.full_name }}
+
+
+ - name: Install pytest
+ run: |
+ python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
+ python -m uv pip install -e [quality,test]
+ python -m uv pip install peft
+
+ - name: Run tests
+ env:
+ PY_TEST: ${{ github.event.inputs.test }}
+ run: |
+ pytest "$PY_TEST"
diff --git a/diffusers/.github/workflows/ssh-pr-runner.yml b/diffusers/.github/workflows/ssh-pr-runner.yml
new file mode 100644
index 0000000000000000000000000000000000000000..49fa9c0ad24d65634a891496ac93bc93e9e24378
--- /dev/null
+++ b/diffusers/.github/workflows/ssh-pr-runner.yml
@@ -0,0 +1,40 @@
+name: SSH into PR runners
+
+on:
+ workflow_dispatch:
+ inputs:
+ docker_image:
+ description: 'Name of the Docker image'
+ required: true
+
+env:
+ IS_GITHUB_CI: "1"
+ HF_HUB_READ_TOKEN: ${{ secrets.HF_HUB_READ_TOKEN }}
+ HF_HOME: /mnt/cache
+ DIFFUSERS_IS_CI: yes
+ OMP_NUM_THREADS: 8
+ MKL_NUM_THREADS: 8
+ RUN_SLOW: yes
+
+jobs:
+ ssh_runner:
+ name: "SSH"
+ runs-on:
+ group: aws-highmemory-32-plus
+ container:
+ image: ${{ github.event.inputs.docker_image }}
+ options: --shm-size "16gb" --ipc host -v /mnt/cache/.cache/huggingface/diffusers:/mnt/cache/ --privileged
+
+ steps:
+ - name: Checkout diffusers
+ uses: actions/checkout@v3
+ with:
+ fetch-depth: 2
+
+ - name: Tailscale # In order to be able to SSH when a test fails
+ uses: huggingface/tailscale-action@main
+ with:
+ authkey: ${{ secrets.TAILSCALE_SSH_AUTHKEY }}
+ slackChannel: ${{ secrets.SLACK_CIFEEDBACK_CHANNEL }}
+ slackToken: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
+ waitForSSH: true
diff --git a/diffusers/.github/workflows/ssh-runner.yml b/diffusers/.github/workflows/ssh-runner.yml
new file mode 100644
index 0000000000000000000000000000000000000000..0d4fe1578ba69d7668fc59546aa833267b569483
--- /dev/null
+++ b/diffusers/.github/workflows/ssh-runner.yml
@@ -0,0 +1,51 @@
+name: SSH into GPU runners
+
+on:
+ workflow_dispatch:
+ inputs:
+ runner_type:
+ description: 'Type of runner to test (aws-g6-4xlarge-plus: a10 or aws-g4dn-2xlarge: t4)'
+ type: choice
+ required: true
+ options:
+ - aws-g6-4xlarge-plus
+ - aws-g4dn-2xlarge
+ docker_image:
+ description: 'Name of the Docker image'
+ required: true
+
+env:
+ IS_GITHUB_CI: "1"
+ HF_HUB_READ_TOKEN: ${{ secrets.HF_HUB_READ_TOKEN }}
+ HF_HOME: /mnt/cache
+ DIFFUSERS_IS_CI: yes
+ OMP_NUM_THREADS: 8
+ MKL_NUM_THREADS: 8
+ RUN_SLOW: yes
+
+jobs:
+ ssh_runner:
+ name: "SSH"
+ runs-on:
+ group: "${{ github.event.inputs.runner_type }}"
+ container:
+ image: ${{ github.event.inputs.docker_image }}
+ options: --shm-size "16gb" --ipc host -v /mnt/cache/.cache/huggingface/diffusers:/mnt/cache/ --gpus 0 --privileged
+
+ steps:
+ - name: Checkout diffusers
+ uses: actions/checkout@v3
+ with:
+ fetch-depth: 2
+
+ - name: NVIDIA-SMI
+ run: |
+ nvidia-smi
+
+ - name: Tailscale # In order to be able to SSH when a test fails
+ uses: huggingface/tailscale-action@main
+ with:
+ authkey: ${{ secrets.TAILSCALE_SSH_AUTHKEY }}
+ slackChannel: ${{ secrets.SLACK_CIFEEDBACK_CHANNEL }}
+ slackToken: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
+ waitForSSH: true
diff --git a/diffusers/.github/workflows/stale.yml b/diffusers/.github/workflows/stale.yml
new file mode 100644
index 0000000000000000000000000000000000000000..443f65404daf79dd9819daac44baa58e064b8c23
--- /dev/null
+++ b/diffusers/.github/workflows/stale.yml
@@ -0,0 +1,30 @@
+name: Stale Bot
+
+on:
+ schedule:
+ - cron: "0 15 * * *"
+
+jobs:
+ close_stale_issues:
+ name: Close Stale Issues
+ if: github.repository == 'huggingface/diffusers'
+ runs-on: ubuntu-latest
+ permissions:
+ issues: write
+ pull-requests: write
+ env:
+ GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
+ steps:
+ - uses: actions/checkout@v2
+
+ - name: Setup Python
+ uses: actions/setup-python@v1
+ with:
+ python-version: 3.8
+
+ - name: Install requirements
+ run: |
+ pip install PyGithub
+ - name: Close stale issues
+ run: |
+ python utils/stale.py
diff --git a/diffusers/.github/workflows/trufflehog.yml b/diffusers/.github/workflows/trufflehog.yml
new file mode 100644
index 0000000000000000000000000000000000000000..9cbbf6803724dacc4759b69d7002bb34831e5937
--- /dev/null
+++ b/diffusers/.github/workflows/trufflehog.yml
@@ -0,0 +1,15 @@
+on:
+ push:
+
+name: Secret Leaks
+
+jobs:
+ trufflehog:
+ runs-on: ubuntu-latest
+ steps:
+ - name: Checkout code
+ uses: actions/checkout@v4
+ with:
+ fetch-depth: 0
+ - name: Secret Scanning
+ uses: trufflesecurity/trufflehog@main
diff --git a/diffusers/.github/workflows/typos.yml b/diffusers/.github/workflows/typos.yml
new file mode 100644
index 0000000000000000000000000000000000000000..fbd051b4da0dc6c1ec9e15a3a7bad07b122d81cd
--- /dev/null
+++ b/diffusers/.github/workflows/typos.yml
@@ -0,0 +1,14 @@
+name: Check typos
+
+on:
+ workflow_dispatch:
+
+jobs:
+ build:
+ runs-on: ubuntu-latest
+
+ steps:
+ - uses: actions/checkout@v3
+
+ - name: typos-action
+ uses: crate-ci/typos@v1.12.4
diff --git a/diffusers/.github/workflows/update_metadata.yml b/diffusers/.github/workflows/update_metadata.yml
new file mode 100644
index 0000000000000000000000000000000000000000..92aea0369ba855237b39d4dee9e8daee0c81010d
--- /dev/null
+++ b/diffusers/.github/workflows/update_metadata.yml
@@ -0,0 +1,30 @@
+name: Update Diffusers metadata
+
+on:
+ workflow_dispatch:
+ push:
+ branches:
+ - main
+ - update_diffusers_metadata*
+
+jobs:
+ update_metadata:
+ runs-on: ubuntu-22.04
+ defaults:
+ run:
+ shell: bash -l {0}
+
+ steps:
+ - uses: actions/checkout@v3
+
+ - name: Setup environment
+ run: |
+ pip install --upgrade pip
+ pip install datasets pandas
+ pip install .[torch]
+
+ - name: Update metadata
+ env:
+ HF_TOKEN: ${{ secrets.SAYAK_HF_TOKEN }}
+ run: |
+ python utils/update_metadata.py --commit_sha ${{ github.sha }}
diff --git a/diffusers/.github/workflows/upload_pr_documentation.yml b/diffusers/.github/workflows/upload_pr_documentation.yml
new file mode 100644
index 0000000000000000000000000000000000000000..fc102df8103e48fb139a8bd47be05fc257d992c5
--- /dev/null
+++ b/diffusers/.github/workflows/upload_pr_documentation.yml
@@ -0,0 +1,16 @@
+name: Upload PR Documentation
+
+on:
+ workflow_run:
+ workflows: ["Build PR Documentation"]
+ types:
+ - completed
+
+jobs:
+ build:
+ uses: huggingface/doc-builder/.github/workflows/upload_pr_documentation.yml@main
+ with:
+ package_name: diffusers
+ secrets:
+ hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }}
+ comment_bot_token: ${{ secrets.COMMENT_BOT_TOKEN }}
diff --git a/diffusers/.gitignore b/diffusers/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..15617d5fdc745f57271db05b5fe7fd83f440c3a7
--- /dev/null
+++ b/diffusers/.gitignore
@@ -0,0 +1,178 @@
+# Initially taken from GitHub's Python gitignore file
+
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# tests and logs
+tests/fixtures/cached_*_text.txt
+logs/
+lightning_logs/
+lang_code_data/
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a Python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+.hypothesis/
+.pytest_cache/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+.python-version
+
+# celery beat schedule file
+celerybeat-schedule
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
+
+# vscode
+.vs
+.vscode
+
+# Pycharm
+.idea
+
+# TF code
+tensorflow_code
+
+# Models
+proc_data
+
+# examples
+runs
+/runs_old
+/wandb
+/examples/runs
+/examples/**/*.args
+/examples/rag/sweep
+
+# data
+/data
+serialization_dir
+
+# emacs
+*.*~
+debug.env
+
+# vim
+.*.swp
+
+# ctags
+tags
+
+# pre-commit
+.pre-commit*
+
+# .lock
+*.lock
+
+# DS_Store (MacOS)
+.DS_Store
+
+# RL pipelines may produce mp4 outputs
+*.mp4
+
+# dependencies
+/transformers
+
+# ruff
+.ruff_cache
+
+# wandb
+wandb
\ No newline at end of file
diff --git a/diffusers/CITATION.cff b/diffusers/CITATION.cff
new file mode 100644
index 0000000000000000000000000000000000000000..09fc6c744d06407fa6b4707af1afe3cc9529b4db
--- /dev/null
+++ b/diffusers/CITATION.cff
@@ -0,0 +1,52 @@
+cff-version: 1.2.0
+title: 'Diffusers: State-of-the-art diffusion models'
+message: >-
+ If you use this software, please cite it using the
+ metadata from this file.
+type: software
+authors:
+ - given-names: Patrick
+ family-names: von Platen
+ - given-names: Suraj
+ family-names: Patil
+ - given-names: Anton
+ family-names: Lozhkov
+ - given-names: Pedro
+ family-names: Cuenca
+ - given-names: Nathan
+ family-names: Lambert
+ - given-names: Kashif
+ family-names: Rasul
+ - given-names: Mishig
+ family-names: Davaadorj
+ - given-names: Dhruv
+ family-names: Nair
+ - given-names: Sayak
+ family-names: Paul
+ - given-names: Steven
+ family-names: Liu
+ - given-names: William
+ family-names: Berman
+ - given-names: Yiyi
+ family-names: Xu
+ - given-names: Thomas
+ family-names: Wolf
+repository-code: 'https://github.com/huggingface/diffusers'
+abstract: >-
+ Diffusers provides pretrained diffusion models across
+ multiple modalities, such as vision and audio, and serves
+ as a modular toolbox for inference and training of
+ diffusion models.
+keywords:
+ - deep-learning
+ - pytorch
+ - image-generation
+ - hacktoberfest
+ - diffusion
+ - text2image
+ - image2image
+ - score-based-generative-modeling
+ - stable-diffusion
+ - stable-diffusion-diffusers
+license: Apache-2.0
+version: 0.12.1
diff --git a/diffusers/CODE_OF_CONDUCT.md b/diffusers/CODE_OF_CONDUCT.md
new file mode 100644
index 0000000000000000000000000000000000000000..2139079964fbd53692380985e60ef90e2fa05dad
--- /dev/null
+++ b/diffusers/CODE_OF_CONDUCT.md
@@ -0,0 +1,130 @@
+
+# Contributor Covenant Code of Conduct
+
+## Our Pledge
+
+We as members, contributors, and leaders pledge to make participation in our
+community a harassment-free experience for everyone, regardless of age, body
+size, visible or invisible disability, ethnicity, sex characteristics, gender
+identity and expression, level of experience, education, socio-economic status,
+nationality, personal appearance, race, caste, color, religion, or sexual identity
+and orientation.
+
+We pledge to act and interact in ways that contribute to an open, welcoming,
+diverse, inclusive, and healthy community.
+
+## Our Standards
+
+Examples of behavior that contributes to a positive environment for our
+community include:
+
+* Demonstrating empathy and kindness toward other people
+* Being respectful of differing opinions, viewpoints, and experiences
+* Giving and gracefully accepting constructive feedback
+* Accepting responsibility and apologizing to those affected by our mistakes,
+ and learning from the experience
+* Focusing on what is best not just for us as individuals, but for the
+ overall Diffusers community
+
+Examples of unacceptable behavior include:
+
+* The use of sexualized language or imagery, and sexual attention or
+ advances of any kind
+* Trolling, insulting or derogatory comments, and personal or political attacks
+* Public or private harassment
+* Publishing others' private information, such as a physical or email
+ address, without their explicit permission
+* Spamming issues or PRs with links to projects unrelated to this library
+* Other conduct which could reasonably be considered inappropriate in a
+ professional setting
+
+## Enforcement Responsibilities
+
+Community leaders are responsible for clarifying and enforcing our standards of
+acceptable behavior and will take appropriate and fair corrective action in
+response to any behavior that they deem inappropriate, threatening, offensive,
+or harmful.
+
+Community leaders have the right and responsibility to remove, edit, or reject
+comments, commits, code, wiki edits, issues, and other contributions that are
+not aligned to this Code of Conduct, and will communicate reasons for moderation
+decisions when appropriate.
+
+## Scope
+
+This Code of Conduct applies within all community spaces, and also applies when
+an individual is officially representing the community in public spaces.
+Examples of representing our community include using an official e-mail address,
+posting via an official social media account, or acting as an appointed
+representative at an online or offline event.
+
+## Enforcement
+
+Instances of abusive, harassing, or otherwise unacceptable behavior may be
+reported to the community leaders responsible for enforcement at
+feedback@huggingface.co.
+All complaints will be reviewed and investigated promptly and fairly.
+
+All community leaders are obligated to respect the privacy and security of the
+reporter of any incident.
+
+## Enforcement Guidelines
+
+Community leaders will follow these Community Impact Guidelines in determining
+the consequences for any action they deem in violation of this Code of Conduct:
+
+### 1. Correction
+
+**Community Impact**: Use of inappropriate language or other behavior deemed
+unprofessional or unwelcome in the community.
+
+**Consequence**: A private, written warning from community leaders, providing
+clarity around the nature of the violation and an explanation of why the
+behavior was inappropriate. A public apology may be requested.
+
+### 2. Warning
+
+**Community Impact**: A violation through a single incident or series
+of actions.
+
+**Consequence**: A warning with consequences for continued behavior. No
+interaction with the people involved, including unsolicited interaction with
+those enforcing the Code of Conduct, for a specified period of time. This
+includes avoiding interactions in community spaces as well as external channels
+like social media. Violating these terms may lead to a temporary or
+permanent ban.
+
+### 3. Temporary Ban
+
+**Community Impact**: A serious violation of community standards, including
+sustained inappropriate behavior.
+
+**Consequence**: A temporary ban from any sort of interaction or public
+communication with the community for a specified period of time. No public or
+private interaction with the people involved, including unsolicited interaction
+with those enforcing the Code of Conduct, is allowed during this period.
+Violating these terms may lead to a permanent ban.
+
+### 4. Permanent Ban
+
+**Community Impact**: Demonstrating a pattern of violation of community
+standards, including sustained inappropriate behavior, harassment of an
+individual, or aggression toward or disparagement of classes of individuals.
+
+**Consequence**: A permanent ban from any sort of public interaction within
+the community.
+
+## Attribution
+
+This Code of Conduct is adapted from the [Contributor Covenant][homepage],
+version 2.1, available at
+https://www.contributor-covenant.org/version/2/1/code_of_conduct.html.
+
+Community Impact Guidelines were inspired by [Mozilla's code of conduct
+enforcement ladder](https://github.com/mozilla/diversity).
+
+[homepage]: https://www.contributor-covenant.org
+
+For answers to common questions about this code of conduct, see the FAQ at
+https://www.contributor-covenant.org/faq. Translations are available at
+https://www.contributor-covenant.org/translations.
diff --git a/diffusers/CONTRIBUTING.md b/diffusers/CONTRIBUTING.md
new file mode 100644
index 0000000000000000000000000000000000000000..049d317599ad2292504acf3d139b332bc64cbe74
--- /dev/null
+++ b/diffusers/CONTRIBUTING.md
@@ -0,0 +1,506 @@
+
+
+# How to contribute to Diffusers 🧨
+
+We ❤️ contributions from the open-source community! Everyone is welcome, and all types of participation –not just code– are valued and appreciated. Answering questions, helping others, reaching out, and improving the documentation are all immensely valuable to the community, so don't be afraid and get involved if you're up for it!
+
+Everyone is encouraged to start by saying 👋 in our public Discord channel. We discuss the latest trends in diffusion models, ask questions, show off personal projects, help each other with contributions, or just hang out ☕.
+
+Whichever way you choose to contribute, we strive to be part of an open, welcoming, and kind community. Please, read our [code of conduct](https://github.com/huggingface/diffusers/blob/main/CODE_OF_CONDUCT.md) and be mindful to respect it during your interactions. We also recommend you become familiar with the [ethical guidelines](https://huggingface.co/docs/diffusers/conceptual/ethical_guidelines) that guide our project and ask you to adhere to the same principles of transparency and responsibility.
+
+We enormously value feedback from the community, so please do not be afraid to speak up if you believe you have valuable feedback that can help improve the library - every message, comment, issue, and pull request (PR) is read and considered.
+
+## Overview
+
+You can contribute in many ways ranging from answering questions on issues to adding new diffusion models to
+the core library.
+
+In the following, we give an overview of different ways to contribute, ranked by difficulty in ascending order. All of them are valuable to the community.
+
+* 1. Asking and answering questions on [the Diffusers discussion forum](https://discuss.huggingface.co/c/discussion-related-to-httpsgithubcomhuggingfacediffusers) or on [Discord](https://discord.gg/G7tWnz98XR).
+* 2. Opening new issues on [the GitHub Issues tab](https://github.com/huggingface/diffusers/issues/new/choose).
+* 3. Answering issues on [the GitHub Issues tab](https://github.com/huggingface/diffusers/issues).
+* 4. Fix a simple issue, marked by the "Good first issue" label, see [here](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22).
+* 5. Contribute to the [documentation](https://github.com/huggingface/diffusers/tree/main/docs/source).
+* 6. Contribute a [Community Pipeline](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3Acommunity-examples).
+* 7. Contribute to the [examples](https://github.com/huggingface/diffusers/tree/main/examples).
+* 8. Fix a more difficult issue, marked by the "Good second issue" label, see [here](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22Good+second+issue%22).
+* 9. Add a new pipeline, model, or scheduler, see ["New Pipeline/Model"](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22New+pipeline%2Fmodel%22) and ["New scheduler"](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22New+scheduler%22) issues. For this contribution, please have a look at [Design Philosophy](https://github.com/huggingface/diffusers/blob/main/PHILOSOPHY.md).
+
+As said before, **all contributions are valuable to the community**.
+In the following, we will explain each contribution a bit more in detail.
+
+For all contributions 4-9, you will need to open a PR. It is explained in detail how to do so in [Opening a pull request](#how-to-open-a-pr).
+
+### 1. Asking and answering questions on the Diffusers discussion forum or on the Diffusers Discord
+
+Any question or comment related to the Diffusers library can be asked on the [discussion forum](https://discuss.huggingface.co/c/discussion-related-to-httpsgithubcomhuggingfacediffusers/) or on [Discord](https://discord.gg/G7tWnz98XR). Such questions and comments include (but are not limited to):
+- Reports of training or inference experiments in an attempt to share knowledge
+- Presentation of personal projects
+- Questions to non-official training examples
+- Project proposals
+- General feedback
+- Paper summaries
+- Asking for help on personal projects that build on top of the Diffusers library
+- General questions
+- Ethical questions regarding diffusion models
+- ...
+
+Every question that is asked on the forum or on Discord actively encourages the community to publicly
+share knowledge and might very well help a beginner in the future who has the same question you're
+having. Please do pose any questions you might have.
+In the same spirit, you are of immense help to the community by answering such questions because this way you are publicly documenting knowledge for everybody to learn from.
+
+**Please** keep in mind that the more effort you put into asking or answering a question, the higher
+the quality of the publicly documented knowledge. In the same way, well-posed and well-answered questions create a high-quality knowledge database accessible to everybody, while badly posed questions or answers reduce the overall quality of the public knowledge database.
+In short, a high quality question or answer is *precise*, *concise*, *relevant*, *easy-to-understand*, *accessible*, and *well-formatted/well-posed*. For more information, please have a look through the [How to write a good issue](#how-to-write-a-good-issue) section.
+
+**NOTE about channels**:
+[*The forum*](https://discuss.huggingface.co/c/discussion-related-to-httpsgithubcomhuggingfacediffusers/63) is much better indexed by search engines, such as Google. Posts are ranked by popularity rather than chronologically. Hence, it's easier to look up questions and answers that we posted some time ago.
+In addition, questions and answers posted in the forum can easily be linked to.
+In contrast, *Discord* has a chat-like format that invites fast back-and-forth communication.
+While it will most likely take less time for you to get an answer to your question on Discord, your
+question won't be visible anymore over time. Also, it's much harder to find information that was posted a while back on Discord. We therefore strongly recommend using the forum for high-quality questions and answers in an attempt to create long-lasting knowledge for the community. If discussions on Discord lead to very interesting answers and conclusions, we recommend posting the results on the forum to make the information more available for future readers.
+
+### 2. Opening new issues on the GitHub issues tab
+
+The 🧨 Diffusers library is robust and reliable thanks to the users who notify us of
+the problems they encounter. So thank you for reporting an issue.
+
+Remember, GitHub issues are reserved for technical questions directly related to the Diffusers library, bug reports, feature requests, or feedback on the library design.
+
+In a nutshell, this means that everything that is **not** related to the **code of the Diffusers library** (including the documentation) should **not** be asked on GitHub, but rather on either the [forum](https://discuss.huggingface.co/c/discussion-related-to-httpsgithubcomhuggingfacediffusers/63) or [Discord](https://discord.gg/G7tWnz98XR).
+
+**Please consider the following guidelines when opening a new issue**:
+- Make sure you have searched whether your issue has already been asked before (use the search bar on GitHub under Issues).
+- Please never report a new issue on another (related) issue. If another issue is highly related, please
+open a new issue nevertheless and link to the related issue.
+- Make sure your issue is written in English. Please use one of the great, free online translation services, such as [DeepL](https://www.deepl.com/translator) to translate from your native language to English if you are not comfortable in English.
+- Check whether your issue might be solved by updating to the newest Diffusers version. Before posting your issue, please make sure that `python -c "import diffusers; print(diffusers.__version__)"` is higher or matches the latest Diffusers version.
+- Remember that the more effort you put into opening a new issue, the higher the quality of your answer will be and the better the overall quality of the Diffusers issues.
+
+New issues usually include the following.
+
+#### 2.1. Reproducible, minimal bug reports
+
+A bug report should always have a reproducible code snippet and be as minimal and concise as possible.
+This means in more detail:
+- Narrow the bug down as much as you can, **do not just dump your whole code file**.
+- Format your code.
+- Do not include any external libraries except for Diffusers depending on them.
+- **Always** provide all necessary information about your environment; for this, you can run: `diffusers-cli env` in your shell and copy-paste the displayed information to the issue.
+- Explain the issue. If the reader doesn't know what the issue is and why it is an issue, she cannot solve it.
+- **Always** make sure the reader can reproduce your issue with as little effort as possible. If your code snippet cannot be run because of missing libraries or undefined variables, the reader cannot help you. Make sure your reproducible code snippet is as minimal as possible and can be copy-pasted into a simple Python shell.
+- If in order to reproduce your issue a model and/or dataset is required, make sure the reader has access to that model or dataset. You can always upload your model or dataset to the [Hub](https://huggingface.co) to make it easily downloadable. Try to keep your model and dataset as small as possible, to make the reproduction of your issue as effortless as possible.
+
+For more information, please have a look through the [How to write a good issue](#how-to-write-a-good-issue) section.
+
+You can open a bug report [here](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=bug&projects=&template=bug-report.yml).
+
+#### 2.2. Feature requests
+
+A world-class feature request addresses the following points:
+
+1. Motivation first:
+* Is it related to a problem/frustration with the library? If so, please explain
+why. Providing a code snippet that demonstrates the problem is best.
+* Is it related to something you would need for a project? We'd love to hear
+about it!
+* Is it something you worked on and think could benefit the community?
+Awesome! Tell us what problem it solved for you.
+2. Write a *full paragraph* describing the feature;
+3. Provide a **code snippet** that demonstrates its future use;
+4. In case this is related to a paper, please attach a link;
+5. Attach any additional information (drawings, screenshots, etc.) you think may help.
+
+You can open a feature request [here](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=&template=feature_request.md&title=).
+
+#### 2.3 Feedback
+
+Feedback about the library design and why it is good or not good helps the core maintainers immensely to build a user-friendly library. To understand the philosophy behind the current design philosophy, please have a look [here](https://huggingface.co/docs/diffusers/conceptual/philosophy). If you feel like a certain design choice does not fit with the current design philosophy, please explain why and how it should be changed. If a certain design choice follows the design philosophy too much, hence restricting use cases, explain why and how it should be changed.
+If a certain design choice is very useful for you, please also leave a note as this is great feedback for future design decisions.
+
+You can open an issue about feedback [here](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=&template=feedback.md&title=).
+
+#### 2.4 Technical questions
+
+Technical questions are mainly about why certain code of the library was written in a certain way, or what a certain part of the code does. Please make sure to link to the code in question and please provide detail on
+why this part of the code is difficult to understand.
+
+You can open an issue about a technical question [here](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=bug&template=bug-report.yml).
+
+#### 2.5 Proposal to add a new model, scheduler, or pipeline
+
+If the diffusion model community released a new model, pipeline, or scheduler that you would like to see in the Diffusers library, please provide the following information:
+
+* Short description of the diffusion pipeline, model, or scheduler and link to the paper or public release.
+* Link to any of its open-source implementation.
+* Link to the model weights if they are available.
+
+If you are willing to contribute to the model yourself, let us know so we can best guide you. Also, don't forget
+to tag the original author of the component (model, scheduler, pipeline, etc.) by GitHub handle if you can find it.
+
+You can open a request for a model/pipeline/scheduler [here](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=New+model%2Fpipeline%2Fscheduler&template=new-model-addition.yml).
+
+### 3. Answering issues on the GitHub issues tab
+
+Answering issues on GitHub might require some technical knowledge of Diffusers, but we encourage everybody to give it a try even if you are not 100% certain that your answer is correct.
+Some tips to give a high-quality answer to an issue:
+- Be as concise and minimal as possible.
+- Stay on topic. An answer to the issue should concern the issue and only the issue.
+- Provide links to code, papers, or other sources that prove or encourage your point.
+- Answer in code. If a simple code snippet is the answer to the issue or shows how the issue can be solved, please provide a fully reproducible code snippet.
+
+Also, many issues tend to be simply off-topic, duplicates of other issues, or irrelevant. It is of great
+help to the maintainers if you can answer such issues, encouraging the author of the issue to be
+more precise, provide the link to a duplicated issue or redirect them to [the forum](https://discuss.huggingface.co/c/discussion-related-to-httpsgithubcomhuggingfacediffusers/63) or [Discord](https://discord.gg/G7tWnz98XR).
+
+If you have verified that the issued bug report is correct and requires a correction in the source code,
+please have a look at the next sections.
+
+For all of the following contributions, you will need to open a PR. It is explained in detail how to do so in the [Opening a pull request](#how-to-open-a-pr) section.
+
+### 4. Fixing a "Good first issue"
+
+*Good first issues* are marked by the [Good first issue](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22) label. Usually, the issue already
+explains how a potential solution should look so that it is easier to fix.
+If the issue hasn't been closed and you would like to try to fix this issue, you can just leave a message "I would like to try this issue.". There are usually three scenarios:
+- a.) The issue description already proposes a fix. In this case and if the solution makes sense to you, you can open a PR or draft PR to fix it.
+- b.) The issue description does not propose a fix. In this case, you can ask what a proposed fix could look like and someone from the Diffusers team should answer shortly. If you have a good idea of how to fix it, feel free to directly open a PR.
+- c.) There is already an open PR to fix the issue, but the issue hasn't been closed yet. If the PR has gone stale, you can simply open a new PR and link to the stale PR. PRs often go stale if the original contributor who wanted to fix the issue suddenly cannot find the time anymore to proceed. This often happens in open-source and is very normal. In this case, the community will be very happy if you give it a new try and leverage the knowledge of the existing PR. If there is already a PR and it is active, you can help the author by giving suggestions, reviewing the PR or even asking whether you can contribute to the PR.
+
+
+### 5. Contribute to the documentation
+
+A good library **always** has good documentation! The official documentation is often one of the first points of contact for new users of the library, and therefore contributing to the documentation is a **highly
+valuable contribution**.
+
+Contributing to the library can have many forms:
+
+- Correcting spelling or grammatical errors.
+- Correct incorrect formatting of the docstring. If you see that the official documentation is weirdly displayed or a link is broken, we are very happy if you take some time to correct it.
+- Correct the shape or dimensions of a docstring input or output tensor.
+- Clarify documentation that is hard to understand or incorrect.
+- Update outdated code examples.
+- Translating the documentation to another language.
+
+Anything displayed on [the official Diffusers doc page](https://huggingface.co/docs/diffusers/index) is part of the official documentation and can be corrected, adjusted in the respective [documentation source](https://github.com/huggingface/diffusers/tree/main/docs/source).
+
+Please have a look at [this page](https://github.com/huggingface/diffusers/tree/main/docs) on how to verify changes made to the documentation locally.
+
+
+### 6. Contribute a community pipeline
+
+[Pipelines](https://huggingface.co/docs/diffusers/api/pipelines/overview) are usually the first point of contact between the Diffusers library and the user.
+Pipelines are examples of how to use Diffusers [models](https://huggingface.co/docs/diffusers/api/models/overview) and [schedulers](https://huggingface.co/docs/diffusers/api/schedulers/overview).
+We support two types of pipelines:
+
+- Official Pipelines
+- Community Pipelines
+
+Both official and community pipelines follow the same design and consist of the same type of components.
+
+Official pipelines are tested and maintained by the core maintainers of Diffusers. Their code
+resides in [src/diffusers/pipelines](https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines).
+In contrast, community pipelines are contributed and maintained purely by the **community** and are **not** tested.
+They reside in [examples/community](https://github.com/huggingface/diffusers/tree/main/examples/community) and while they can be accessed via the [PyPI diffusers package](https://pypi.org/project/diffusers/), their code is not part of the PyPI distribution.
+
+The reason for the distinction is that the core maintainers of the Diffusers library cannot maintain and test all
+possible ways diffusion models can be used for inference, but some of them may be of interest to the community.
+Officially released diffusion pipelines,
+such as Stable Diffusion are added to the core src/diffusers/pipelines package which ensures
+high quality of maintenance, no backward-breaking code changes, and testing.
+More bleeding edge pipelines should be added as community pipelines. If usage for a community pipeline is high, the pipeline can be moved to the official pipelines upon request from the community. This is one of the ways we strive to be a community-driven library.
+
+To add a community pipeline, one should add a .py file to [examples/community](https://github.com/huggingface/diffusers/tree/main/examples/community) and adapt the [examples/community/README.md](https://github.com/huggingface/diffusers/tree/main/examples/community/README.md) to include an example of the new pipeline.
+
+An example can be seen [here](https://github.com/huggingface/diffusers/pull/2400).
+
+Community pipeline PRs are only checked at a superficial level and ideally they should be maintained by their original authors.
+
+Contributing a community pipeline is a great way to understand how Diffusers models and schedulers work. Having contributed a community pipeline is usually the first stepping stone to contributing an official pipeline to the
+core package.
+
+### 7. Contribute to training examples
+
+Diffusers examples are a collection of training scripts that reside in [examples](https://github.com/huggingface/diffusers/tree/main/examples).
+
+We support two types of training examples:
+
+- Official training examples
+- Research training examples
+
+Research training examples are located in [examples/research_projects](https://github.com/huggingface/diffusers/tree/main/examples/research_projects) whereas official training examples include all folders under [examples](https://github.com/huggingface/diffusers/tree/main/examples) except the `research_projects` and `community` folders.
+The official training examples are maintained by the Diffusers' core maintainers whereas the research training examples are maintained by the community.
+This is because of the same reasons put forward in [6. Contribute a community pipeline](#6-contribute-a-community-pipeline) for official pipelines vs. community pipelines: It is not feasible for the core maintainers to maintain all possible training methods for diffusion models.
+If the Diffusers core maintainers and the community consider a certain training paradigm to be too experimental or not popular enough, the corresponding training code should be put in the `research_projects` folder and maintained by the author.
+
+Both official training and research examples consist of a directory that contains one or more training scripts, a `requirements.txt` file, and a `README.md` file. In order for the user to make use of the
+training examples, it is required to clone the repository:
+
+```bash
+git clone https://github.com/huggingface/diffusers
+```
+
+as well as to install all additional dependencies required for training:
+
+```bash
+cd diffusers
+pip install -r examples//requirements.txt
+```
+
+Therefore when adding an example, the `requirements.txt` file shall define all pip dependencies required for your training example so that once all those are installed, the user can run the example's training script. See, for example, the [DreamBooth `requirements.txt` file](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/requirements.txt).
+
+Training examples of the Diffusers library should adhere to the following philosophy:
+- All the code necessary to run the examples should be found in a single Python file.
+- One should be able to run the example from the command line with `python .py --args`.
+- Examples should be kept simple and serve as **an example** on how to use Diffusers for training. The purpose of example scripts is **not** to create state-of-the-art diffusion models, but rather to reproduce known training schemes without adding too much custom logic. As a byproduct of this point, our examples also strive to serve as good educational materials.
+
+To contribute an example, it is highly recommended to look at already existing examples such as [dreambooth](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth.py) to get an idea of how they should look like.
+We strongly advise contributors to make use of the [Accelerate library](https://github.com/huggingface/accelerate) as it's tightly integrated
+with Diffusers.
+Once an example script works, please make sure to add a comprehensive `README.md` that states how to use the example exactly. This README should include:
+- An example command on how to run the example script as shown [here e.g.](https://github.com/huggingface/diffusers/tree/main/examples/dreambooth#running-locally-with-pytorch).
+- A link to some training results (logs, models, ...) that show what the user can expect as shown [here e.g.](https://api.wandb.ai/report/patrickvonplaten/xm6cd5q5).
+- If you are adding a non-official/research training example, **please don't forget** to add a sentence that you are maintaining this training example which includes your git handle as shown [here](https://github.com/huggingface/diffusers/tree/main/examples/research_projects/intel_opts#diffusers-examples-with-intel-optimizations).
+
+If you are contributing to the official training examples, please also make sure to add a test to [examples/test_examples.py](https://github.com/huggingface/diffusers/blob/main/examples/test_examples.py). This is not necessary for non-official training examples.
+
+### 8. Fixing a "Good second issue"
+
+*Good second issues* are marked by the [Good second issue](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22Good+second+issue%22) label. Good second issues are
+usually more complicated to solve than [Good first issues](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22).
+The issue description usually gives less guidance on how to fix the issue and requires
+a decent understanding of the library by the interested contributor.
+If you are interested in tackling a good second issue, feel free to open a PR to fix it and link the PR to the issue. If you see that a PR has already been opened for this issue but did not get merged, have a look to understand why it wasn't merged and try to open an improved PR.
+Good second issues are usually more difficult to get merged compared to good first issues, so don't hesitate to ask for help from the core maintainers. If your PR is almost finished the core maintainers can also jump into your PR and commit to it in order to get it merged.
+
+### 9. Adding pipelines, models, schedulers
+
+Pipelines, models, and schedulers are the most important pieces of the Diffusers library.
+They provide easy access to state-of-the-art diffusion technologies and thus allow the community to
+build powerful generative AI applications.
+
+By adding a new model, pipeline, or scheduler you might enable a new powerful use case for any of the user interfaces relying on Diffusers which can be of immense value for the whole generative AI ecosystem.
+
+Diffusers has a couple of open feature requests for all three components - feel free to gloss over them
+if you don't know yet what specific component you would like to add:
+- [Model or pipeline](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22New+pipeline%2Fmodel%22)
+- [Scheduler](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22New+scheduler%22)
+
+Before adding any of the three components, it is strongly recommended that you give the [Philosophy guide](https://github.com/huggingface/diffusers/blob/main/PHILOSOPHY.md) a read to better understand the design of any of the three components. Please be aware that
+we cannot merge model, scheduler, or pipeline additions that strongly diverge from our design philosophy
+as it will lead to API inconsistencies. If you fundamentally disagree with a design choice, please
+open a [Feedback issue](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=&template=feedback.md&title=) instead so that it can be discussed whether a certain design
+pattern/design choice shall be changed everywhere in the library and whether we shall update our design philosophy. Consistency across the library is very important for us.
+
+Please make sure to add links to the original codebase/paper to the PR and ideally also ping the
+original author directly on the PR so that they can follow the progress and potentially help with questions.
+
+If you are unsure or stuck in the PR, don't hesitate to leave a message to ask for a first review or help.
+
+## How to write a good issue
+
+**The better your issue is written, the higher the chances that it will be quickly resolved.**
+
+1. Make sure that you've used the correct template for your issue. You can pick between *Bug Report*, *Feature Request*, *Feedback about API Design*, *New model/pipeline/scheduler addition*, *Forum*, or a blank issue. Make sure to pick the correct one when opening [a new issue](https://github.com/huggingface/diffusers/issues/new/choose).
+2. **Be precise**: Give your issue a fitting title. Try to formulate your issue description as simple as possible. The more precise you are when submitting an issue, the less time it takes to understand the issue and potentially solve it. Make sure to open an issue for one issue only and not for multiple issues. If you found multiple issues, simply open multiple issues. If your issue is a bug, try to be as precise as possible about what bug it is - you should not just write "Error in diffusers".
+3. **Reproducibility**: No reproducible code snippet == no solution. If you encounter a bug, maintainers **have to be able to reproduce** it. Make sure that you include a code snippet that can be copy-pasted into a Python interpreter to reproduce the issue. Make sure that your code snippet works, *i.e.* that there are no missing imports or missing links to images, ... Your issue should contain an error message **and** a code snippet that can be copy-pasted without any changes to reproduce the exact same error message. If your issue is using local model weights or local data that cannot be accessed by the reader, the issue cannot be solved. If you cannot share your data or model, try to make a dummy model or dummy data.
+4. **Minimalistic**: Try to help the reader as much as you can to understand the issue as quickly as possible by staying as concise as possible. Remove all code / all information that is irrelevant to the issue. If you have found a bug, try to create the easiest code example you can to demonstrate your issue, do not just dump your whole workflow into the issue as soon as you have found a bug. E.g., if you train a model and get an error at some point during the training, you should first try to understand what part of the training code is responsible for the error and try to reproduce it with a couple of lines. Try to use dummy data instead of full datasets.
+5. Add links. If you are referring to a certain naming, method, or model make sure to provide a link so that the reader can better understand what you mean. If you are referring to a specific PR or issue, make sure to link it to your issue. Do not assume that the reader knows what you are talking about. The more links you add to your issue the better.
+6. Formatting. Make sure to nicely format your issue by formatting code into Python code syntax, and error messages into normal code syntax. See the [official GitHub formatting docs](https://docs.github.com/en/get-started/writing-on-github/getting-started-with-writing-and-formatting-on-github/basic-writing-and-formatting-syntax) for more information.
+7. Think of your issue not as a ticket to be solved, but rather as a beautiful entry to a well-written encyclopedia. Every added issue is a contribution to publicly available knowledge. By adding a nicely written issue you not only make it easier for maintainers to solve your issue, but you are helping the whole community to better understand a certain aspect of the library.
+
+## How to write a good PR
+
+1. Be a chameleon. Understand existing design patterns and syntax and make sure your code additions flow seamlessly into the existing code base. Pull requests that significantly diverge from existing design patterns or user interfaces will not be merged.
+2. Be laser focused. A pull request should solve one problem and one problem only. Make sure to not fall into the trap of "also fixing another problem while we're adding it". It is much more difficult to review pull requests that solve multiple, unrelated problems at once.
+3. If helpful, try to add a code snippet that displays an example of how your addition can be used.
+4. The title of your pull request should be a summary of its contribution.
+5. If your pull request addresses an issue, please mention the issue number in
+the pull request description to make sure they are linked (and people
+consulting the issue know you are working on it);
+6. To indicate a work in progress please prefix the title with `[WIP]`. These
+are useful to avoid duplicated work, and to differentiate it from PRs ready
+to be merged;
+7. Try to formulate and format your text as explained in [How to write a good issue](#how-to-write-a-good-issue).
+8. Make sure existing tests pass;
+9. Add high-coverage tests. No quality testing = no merge.
+- If you are adding new `@slow` tests, make sure they pass using
+`RUN_SLOW=1 python -m pytest tests/test_my_new_model.py`.
+CircleCI does not run the slow tests, but GitHub Actions does every night!
+10. All public methods must have informative docstrings that work nicely with markdown. See [`pipeline_latent_diffusion.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py) for an example.
+11. Due to the rapidly growing repository, it is important to make sure that no files that would significantly weigh down the repository are added. This includes images, videos, and other non-text files. We prefer to leverage a hf.co hosted `dataset` like
+[`hf-internal-testing`](https://huggingface.co/hf-internal-testing) or [huggingface/documentation-images](https://huggingface.co/datasets/huggingface/documentation-images) to place these files.
+If an external contribution, feel free to add the images to your PR and ask a Hugging Face member to migrate your images
+to this dataset.
+
+## How to open a PR
+
+Before writing code, we strongly advise you to search through the existing PRs or
+issues to make sure that nobody is already working on the same thing. If you are
+unsure, it is always a good idea to open an issue to get some feedback.
+
+You will need basic `git` proficiency to be able to contribute to
+🧨 Diffusers. `git` is not the easiest tool to use but it has the greatest
+manual. Type `git --help` in a shell and enjoy. If you prefer books, [Pro
+Git](https://git-scm.com/book/en/v2) is a very good reference.
+
+Follow these steps to start contributing ([supported Python versions](https://github.com/huggingface/diffusers/blob/42f25d601a910dceadaee6c44345896b4cfa9928/setup.py#L270)):
+
+1. Fork the [repository](https://github.com/huggingface/diffusers) by
+clicking on the 'Fork' button on the repository's page. This creates a copy of the code
+under your GitHub user account.
+
+2. Clone your fork to your local disk, and add the base repository as a remote:
+
+ ```bash
+ $ git clone git@github.com:/diffusers.git
+ $ cd diffusers
+ $ git remote add upstream https://github.com/huggingface/diffusers.git
+ ```
+
+3. Create a new branch to hold your development changes:
+
+ ```bash
+ $ git checkout -b a-descriptive-name-for-my-changes
+ ```
+
+**Do not** work on the `main` branch.
+
+4. Set up a development environment by running the following command in a virtual environment:
+
+ ```bash
+ $ pip install -e ".[dev]"
+ ```
+
+If you have already cloned the repo, you might need to `git pull` to get the most recent changes in the
+library.
+
+5. Develop the features on your branch.
+
+As you work on the features, you should make sure that the test suite
+passes. You should run the tests impacted by your changes like this:
+
+ ```bash
+ $ pytest tests/.py
+ ```
+
+Before you run the tests, please make sure you install the dependencies required for testing. You can do so
+with this command:
+
+ ```bash
+ $ pip install -e ".[test]"
+ ```
+
+You can also run the full test suite with the following command, but it takes
+a beefy machine to produce a result in a decent amount of time now that
+Diffusers has grown a lot. Here is the command for it:
+
+ ```bash
+ $ make test
+ ```
+
+🧨 Diffusers relies on `ruff` and `isort` to format its source code
+consistently. After you make changes, apply automatic style corrections and code verifications
+that can't be automated in one go with:
+
+ ```bash
+ $ make style
+ ```
+
+🧨 Diffusers also uses `ruff` and a few custom scripts to check for coding mistakes. Quality
+control runs in CI, however, you can also run the same checks with:
+
+ ```bash
+ $ make quality
+ ```
+
+Once you're happy with your changes, add changed files using `git add` and
+make a commit with `git commit` to record your changes locally:
+
+ ```bash
+ $ git add modified_file.py
+ $ git commit -m "A descriptive message about your changes."
+ ```
+
+It is a good idea to sync your copy of the code with the original
+repository regularly. This way you can quickly account for changes:
+
+ ```bash
+ $ git pull upstream main
+ ```
+
+Push the changes to your account using:
+
+ ```bash
+ $ git push -u origin a-descriptive-name-for-my-changes
+ ```
+
+6. Once you are satisfied, go to the
+webpage of your fork on GitHub. Click on 'Pull request' to send your changes
+to the project maintainers for review.
+
+7. It's ok if maintainers ask you for changes. It happens to core contributors
+too! So everyone can see the changes in the Pull request, work in your local
+branch and push the changes to your fork. They will automatically appear in
+the pull request.
+
+### Tests
+
+An extensive test suite is included to test the library behavior and several examples. Library tests can be found in
+the [tests folder](https://github.com/huggingface/diffusers/tree/main/tests).
+
+We like `pytest` and `pytest-xdist` because it's faster. From the root of the
+repository, here's how to run tests with `pytest` for the library:
+
+```bash
+$ python -m pytest -n auto --dist=loadfile -s -v ./tests/
+```
+
+In fact, that's how `make test` is implemented!
+
+You can specify a smaller set of tests in order to test only the feature
+you're working on.
+
+By default, slow tests are skipped. Set the `RUN_SLOW` environment variable to
+`yes` to run them. This will download many gigabytes of models — make sure you
+have enough disk space and a good Internet connection, or a lot of patience!
+
+```bash
+$ RUN_SLOW=yes python -m pytest -n auto --dist=loadfile -s -v ./tests/
+```
+
+`unittest` is fully supported, here's how to run tests with it:
+
+```bash
+$ python -m unittest discover -s tests -t . -v
+$ python -m unittest discover -s examples -t examples -v
+```
+
+### Syncing forked main with upstream (HuggingFace) main
+
+To avoid pinging the upstream repository which adds reference notes to each upstream PR and sends unnecessary notifications to the developers involved in these PRs,
+when syncing the main branch of a forked repository, please, follow these steps:
+1. When possible, avoid syncing with the upstream using a branch and PR on the forked repository. Instead, merge directly into the forked main.
+2. If a PR is absolutely necessary, use the following steps after checking out your branch:
+```bash
+$ git checkout -b your-branch-for-syncing
+$ git pull --squash --no-commit upstream main
+$ git commit -m ''
+$ git push --set-upstream origin your-branch-for-syncing
+```
+
+### Style guide
+
+For documentation strings, 🧨 Diffusers follows the [Google style](https://google.github.io/styleguide/pyguide.html).
diff --git a/diffusers/LICENSE b/diffusers/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..261eeb9e9f8b2b4b0d119366dda99c6fd7d35c64
--- /dev/null
+++ b/diffusers/LICENSE
@@ -0,0 +1,201 @@
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/diffusers/MANIFEST.in b/diffusers/MANIFEST.in
new file mode 100644
index 0000000000000000000000000000000000000000..b22fe1a28a1ef881fdb36af3c30b14c0a5d10aa5
--- /dev/null
+++ b/diffusers/MANIFEST.in
@@ -0,0 +1,2 @@
+include LICENSE
+include src/diffusers/utils/model_card_template.md
diff --git a/diffusers/Makefile b/diffusers/Makefile
new file mode 100644
index 0000000000000000000000000000000000000000..9af2e8b1a5c9993411ffa06e7e48f9cfec3bd164
--- /dev/null
+++ b/diffusers/Makefile
@@ -0,0 +1,96 @@
+.PHONY: deps_table_update modified_only_fixup extra_style_checks quality style fixup fix-copies test test-examples
+
+# make sure to test the local checkout in scripts and not the pre-installed one (don't use quotes!)
+export PYTHONPATH = src
+
+check_dirs := examples scripts src tests utils benchmarks
+
+modified_only_fixup:
+ $(eval modified_py_files := $(shell python utils/get_modified_files.py $(check_dirs)))
+ @if test -n "$(modified_py_files)"; then \
+ echo "Checking/fixing $(modified_py_files)"; \
+ ruff check $(modified_py_files) --fix; \
+ ruff format $(modified_py_files);\
+ else \
+ echo "No library .py files were modified"; \
+ fi
+
+# Update src/diffusers/dependency_versions_table.py
+
+deps_table_update:
+ @python setup.py deps_table_update
+
+deps_table_check_updated:
+ @md5sum src/diffusers/dependency_versions_table.py > md5sum.saved
+ @python setup.py deps_table_update
+ @md5sum -c --quiet md5sum.saved || (printf "\nError: the version dependency table is outdated.\nPlease run 'make fixup' or 'make style' and commit the changes.\n\n" && exit 1)
+ @rm md5sum.saved
+
+# autogenerating code
+
+autogenerate_code: deps_table_update
+
+# Check that the repo is in a good state
+
+repo-consistency:
+ python utils/check_dummies.py
+ python utils/check_repo.py
+ python utils/check_inits.py
+
+# this target runs checks on all files
+
+quality:
+ ruff check $(check_dirs) setup.py
+ ruff format --check $(check_dirs) setup.py
+ doc-builder style src/diffusers docs/source --max_len 119 --check_only
+ python utils/check_doc_toc.py
+
+# Format source code automatically and check is there are any problems left that need manual fixing
+
+extra_style_checks:
+ python utils/custom_init_isort.py
+ python utils/check_doc_toc.py --fix_and_overwrite
+
+# this target runs checks on all files and potentially modifies some of them
+
+style:
+ ruff check $(check_dirs) setup.py --fix
+ ruff format $(check_dirs) setup.py
+ doc-builder style src/diffusers docs/source --max_len 119
+ ${MAKE} autogenerate_code
+ ${MAKE} extra_style_checks
+
+# Super fast fix and check target that only works on relevant modified files since the branch was made
+
+fixup: modified_only_fixup extra_style_checks autogenerate_code repo-consistency
+
+# Make marked copies of snippets of codes conform to the original
+
+fix-copies:
+ python utils/check_copies.py --fix_and_overwrite
+ python utils/check_dummies.py --fix_and_overwrite
+
+# Run tests for the library
+
+test:
+ python -m pytest -n auto --dist=loadfile -s -v ./tests/
+
+# Run tests for examples
+
+test-examples:
+ python -m pytest -n auto --dist=loadfile -s -v ./examples/
+
+
+# Release stuff
+
+pre-release:
+ python utils/release.py
+
+pre-patch:
+ python utils/release.py --patch
+
+post-release:
+ python utils/release.py --post_release
+
+post-patch:
+ python utils/release.py --post_release --patch
diff --git a/diffusers/PHILOSOPHY.md b/diffusers/PHILOSOPHY.md
new file mode 100644
index 0000000000000000000000000000000000000000..c646c61ec429020032cd51f0d17ac007b17edc48
--- /dev/null
+++ b/diffusers/PHILOSOPHY.md
@@ -0,0 +1,110 @@
+
+
+# Philosophy
+
+🧨 Diffusers provides **state-of-the-art** pretrained diffusion models across multiple modalities.
+Its purpose is to serve as a **modular toolbox** for both inference and training.
+
+We aim to build a library that stands the test of time and therefore take API design very seriously.
+
+In a nutshell, Diffusers is built to be a natural extension of PyTorch. Therefore, most of our design choices are based on [PyTorch's Design Principles](https://pytorch.org/docs/stable/community/design.html#pytorch-design-philosophy). Let's go over the most important ones:
+
+## Usability over Performance
+
+- While Diffusers has many built-in performance-enhancing features (see [Memory and Speed](https://huggingface.co/docs/diffusers/optimization/fp16)), models are always loaded with the highest precision and lowest optimization. Therefore, by default diffusion pipelines are always instantiated on CPU with float32 precision if not otherwise defined by the user. This ensures usability across different platforms and accelerators and means that no complex installations are required to run the library.
+- Diffusers aims to be a **light-weight** package and therefore has very few required dependencies, but many soft dependencies that can improve performance (such as `accelerate`, `safetensors`, `onnx`, etc...). We strive to keep the library as lightweight as possible so that it can be added without much concern as a dependency on other packages.
+- Diffusers prefers simple, self-explainable code over condensed, magic code. This means that short-hand code syntaxes such as lambda functions, and advanced PyTorch operators are often not desired.
+
+## Simple over easy
+
+As PyTorch states, **explicit is better than implicit** and **simple is better than complex**. This design philosophy is reflected in multiple parts of the library:
+- We follow PyTorch's API with methods like [`DiffusionPipeline.to`](https://huggingface.co/docs/diffusers/main/en/api/diffusion_pipeline#diffusers.DiffusionPipeline.to) to let the user handle device management.
+- Raising concise error messages is preferred to silently correct erroneous input. Diffusers aims at teaching the user, rather than making the library as easy to use as possible.
+- Complex model vs. scheduler logic is exposed instead of magically handled inside. Schedulers/Samplers are separated from diffusion models with minimal dependencies on each other. This forces the user to write the unrolled denoising loop. However, the separation allows for easier debugging and gives the user more control over adapting the denoising process or switching out diffusion models or schedulers.
+- Separately trained components of the diffusion pipeline, *e.g.* the text encoder, the UNet, and the variational autoencoder, each has their own model class. This forces the user to handle the interaction between the different model components, and the serialization format separates the model components into different files. However, this allows for easier debugging and customization. DreamBooth or Textual Inversion training
+is very simple thanks to Diffusers' ability to separate single components of the diffusion pipeline.
+
+## Tweakable, contributor-friendly over abstraction
+
+For large parts of the library, Diffusers adopts an important design principle of the [Transformers library](https://github.com/huggingface/transformers), which is to prefer copy-pasted code over hasty abstractions. This design principle is very opinionated and stands in stark contrast to popular design principles such as [Don't repeat yourself (DRY)](https://en.wikipedia.org/wiki/Don%27t_repeat_yourself).
+In short, just like Transformers does for modeling files, Diffusers prefers to keep an extremely low level of abstraction and very self-contained code for pipelines and schedulers.
+Functions, long code blocks, and even classes can be copied across multiple files which at first can look like a bad, sloppy design choice that makes the library unmaintainable.
+**However**, this design has proven to be extremely successful for Transformers and makes a lot of sense for community-driven, open-source machine learning libraries because:
+- Machine Learning is an extremely fast-moving field in which paradigms, model architectures, and algorithms are changing rapidly, which therefore makes it very difficult to define long-lasting code abstractions.
+- Machine Learning practitioners like to be able to quickly tweak existing code for ideation and research and therefore prefer self-contained code over one that contains many abstractions.
+- Open-source libraries rely on community contributions and therefore must build a library that is easy to contribute to. The more abstract the code, the more dependencies, the harder to read, and the harder to contribute to. Contributors simply stop contributing to very abstract libraries out of fear of breaking vital functionality. If contributing to a library cannot break other fundamental code, not only is it more inviting for potential new contributors, but it is also easier to review and contribute to multiple parts in parallel.
+
+At Hugging Face, we call this design the **single-file policy** which means that almost all of the code of a certain class should be written in a single, self-contained file. To read more about the philosophy, you can have a look
+at [this blog post](https://huggingface.co/blog/transformers-design-philosophy).
+
+In Diffusers, we follow this philosophy for both pipelines and schedulers, but only partly for diffusion models. The reason we don't follow this design fully for diffusion models is because almost all diffusion pipelines, such
+as [DDPM](https://huggingface.co/docs/diffusers/api/pipelines/ddpm), [Stable Diffusion](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/overview#stable-diffusion-pipelines), [unCLIP (DALL·E 2)](https://huggingface.co/docs/diffusers/api/pipelines/unclip) and [Imagen](https://imagen.research.google/) all rely on the same diffusion model, the [UNet](https://huggingface.co/docs/diffusers/api/models/unet2d-cond).
+
+Great, now you should have generally understood why 🧨 Diffusers is designed the way it is 🤗.
+We try to apply these design principles consistently across the library. Nevertheless, there are some minor exceptions to the philosophy or some unlucky design choices. If you have feedback regarding the design, we would ❤️ to hear it [directly on GitHub](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=&template=feedback.md&title=).
+
+## Design Philosophy in Details
+
+Now, let's look a bit into the nitty-gritty details of the design philosophy. Diffusers essentially consists of three major classes: [pipelines](https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines), [models](https://github.com/huggingface/diffusers/tree/main/src/diffusers/models), and [schedulers](https://github.com/huggingface/diffusers/tree/main/src/diffusers/schedulers).
+Let's walk through more detailed design decisions for each class.
+
+### Pipelines
+
+Pipelines are designed to be easy to use (therefore do not follow [*Simple over easy*](#simple-over-easy) 100%), are not feature complete, and should loosely be seen as examples of how to use [models](#models) and [schedulers](#schedulers) for inference.
+
+The following design principles are followed:
+- Pipelines follow the single-file policy. All pipelines can be found in individual directories under src/diffusers/pipelines. One pipeline folder corresponds to one diffusion paper/project/release. Multiple pipeline files can be gathered in one pipeline folder, as it’s done for [`src/diffusers/pipelines/stable-diffusion`](https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines/stable_diffusion). If pipelines share similar functionality, one can make use of the [# Copied from mechanism](https://github.com/huggingface/diffusers/blob/125d783076e5bd9785beb05367a2d2566843a271/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py#L251).
+- Pipelines all inherit from [`DiffusionPipeline`].
+- Every pipeline consists of different model and scheduler components, that are documented in the [`model_index.json` file](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5/blob/main/model_index.json), are accessible under the same name as attributes of the pipeline and can be shared between pipelines with [`DiffusionPipeline.components`](https://huggingface.co/docs/diffusers/main/en/api/diffusion_pipeline#diffusers.DiffusionPipeline.components) function.
+- Every pipeline should be loadable via the [`DiffusionPipeline.from_pretrained`](https://huggingface.co/docs/diffusers/main/en/api/diffusion_pipeline#diffusers.DiffusionPipeline.from_pretrained) function.
+- Pipelines should be used **only** for inference.
+- Pipelines should be very readable, self-explanatory, and easy to tweak.
+- Pipelines should be designed to build on top of each other and be easy to integrate into higher-level APIs.
+- Pipelines are **not** intended to be feature-complete user interfaces. For feature-complete user interfaces one should rather have a look at [InvokeAI](https://github.com/invoke-ai/InvokeAI), [Diffuzers](https://github.com/abhishekkrthakur/diffuzers), and [lama-cleaner](https://github.com/Sanster/lama-cleaner).
+- Every pipeline should have one and only one way to run it via a `__call__` method. The naming of the `__call__` arguments should be shared across all pipelines.
+- Pipelines should be named after the task they are intended to solve.
+- In almost all cases, novel diffusion pipelines shall be implemented in a new pipeline folder/file.
+
+### Models
+
+Models are designed as configurable toolboxes that are natural extensions of [PyTorch's Module class](https://pytorch.org/docs/stable/generated/torch.nn.Module.html). They only partly follow the **single-file policy**.
+
+The following design principles are followed:
+- Models correspond to **a type of model architecture**. *E.g.* the [`UNet2DConditionModel`] class is used for all UNet variations that expect 2D image inputs and are conditioned on some context.
+- All models can be found in [`src/diffusers/models`](https://github.com/huggingface/diffusers/tree/main/src/diffusers/models) and every model architecture shall be defined in its file, e.g. [`unets/unet_2d_condition.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unets/unet_2d_condition.py), [`transformers/transformer_2d.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_2d.py), etc...
+- Models **do not** follow the single-file policy and should make use of smaller model building blocks, such as [`attention.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py), [`resnet.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py), [`embeddings.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/embeddings.py), etc... **Note**: This is in stark contrast to Transformers' modeling files and shows that models do not really follow the single-file policy.
+- Models intend to expose complexity, just like PyTorch's `Module` class, and give clear error messages.
+- Models all inherit from `ModelMixin` and `ConfigMixin`.
+- Models can be optimized for performance when it doesn’t demand major code changes, keep backward compatibility, and give significant memory or compute gain.
+- Models should by default have the highest precision and lowest performance setting.
+- To integrate new model checkpoints whose general architecture can be classified as an architecture that already exists in Diffusers, the existing model architecture shall be adapted to make it work with the new checkpoint. One should only create a new file if the model architecture is fundamentally different.
+- Models should be designed to be easily extendable to future changes. This can be achieved by limiting public function arguments, configuration arguments, and "foreseeing" future changes, *e.g.* it is usually better to add `string` "...type" arguments that can easily be extended to new future types instead of boolean `is_..._type` arguments. Only the minimum amount of changes shall be made to existing architectures to make a new model checkpoint work.
+- The model design is a difficult trade-off between keeping code readable and concise and supporting many model checkpoints. For most parts of the modeling code, classes shall be adapted for new model checkpoints, while there are some exceptions where it is preferred to add new classes to make sure the code is kept concise and
+readable long-term, such as [UNet blocks](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unets/unet_2d_blocks.py) and [Attention processors](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+
+### Schedulers
+
+Schedulers are responsible to guide the denoising process for inference as well as to define a noise schedule for training. They are designed as individual classes with loadable configuration files and strongly follow the **single-file policy**.
+
+The following design principles are followed:
+- All schedulers are found in [`src/diffusers/schedulers`](https://github.com/huggingface/diffusers/tree/main/src/diffusers/schedulers).
+- Schedulers are **not** allowed to import from large utils files and shall be kept very self-contained.
+- One scheduler Python file corresponds to one scheduler algorithm (as might be defined in a paper).
+- If schedulers share similar functionalities, we can make use of the `# Copied from` mechanism.
+- Schedulers all inherit from `SchedulerMixin` and `ConfigMixin`.
+- Schedulers can be easily swapped out with the [`ConfigMixin.from_config`](https://huggingface.co/docs/diffusers/main/en/api/configuration#diffusers.ConfigMixin.from_config) method as explained in detail [here](./docs/source/en/using-diffusers/schedulers.md).
+- Every scheduler has to have a `set_num_inference_steps`, and a `step` function. `set_num_inference_steps(...)` has to be called before every denoising process, *i.e.* before `step(...)` is called.
+- Every scheduler exposes the timesteps to be "looped over" via a `timesteps` attribute, which is an array of timesteps the model will be called upon.
+- The `step(...)` function takes a predicted model output and the "current" sample (x_t) and returns the "previous", slightly more denoised sample (x_t-1).
+- Given the complexity of diffusion schedulers, the `step` function does not expose all the complexity and can be a bit of a "black box".
+- In almost all cases, novel schedulers shall be implemented in a new scheduling file.
diff --git a/diffusers/README.md b/diffusers/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..b99ca828e4d0dcf44a6954b9ae686435fc3127e9
--- /dev/null
+++ b/diffusers/README.md
@@ -0,0 +1,239 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+🤗 Diffusers is the go-to library for state-of-the-art pretrained diffusion models for generating images, audio, and even 3D structures of molecules. Whether you're looking for a simple inference solution or training your own diffusion models, 🤗 Diffusers is a modular toolbox that supports both. Our library is designed with a focus on [usability over performance](https://huggingface.co/docs/diffusers/conceptual/philosophy#usability-over-performance), [simple over easy](https://huggingface.co/docs/diffusers/conceptual/philosophy#simple-over-easy), and [customizability over abstractions](https://huggingface.co/docs/diffusers/conceptual/philosophy#tweakable-contributorfriendly-over-abstraction).
+
+🤗 Diffusers offers three core components:
+
+- State-of-the-art [diffusion pipelines](https://huggingface.co/docs/diffusers/api/pipelines/overview) that can be run in inference with just a few lines of code.
+- Interchangeable noise [schedulers](https://huggingface.co/docs/diffusers/api/schedulers/overview) for different diffusion speeds and output quality.
+- Pretrained [models](https://huggingface.co/docs/diffusers/api/models/overview) that can be used as building blocks, and combined with schedulers, for creating your own end-to-end diffusion systems.
+
+## Installation
+
+We recommend installing 🤗 Diffusers in a virtual environment from PyPI or Conda. For more details about installing [PyTorch](https://pytorch.org/get-started/locally/) and [Flax](https://flax.readthedocs.io/en/latest/#installation), please refer to their official documentation.
+
+### PyTorch
+
+With `pip` (official package):
+
+```bash
+pip install --upgrade diffusers[torch]
+```
+
+With `conda` (maintained by the community):
+
+```sh
+conda install -c conda-forge diffusers
+```
+
+### Flax
+
+With `pip` (official package):
+
+```bash
+pip install --upgrade diffusers[flax]
+```
+
+### Apple Silicon (M1/M2) support
+
+Please refer to the [How to use Stable Diffusion in Apple Silicon](https://huggingface.co/docs/diffusers/optimization/mps) guide.
+
+## Quickstart
+
+Generating outputs is super easy with 🤗 Diffusers. To generate an image from text, use the `from_pretrained` method to load any pretrained diffusion model (browse the [Hub](https://huggingface.co/models?library=diffusers&sort=downloads) for 30,000+ checkpoints):
+
+```python
+from diffusers import DiffusionPipeline
+import torch
+
+pipeline = DiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16)
+pipeline.to("cuda")
+pipeline("An image of a squirrel in Picasso style").images[0]
+```
+
+You can also dig into the models and schedulers toolbox to build your own diffusion system:
+
+```python
+from diffusers import DDPMScheduler, UNet2DModel
+from PIL import Image
+import torch
+
+scheduler = DDPMScheduler.from_pretrained("google/ddpm-cat-256")
+model = UNet2DModel.from_pretrained("google/ddpm-cat-256").to("cuda")
+scheduler.set_timesteps(50)
+
+sample_size = model.config.sample_size
+noise = torch.randn((1, 3, sample_size, sample_size), device="cuda")
+input = noise
+
+for t in scheduler.timesteps:
+ with torch.no_grad():
+ noisy_residual = model(input, t).sample
+ prev_noisy_sample = scheduler.step(noisy_residual, t, input).prev_sample
+ input = prev_noisy_sample
+
+image = (input / 2 + 0.5).clamp(0, 1)
+image = image.cpu().permute(0, 2, 3, 1).numpy()[0]
+image = Image.fromarray((image * 255).round().astype("uint8"))
+image
+```
+
+Check out the [Quickstart](https://huggingface.co/docs/diffusers/quicktour) to launch your diffusion journey today!
+
+## How to navigate the documentation
+
+| **Documentation** | **What can I learn?** |
+|---------------------------------------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
+| [Tutorial](https://huggingface.co/docs/diffusers/tutorials/tutorial_overview) | A basic crash course for learning how to use the library's most important features like using models and schedulers to build your own diffusion system, and training your own diffusion model. |
+| [Loading](https://huggingface.co/docs/diffusers/using-diffusers/loading_overview) | Guides for how to load and configure all the components (pipelines, models, and schedulers) of the library, as well as how to use different schedulers. |
+| [Pipelines for inference](https://huggingface.co/docs/diffusers/using-diffusers/pipeline_overview) | Guides for how to use pipelines for different inference tasks, batched generation, controlling generated outputs and randomness, and how to contribute a pipeline to the library. |
+| [Optimization](https://huggingface.co/docs/diffusers/optimization/opt_overview) | Guides for how to optimize your diffusion model to run faster and consume less memory. |
+| [Training](https://huggingface.co/docs/diffusers/training/overview) | Guides for how to train a diffusion model for different tasks with different training techniques. |
+## Contribution
+
+We ❤️ contributions from the open-source community!
+If you want to contribute to this library, please check out our [Contribution guide](https://github.com/huggingface/diffusers/blob/main/CONTRIBUTING.md).
+You can look out for [issues](https://github.com/huggingface/diffusers/issues) you'd like to tackle to contribute to the library.
+- See [Good first issues](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22) for general opportunities to contribute
+- See [New model/pipeline](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22New+pipeline%2Fmodel%22) to contribute exciting new diffusion models / diffusion pipelines
+- See [New scheduler](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22New+scheduler%22)
+
+Also, say 👋 in our public Discord channel . We discuss the hottest trends about diffusion models, help each other with contributions, personal projects or just hang out ☕.
+
+
+## Popular Tasks & Pipelines
+
+
+
+## Popular libraries using 🧨 Diffusers
+
+- https://github.com/microsoft/TaskMatrix
+- https://github.com/invoke-ai/InvokeAI
+- https://github.com/InstantID/InstantID
+- https://github.com/apple/ml-stable-diffusion
+- https://github.com/Sanster/lama-cleaner
+- https://github.com/IDEA-Research/Grounded-Segment-Anything
+- https://github.com/ashawkey/stable-dreamfusion
+- https://github.com/deep-floyd/IF
+- https://github.com/bentoml/BentoML
+- https://github.com/bmaltais/kohya_ss
+- +14,000 other amazing GitHub repositories 💪
+
+Thank you for using us ❤️.
+
+## Credits
+
+This library concretizes previous work by many different authors and would not have been possible without their great research and implementations. We'd like to thank, in particular, the following implementations which have helped us in our development and without which the API could not have been as polished today:
+
+- @CompVis' latent diffusion models library, available [here](https://github.com/CompVis/latent-diffusion)
+- @hojonathanho original DDPM implementation, available [here](https://github.com/hojonathanho/diffusion) as well as the extremely useful translation into PyTorch by @pesser, available [here](https://github.com/pesser/pytorch_diffusion)
+- @ermongroup's DDIM implementation, available [here](https://github.com/ermongroup/ddim)
+- @yang-song's Score-VE and Score-VP implementations, available [here](https://github.com/yang-song/score_sde_pytorch)
+
+We also want to thank @heejkoo for the very helpful overview of papers, code and resources on diffusion models, available [here](https://github.com/heejkoo/Awesome-Diffusion-Models) as well as @crowsonkb and @rromb for useful discussions and insights.
+
+## Citation
+
+```bibtex
+@misc{von-platen-etal-2022-diffusers,
+ author = {Patrick von Platen and Suraj Patil and Anton Lozhkov and Pedro Cuenca and Nathan Lambert and Kashif Rasul and Mishig Davaadorj and Dhruv Nair and Sayak Paul and William Berman and Yiyi Xu and Steven Liu and Thomas Wolf},
+ title = {Diffusers: State-of-the-art diffusion models},
+ year = {2022},
+ publisher = {GitHub},
+ journal = {GitHub repository},
+ howpublished = {\url{https://github.com/huggingface/diffusers}}
+}
+```
diff --git a/diffusers/_typos.toml b/diffusers/_typos.toml
new file mode 100644
index 0000000000000000000000000000000000000000..551099f981e7885fbda9ed28e297bace0e13407b
--- /dev/null
+++ b/diffusers/_typos.toml
@@ -0,0 +1,13 @@
+# Files for typos
+# Instruction: https://github.com/marketplace/actions/typos-action#getting-started
+
+[default.extend-identifiers]
+
+[default.extend-words]
+NIN="NIN" # NIN is used in scripts/convert_ncsnpp_original_checkpoint_to_diffusers.py
+nd="np" # nd may be np (numpy)
+parms="parms" # parms is used in scripts/convert_original_stable_diffusion_to_diffusers.py
+
+
+[files]
+extend-exclude = ["_typos.toml"]
diff --git a/diffusers/benchmarks/base_classes.py b/diffusers/benchmarks/base_classes.py
new file mode 100644
index 0000000000000000000000000000000000000000..45bf65c93c93a3a77d5b2e4b321b19a4cf9cd63f
--- /dev/null
+++ b/diffusers/benchmarks/base_classes.py
@@ -0,0 +1,346 @@
+import os
+import sys
+
+import torch
+
+from diffusers import (
+ AutoPipelineForImage2Image,
+ AutoPipelineForInpainting,
+ AutoPipelineForText2Image,
+ ControlNetModel,
+ LCMScheduler,
+ StableDiffusionAdapterPipeline,
+ StableDiffusionControlNetPipeline,
+ StableDiffusionXLAdapterPipeline,
+ StableDiffusionXLControlNetPipeline,
+ T2IAdapter,
+ WuerstchenCombinedPipeline,
+)
+from diffusers.utils import load_image
+
+
+sys.path.append(".")
+
+from utils import ( # noqa: E402
+ BASE_PATH,
+ PROMPT,
+ BenchmarkInfo,
+ benchmark_fn,
+ bytes_to_giga_bytes,
+ flush,
+ generate_csv_dict,
+ write_to_csv,
+)
+
+
+RESOLUTION_MAPPING = {
+ "Lykon/DreamShaper": (512, 512),
+ "lllyasviel/sd-controlnet-canny": (512, 512),
+ "diffusers/controlnet-canny-sdxl-1.0": (1024, 1024),
+ "TencentARC/t2iadapter_canny_sd14v1": (512, 512),
+ "TencentARC/t2i-adapter-canny-sdxl-1.0": (1024, 1024),
+ "stabilityai/stable-diffusion-2-1": (768, 768),
+ "stabilityai/stable-diffusion-xl-base-1.0": (1024, 1024),
+ "stabilityai/stable-diffusion-xl-refiner-1.0": (1024, 1024),
+ "stabilityai/sdxl-turbo": (512, 512),
+}
+
+
+class BaseBenchmak:
+ pipeline_class = None
+
+ def __init__(self, args):
+ super().__init__()
+
+ def run_inference(self, args):
+ raise NotImplementedError
+
+ def benchmark(self, args):
+ raise NotImplementedError
+
+ def get_result_filepath(self, args):
+ pipeline_class_name = str(self.pipe.__class__.__name__)
+ name = (
+ args.ckpt.replace("/", "_")
+ + "_"
+ + pipeline_class_name
+ + f"-bs@{args.batch_size}-steps@{args.num_inference_steps}-mco@{args.model_cpu_offload}-compile@{args.run_compile}.csv"
+ )
+ filepath = os.path.join(BASE_PATH, name)
+ return filepath
+
+
+class TextToImageBenchmark(BaseBenchmak):
+ pipeline_class = AutoPipelineForText2Image
+
+ def __init__(self, args):
+ pipe = self.pipeline_class.from_pretrained(args.ckpt, torch_dtype=torch.float16)
+ pipe = pipe.to("cuda")
+
+ if args.run_compile:
+ if not isinstance(pipe, WuerstchenCombinedPipeline):
+ pipe.unet.to(memory_format=torch.channels_last)
+ print("Run torch compile")
+ pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
+
+ if hasattr(pipe, "movq") and getattr(pipe, "movq", None) is not None:
+ pipe.movq.to(memory_format=torch.channels_last)
+ pipe.movq = torch.compile(pipe.movq, mode="reduce-overhead", fullgraph=True)
+ else:
+ print("Run torch compile")
+ pipe.decoder = torch.compile(pipe.decoder, mode="reduce-overhead", fullgraph=True)
+ pipe.vqgan = torch.compile(pipe.vqgan, mode="reduce-overhead", fullgraph=True)
+
+ pipe.set_progress_bar_config(disable=True)
+ self.pipe = pipe
+
+ def run_inference(self, pipe, args):
+ _ = pipe(
+ prompt=PROMPT,
+ num_inference_steps=args.num_inference_steps,
+ num_images_per_prompt=args.batch_size,
+ )
+
+ def benchmark(self, args):
+ flush()
+
+ print(f"[INFO] {self.pipe.__class__.__name__}: Running benchmark with: {vars(args)}\n")
+
+ time = benchmark_fn(self.run_inference, self.pipe, args) # in seconds.
+ memory = bytes_to_giga_bytes(torch.cuda.max_memory_allocated()) # in GBs.
+ benchmark_info = BenchmarkInfo(time=time, memory=memory)
+
+ pipeline_class_name = str(self.pipe.__class__.__name__)
+ flush()
+ csv_dict = generate_csv_dict(
+ pipeline_cls=pipeline_class_name, ckpt=args.ckpt, args=args, benchmark_info=benchmark_info
+ )
+ filepath = self.get_result_filepath(args)
+ write_to_csv(filepath, csv_dict)
+ print(f"Logs written to: {filepath}")
+ flush()
+
+
+class TurboTextToImageBenchmark(TextToImageBenchmark):
+ def __init__(self, args):
+ super().__init__(args)
+
+ def run_inference(self, pipe, args):
+ _ = pipe(
+ prompt=PROMPT,
+ num_inference_steps=args.num_inference_steps,
+ num_images_per_prompt=args.batch_size,
+ guidance_scale=0.0,
+ )
+
+
+class LCMLoRATextToImageBenchmark(TextToImageBenchmark):
+ lora_id = "latent-consistency/lcm-lora-sdxl"
+
+ def __init__(self, args):
+ super().__init__(args)
+ self.pipe.load_lora_weights(self.lora_id)
+ self.pipe.fuse_lora()
+ self.pipe.unload_lora_weights()
+ self.pipe.scheduler = LCMScheduler.from_config(self.pipe.scheduler.config)
+
+ def get_result_filepath(self, args):
+ pipeline_class_name = str(self.pipe.__class__.__name__)
+ name = (
+ self.lora_id.replace("/", "_")
+ + "_"
+ + pipeline_class_name
+ + f"-bs@{args.batch_size}-steps@{args.num_inference_steps}-mco@{args.model_cpu_offload}-compile@{args.run_compile}.csv"
+ )
+ filepath = os.path.join(BASE_PATH, name)
+ return filepath
+
+ def run_inference(self, pipe, args):
+ _ = pipe(
+ prompt=PROMPT,
+ num_inference_steps=args.num_inference_steps,
+ num_images_per_prompt=args.batch_size,
+ guidance_scale=1.0,
+ )
+
+ def benchmark(self, args):
+ flush()
+
+ print(f"[INFO] {self.pipe.__class__.__name__}: Running benchmark with: {vars(args)}\n")
+
+ time = benchmark_fn(self.run_inference, self.pipe, args) # in seconds.
+ memory = bytes_to_giga_bytes(torch.cuda.max_memory_allocated()) # in GBs.
+ benchmark_info = BenchmarkInfo(time=time, memory=memory)
+
+ pipeline_class_name = str(self.pipe.__class__.__name__)
+ flush()
+ csv_dict = generate_csv_dict(
+ pipeline_cls=pipeline_class_name, ckpt=self.lora_id, args=args, benchmark_info=benchmark_info
+ )
+ filepath = self.get_result_filepath(args)
+ write_to_csv(filepath, csv_dict)
+ print(f"Logs written to: {filepath}")
+ flush()
+
+
+class ImageToImageBenchmark(TextToImageBenchmark):
+ pipeline_class = AutoPipelineForImage2Image
+ url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/benchmarking/1665_Girl_with_a_Pearl_Earring.jpg"
+ image = load_image(url).convert("RGB")
+
+ def __init__(self, args):
+ super().__init__(args)
+ self.image = self.image.resize(RESOLUTION_MAPPING[args.ckpt])
+
+ def run_inference(self, pipe, args):
+ _ = pipe(
+ prompt=PROMPT,
+ image=self.image,
+ num_inference_steps=args.num_inference_steps,
+ num_images_per_prompt=args.batch_size,
+ )
+
+
+class TurboImageToImageBenchmark(ImageToImageBenchmark):
+ def __init__(self, args):
+ super().__init__(args)
+
+ def run_inference(self, pipe, args):
+ _ = pipe(
+ prompt=PROMPT,
+ image=self.image,
+ num_inference_steps=args.num_inference_steps,
+ num_images_per_prompt=args.batch_size,
+ guidance_scale=0.0,
+ strength=0.5,
+ )
+
+
+class InpaintingBenchmark(ImageToImageBenchmark):
+ pipeline_class = AutoPipelineForInpainting
+ mask_url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/benchmarking/overture-creations-5sI6fQgYIuo_mask.png"
+ mask = load_image(mask_url).convert("RGB")
+
+ def __init__(self, args):
+ super().__init__(args)
+ self.image = self.image.resize(RESOLUTION_MAPPING[args.ckpt])
+ self.mask = self.mask.resize(RESOLUTION_MAPPING[args.ckpt])
+
+ def run_inference(self, pipe, args):
+ _ = pipe(
+ prompt=PROMPT,
+ image=self.image,
+ mask_image=self.mask,
+ num_inference_steps=args.num_inference_steps,
+ num_images_per_prompt=args.batch_size,
+ )
+
+
+class IPAdapterTextToImageBenchmark(TextToImageBenchmark):
+ url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/load_neg_embed.png"
+ image = load_image(url)
+
+ def __init__(self, args):
+ pipe = self.pipeline_class.from_pretrained(args.ckpt, torch_dtype=torch.float16).to("cuda")
+ pipe.load_ip_adapter(
+ args.ip_adapter_id[0],
+ subfolder="models" if "sdxl" not in args.ip_adapter_id[1] else "sdxl_models",
+ weight_name=args.ip_adapter_id[1],
+ )
+
+ if args.run_compile:
+ pipe.unet.to(memory_format=torch.channels_last)
+ print("Run torch compile")
+ pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
+
+ pipe.set_progress_bar_config(disable=True)
+ self.pipe = pipe
+
+ def run_inference(self, pipe, args):
+ _ = pipe(
+ prompt=PROMPT,
+ ip_adapter_image=self.image,
+ num_inference_steps=args.num_inference_steps,
+ num_images_per_prompt=args.batch_size,
+ )
+
+
+class ControlNetBenchmark(TextToImageBenchmark):
+ pipeline_class = StableDiffusionControlNetPipeline
+ aux_network_class = ControlNetModel
+ root_ckpt = "Lykon/DreamShaper"
+
+ url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/benchmarking/canny_image_condition.png"
+ image = load_image(url).convert("RGB")
+
+ def __init__(self, args):
+ aux_network = self.aux_network_class.from_pretrained(args.ckpt, torch_dtype=torch.float16)
+ pipe = self.pipeline_class.from_pretrained(self.root_ckpt, controlnet=aux_network, torch_dtype=torch.float16)
+ pipe = pipe.to("cuda")
+
+ pipe.set_progress_bar_config(disable=True)
+ self.pipe = pipe
+
+ if args.run_compile:
+ pipe.unet.to(memory_format=torch.channels_last)
+ pipe.controlnet.to(memory_format=torch.channels_last)
+
+ print("Run torch compile")
+ pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
+ pipe.controlnet = torch.compile(pipe.controlnet, mode="reduce-overhead", fullgraph=True)
+
+ self.image = self.image.resize(RESOLUTION_MAPPING[args.ckpt])
+
+ def run_inference(self, pipe, args):
+ _ = pipe(
+ prompt=PROMPT,
+ image=self.image,
+ num_inference_steps=args.num_inference_steps,
+ num_images_per_prompt=args.batch_size,
+ )
+
+
+class ControlNetSDXLBenchmark(ControlNetBenchmark):
+ pipeline_class = StableDiffusionXLControlNetPipeline
+ root_ckpt = "stabilityai/stable-diffusion-xl-base-1.0"
+
+ def __init__(self, args):
+ super().__init__(args)
+
+
+class T2IAdapterBenchmark(ControlNetBenchmark):
+ pipeline_class = StableDiffusionAdapterPipeline
+ aux_network_class = T2IAdapter
+ root_ckpt = "Lykon/DreamShaper"
+
+ url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/benchmarking/canny_for_adapter.png"
+ image = load_image(url).convert("L")
+
+ def __init__(self, args):
+ aux_network = self.aux_network_class.from_pretrained(args.ckpt, torch_dtype=torch.float16)
+ pipe = self.pipeline_class.from_pretrained(self.root_ckpt, adapter=aux_network, torch_dtype=torch.float16)
+ pipe = pipe.to("cuda")
+
+ pipe.set_progress_bar_config(disable=True)
+ self.pipe = pipe
+
+ if args.run_compile:
+ pipe.unet.to(memory_format=torch.channels_last)
+ pipe.adapter.to(memory_format=torch.channels_last)
+
+ print("Run torch compile")
+ pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
+ pipe.adapter = torch.compile(pipe.adapter, mode="reduce-overhead", fullgraph=True)
+
+ self.image = self.image.resize(RESOLUTION_MAPPING[args.ckpt])
+
+
+class T2IAdapterSDXLBenchmark(T2IAdapterBenchmark):
+ pipeline_class = StableDiffusionXLAdapterPipeline
+ root_ckpt = "stabilityai/stable-diffusion-xl-base-1.0"
+
+ url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/benchmarking/canny_for_adapter_sdxl.png"
+ image = load_image(url)
+
+ def __init__(self, args):
+ super().__init__(args)
diff --git a/diffusers/benchmarks/benchmark_controlnet.py b/diffusers/benchmarks/benchmark_controlnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..9217004461dc9352b1b9e6cda698dd866177eb67
--- /dev/null
+++ b/diffusers/benchmarks/benchmark_controlnet.py
@@ -0,0 +1,26 @@
+import argparse
+import sys
+
+
+sys.path.append(".")
+from base_classes import ControlNetBenchmark, ControlNetSDXLBenchmark # noqa: E402
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--ckpt",
+ type=str,
+ default="lllyasviel/sd-controlnet-canny",
+ choices=["lllyasviel/sd-controlnet-canny", "diffusers/controlnet-canny-sdxl-1.0"],
+ )
+ parser.add_argument("--batch_size", type=int, default=1)
+ parser.add_argument("--num_inference_steps", type=int, default=50)
+ parser.add_argument("--model_cpu_offload", action="store_true")
+ parser.add_argument("--run_compile", action="store_true")
+ args = parser.parse_args()
+
+ benchmark_pipe = (
+ ControlNetBenchmark(args) if args.ckpt == "lllyasviel/sd-controlnet-canny" else ControlNetSDXLBenchmark(args)
+ )
+ benchmark_pipe.benchmark(args)
diff --git a/diffusers/benchmarks/benchmark_ip_adapters.py b/diffusers/benchmarks/benchmark_ip_adapters.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a31a21fc60ddd2d041d7896ee4144d1766fcc01
--- /dev/null
+++ b/diffusers/benchmarks/benchmark_ip_adapters.py
@@ -0,0 +1,33 @@
+import argparse
+import sys
+
+
+sys.path.append(".")
+from base_classes import IPAdapterTextToImageBenchmark # noqa: E402
+
+
+IP_ADAPTER_CKPTS = {
+ # because original SD v1.5 has been taken down.
+ "Lykon/DreamShaper": ("h94/IP-Adapter", "ip-adapter_sd15.bin"),
+ "stabilityai/stable-diffusion-xl-base-1.0": ("h94/IP-Adapter", "ip-adapter_sdxl.bin"),
+}
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--ckpt",
+ type=str,
+ default="rstabilityai/stable-diffusion-xl-base-1.0",
+ choices=list(IP_ADAPTER_CKPTS.keys()),
+ )
+ parser.add_argument("--batch_size", type=int, default=1)
+ parser.add_argument("--num_inference_steps", type=int, default=50)
+ parser.add_argument("--model_cpu_offload", action="store_true")
+ parser.add_argument("--run_compile", action="store_true")
+ args = parser.parse_args()
+
+ args.ip_adapter_id = IP_ADAPTER_CKPTS[args.ckpt]
+ benchmark_pipe = IPAdapterTextToImageBenchmark(args)
+ args.ckpt = f"{args.ckpt} (IP-Adapter)"
+ benchmark_pipe.benchmark(args)
diff --git a/diffusers/benchmarks/benchmark_sd_img.py b/diffusers/benchmarks/benchmark_sd_img.py
new file mode 100644
index 0000000000000000000000000000000000000000..772befe8795fb933fdb7831565eb28848e642af2
--- /dev/null
+++ b/diffusers/benchmarks/benchmark_sd_img.py
@@ -0,0 +1,29 @@
+import argparse
+import sys
+
+
+sys.path.append(".")
+from base_classes import ImageToImageBenchmark, TurboImageToImageBenchmark # noqa: E402
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--ckpt",
+ type=str,
+ default="Lykon/DreamShaper",
+ choices=[
+ "Lykon/DreamShaper",
+ "stabilityai/stable-diffusion-2-1",
+ "stabilityai/stable-diffusion-xl-refiner-1.0",
+ "stabilityai/sdxl-turbo",
+ ],
+ )
+ parser.add_argument("--batch_size", type=int, default=1)
+ parser.add_argument("--num_inference_steps", type=int, default=50)
+ parser.add_argument("--model_cpu_offload", action="store_true")
+ parser.add_argument("--run_compile", action="store_true")
+ args = parser.parse_args()
+
+ benchmark_pipe = ImageToImageBenchmark(args) if "turbo" not in args.ckpt else TurboImageToImageBenchmark(args)
+ benchmark_pipe.benchmark(args)
diff --git a/diffusers/benchmarks/benchmark_sd_inpainting.py b/diffusers/benchmarks/benchmark_sd_inpainting.py
new file mode 100644
index 0000000000000000000000000000000000000000..143adcb0d87c7a9cf9c8ae0d19581fd48f9ac4c8
--- /dev/null
+++ b/diffusers/benchmarks/benchmark_sd_inpainting.py
@@ -0,0 +1,28 @@
+import argparse
+import sys
+
+
+sys.path.append(".")
+from base_classes import InpaintingBenchmark # noqa: E402
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--ckpt",
+ type=str,
+ default="Lykon/DreamShaper",
+ choices=[
+ "Lykon/DreamShaper",
+ "stabilityai/stable-diffusion-2-1",
+ "stabilityai/stable-diffusion-xl-base-1.0",
+ ],
+ )
+ parser.add_argument("--batch_size", type=int, default=1)
+ parser.add_argument("--num_inference_steps", type=int, default=50)
+ parser.add_argument("--model_cpu_offload", action="store_true")
+ parser.add_argument("--run_compile", action="store_true")
+ args = parser.parse_args()
+
+ benchmark_pipe = InpaintingBenchmark(args)
+ benchmark_pipe.benchmark(args)
diff --git a/diffusers/benchmarks/benchmark_t2i_adapter.py b/diffusers/benchmarks/benchmark_t2i_adapter.py
new file mode 100644
index 0000000000000000000000000000000000000000..44b04b470ea65d5f3318bee21bb107c7b4b2b2f9
--- /dev/null
+++ b/diffusers/benchmarks/benchmark_t2i_adapter.py
@@ -0,0 +1,28 @@
+import argparse
+import sys
+
+
+sys.path.append(".")
+from base_classes import T2IAdapterBenchmark, T2IAdapterSDXLBenchmark # noqa: E402
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--ckpt",
+ type=str,
+ default="TencentARC/t2iadapter_canny_sd14v1",
+ choices=["TencentARC/t2iadapter_canny_sd14v1", "TencentARC/t2i-adapter-canny-sdxl-1.0"],
+ )
+ parser.add_argument("--batch_size", type=int, default=1)
+ parser.add_argument("--num_inference_steps", type=int, default=50)
+ parser.add_argument("--model_cpu_offload", action="store_true")
+ parser.add_argument("--run_compile", action="store_true")
+ args = parser.parse_args()
+
+ benchmark_pipe = (
+ T2IAdapterBenchmark(args)
+ if args.ckpt == "TencentARC/t2iadapter_canny_sd14v1"
+ else T2IAdapterSDXLBenchmark(args)
+ )
+ benchmark_pipe.benchmark(args)
diff --git a/diffusers/benchmarks/benchmark_t2i_lcm_lora.py b/diffusers/benchmarks/benchmark_t2i_lcm_lora.py
new file mode 100644
index 0000000000000000000000000000000000000000..957e0a463e28fccc51fe32cd975f3d5234cfd1f2
--- /dev/null
+++ b/diffusers/benchmarks/benchmark_t2i_lcm_lora.py
@@ -0,0 +1,23 @@
+import argparse
+import sys
+
+
+sys.path.append(".")
+from base_classes import LCMLoRATextToImageBenchmark # noqa: E402
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--ckpt",
+ type=str,
+ default="stabilityai/stable-diffusion-xl-base-1.0",
+ )
+ parser.add_argument("--batch_size", type=int, default=1)
+ parser.add_argument("--num_inference_steps", type=int, default=4)
+ parser.add_argument("--model_cpu_offload", action="store_true")
+ parser.add_argument("--run_compile", action="store_true")
+ args = parser.parse_args()
+
+ benchmark_pipe = LCMLoRATextToImageBenchmark(args)
+ benchmark_pipe.benchmark(args)
diff --git a/diffusers/benchmarks/benchmark_text_to_image.py b/diffusers/benchmarks/benchmark_text_to_image.py
new file mode 100644
index 0000000000000000000000000000000000000000..ddc7fb2676a5c6090ce08e4488f5ca697c4329aa
--- /dev/null
+++ b/diffusers/benchmarks/benchmark_text_to_image.py
@@ -0,0 +1,40 @@
+import argparse
+import sys
+
+
+sys.path.append(".")
+from base_classes import TextToImageBenchmark, TurboTextToImageBenchmark # noqa: E402
+
+
+ALL_T2I_CKPTS = [
+ "Lykon/DreamShaper",
+ "segmind/SSD-1B",
+ "stabilityai/stable-diffusion-xl-base-1.0",
+ "kandinsky-community/kandinsky-2-2-decoder",
+ "warp-ai/wuerstchen",
+ "stabilityai/sdxl-turbo",
+]
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--ckpt",
+ type=str,
+ default="Lykon/DreamShaper",
+ choices=ALL_T2I_CKPTS,
+ )
+ parser.add_argument("--batch_size", type=int, default=1)
+ parser.add_argument("--num_inference_steps", type=int, default=50)
+ parser.add_argument("--model_cpu_offload", action="store_true")
+ parser.add_argument("--run_compile", action="store_true")
+ args = parser.parse_args()
+
+ benchmark_cls = None
+ if "turbo" in args.ckpt:
+ benchmark_cls = TurboTextToImageBenchmark
+ else:
+ benchmark_cls = TextToImageBenchmark
+
+ benchmark_pipe = benchmark_cls(args)
+ benchmark_pipe.benchmark(args)
diff --git a/diffusers/benchmarks/push_results.py b/diffusers/benchmarks/push_results.py
new file mode 100644
index 0000000000000000000000000000000000000000..962e07c6d74cd9c9d0b18657e935d244a131a9de
--- /dev/null
+++ b/diffusers/benchmarks/push_results.py
@@ -0,0 +1,72 @@
+import glob
+import sys
+
+import pandas as pd
+from huggingface_hub import hf_hub_download, upload_file
+from huggingface_hub.utils._errors import EntryNotFoundError
+
+
+sys.path.append(".")
+from utils import BASE_PATH, FINAL_CSV_FILE, GITHUB_SHA, REPO_ID, collate_csv # noqa: E402
+
+
+def has_previous_benchmark() -> str:
+ csv_path = None
+ try:
+ csv_path = hf_hub_download(repo_id=REPO_ID, repo_type="dataset", filename=FINAL_CSV_FILE)
+ except EntryNotFoundError:
+ csv_path = None
+ return csv_path
+
+
+def filter_float(value):
+ if isinstance(value, str):
+ return float(value.split()[0])
+ return value
+
+
+def push_to_hf_dataset():
+ all_csvs = sorted(glob.glob(f"{BASE_PATH}/*.csv"))
+ collate_csv(all_csvs, FINAL_CSV_FILE)
+
+ # If there's an existing benchmark file, we should report the changes.
+ csv_path = has_previous_benchmark()
+ if csv_path is not None:
+ current_results = pd.read_csv(FINAL_CSV_FILE)
+ previous_results = pd.read_csv(csv_path)
+
+ numeric_columns = current_results.select_dtypes(include=["float64", "int64"]).columns
+ numeric_columns = [
+ c for c in numeric_columns if c not in ["batch_size", "num_inference_steps", "actual_gpu_memory (gbs)"]
+ ]
+
+ for column in numeric_columns:
+ previous_results[column] = previous_results[column].map(lambda x: filter_float(x))
+
+ # Calculate the percentage change
+ current_results[column] = current_results[column].astype(float)
+ previous_results[column] = previous_results[column].astype(float)
+ percent_change = ((current_results[column] - previous_results[column]) / previous_results[column]) * 100
+
+ # Format the values with '+' or '-' sign and append to original values
+ current_results[column] = current_results[column].map(str) + percent_change.map(
+ lambda x: f" ({'+' if x > 0 else ''}{x:.2f}%)"
+ )
+ # There might be newly added rows. So, filter out the NaNs.
+ current_results[column] = current_results[column].map(lambda x: x.replace(" (nan%)", ""))
+
+ # Overwrite the current result file.
+ current_results.to_csv(FINAL_CSV_FILE, index=False)
+
+ commit_message = f"upload from sha: {GITHUB_SHA}" if GITHUB_SHA is not None else "upload benchmark results"
+ upload_file(
+ repo_id=REPO_ID,
+ path_in_repo=FINAL_CSV_FILE,
+ path_or_fileobj=FINAL_CSV_FILE,
+ repo_type="dataset",
+ commit_message=commit_message,
+ )
+
+
+if __name__ == "__main__":
+ push_to_hf_dataset()
diff --git a/diffusers/benchmarks/run_all.py b/diffusers/benchmarks/run_all.py
new file mode 100644
index 0000000000000000000000000000000000000000..c9932cc71c38513a301854a5ad227a2e3ec28d30
--- /dev/null
+++ b/diffusers/benchmarks/run_all.py
@@ -0,0 +1,101 @@
+import glob
+import subprocess
+import sys
+from typing import List
+
+
+sys.path.append(".")
+from benchmark_text_to_image import ALL_T2I_CKPTS # noqa: E402
+
+
+PATTERN = "benchmark_*.py"
+
+
+class SubprocessCallException(Exception):
+ pass
+
+
+# Taken from `test_examples_utils.py`
+def run_command(command: List[str], return_stdout=False):
+ """
+ Runs `command` with `subprocess.check_output` and will potentially return the `stdout`. Will also properly capture
+ if an error occurred while running `command`
+ """
+ try:
+ output = subprocess.check_output(command, stderr=subprocess.STDOUT)
+ if return_stdout:
+ if hasattr(output, "decode"):
+ output = output.decode("utf-8")
+ return output
+ except subprocess.CalledProcessError as e:
+ raise SubprocessCallException(
+ f"Command `{' '.join(command)}` failed with the following error:\n\n{e.output.decode()}"
+ ) from e
+
+
+def main():
+ python_files = glob.glob(PATTERN)
+
+ for file in python_files:
+ print(f"****** Running file: {file} ******")
+
+ # Run with canonical settings.
+ if file != "benchmark_text_to_image.py" and file != "benchmark_ip_adapters.py":
+ command = f"python {file}"
+ run_command(command.split())
+
+ command += " --run_compile"
+ run_command(command.split())
+
+ # Run variants.
+ for file in python_files:
+ # See: https://github.com/pytorch/pytorch/issues/129637
+ if file == "benchmark_ip_adapters.py":
+ continue
+
+ if file == "benchmark_text_to_image.py":
+ for ckpt in ALL_T2I_CKPTS:
+ command = f"python {file} --ckpt {ckpt}"
+
+ if "turbo" in ckpt:
+ command += " --num_inference_steps 1"
+
+ run_command(command.split())
+
+ command += " --run_compile"
+ run_command(command.split())
+
+ elif file == "benchmark_sd_img.py":
+ for ckpt in ["stabilityai/stable-diffusion-xl-refiner-1.0", "stabilityai/sdxl-turbo"]:
+ command = f"python {file} --ckpt {ckpt}"
+
+ if ckpt == "stabilityai/sdxl-turbo":
+ command += " --num_inference_steps 2"
+
+ run_command(command.split())
+ command += " --run_compile"
+ run_command(command.split())
+
+ elif file in ["benchmark_sd_inpainting.py", "benchmark_ip_adapters.py"]:
+ sdxl_ckpt = "stabilityai/stable-diffusion-xl-base-1.0"
+ command = f"python {file} --ckpt {sdxl_ckpt}"
+ run_command(command.split())
+
+ command += " --run_compile"
+ run_command(command.split())
+
+ elif file in ["benchmark_controlnet.py", "benchmark_t2i_adapter.py"]:
+ sdxl_ckpt = (
+ "diffusers/controlnet-canny-sdxl-1.0"
+ if "controlnet" in file
+ else "TencentARC/t2i-adapter-canny-sdxl-1.0"
+ )
+ command = f"python {file} --ckpt {sdxl_ckpt}"
+ run_command(command.split())
+
+ command += " --run_compile"
+ run_command(command.split())
+
+
+if __name__ == "__main__":
+ main()
diff --git a/diffusers/benchmarks/utils.py b/diffusers/benchmarks/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..5fce920ac6c3549e3654b1cfb2f0e79096aa019d
--- /dev/null
+++ b/diffusers/benchmarks/utils.py
@@ -0,0 +1,98 @@
+import argparse
+import csv
+import gc
+import os
+from dataclasses import dataclass
+from typing import Dict, List, Union
+
+import torch
+import torch.utils.benchmark as benchmark
+
+
+GITHUB_SHA = os.getenv("GITHUB_SHA", None)
+BENCHMARK_FIELDS = [
+ "pipeline_cls",
+ "ckpt_id",
+ "batch_size",
+ "num_inference_steps",
+ "model_cpu_offload",
+ "run_compile",
+ "time (secs)",
+ "memory (gbs)",
+ "actual_gpu_memory (gbs)",
+ "github_sha",
+]
+
+PROMPT = "ghibli style, a fantasy landscape with castles"
+BASE_PATH = os.getenv("BASE_PATH", ".")
+TOTAL_GPU_MEMORY = float(os.getenv("TOTAL_GPU_MEMORY", torch.cuda.get_device_properties(0).total_memory / (1024**3)))
+
+REPO_ID = "diffusers/benchmarks"
+FINAL_CSV_FILE = "collated_results.csv"
+
+
+@dataclass
+class BenchmarkInfo:
+ time: float
+ memory: float
+
+
+def flush():
+ """Wipes off memory."""
+ gc.collect()
+ torch.cuda.empty_cache()
+ torch.cuda.reset_max_memory_allocated()
+ torch.cuda.reset_peak_memory_stats()
+
+
+def bytes_to_giga_bytes(bytes):
+ return f"{(bytes / 1024 / 1024 / 1024):.3f}"
+
+
+def benchmark_fn(f, *args, **kwargs):
+ t0 = benchmark.Timer(
+ stmt="f(*args, **kwargs)",
+ globals={"args": args, "kwargs": kwargs, "f": f},
+ num_threads=torch.get_num_threads(),
+ )
+ return f"{(t0.blocked_autorange().mean):.3f}"
+
+
+def generate_csv_dict(
+ pipeline_cls: str, ckpt: str, args: argparse.Namespace, benchmark_info: BenchmarkInfo
+) -> Dict[str, Union[str, bool, float]]:
+ """Packs benchmarking data into a dictionary for latter serialization."""
+ data_dict = {
+ "pipeline_cls": pipeline_cls,
+ "ckpt_id": ckpt,
+ "batch_size": args.batch_size,
+ "num_inference_steps": args.num_inference_steps,
+ "model_cpu_offload": args.model_cpu_offload,
+ "run_compile": args.run_compile,
+ "time (secs)": benchmark_info.time,
+ "memory (gbs)": benchmark_info.memory,
+ "actual_gpu_memory (gbs)": f"{(TOTAL_GPU_MEMORY):.3f}",
+ "github_sha": GITHUB_SHA,
+ }
+ return data_dict
+
+
+def write_to_csv(file_name: str, data_dict: Dict[str, Union[str, bool, float]]):
+ """Serializes a dictionary into a CSV file."""
+ with open(file_name, mode="w", newline="") as csvfile:
+ writer = csv.DictWriter(csvfile, fieldnames=BENCHMARK_FIELDS)
+ writer.writeheader()
+ writer.writerow(data_dict)
+
+
+def collate_csv(input_files: List[str], output_file: str):
+ """Collates multiple identically structured CSVs into a single CSV file."""
+ with open(output_file, mode="w", newline="") as outfile:
+ writer = csv.DictWriter(outfile, fieldnames=BENCHMARK_FIELDS)
+ writer.writeheader()
+
+ for file in input_files:
+ with open(file, mode="r") as infile:
+ reader = csv.DictReader(infile)
+ for row in reader:
+ writer.writerow(row)
diff --git a/diffusers/docker/diffusers-doc-builder/Dockerfile b/diffusers/docker/diffusers-doc-builder/Dockerfile
new file mode 100644
index 0000000000000000000000000000000000000000..c9fc62707cb0dac3126c06ebc22420c1353146ff
--- /dev/null
+++ b/diffusers/docker/diffusers-doc-builder/Dockerfile
@@ -0,0 +1,52 @@
+FROM ubuntu:20.04
+LABEL maintainer="Hugging Face"
+LABEL repository="diffusers"
+
+ENV DEBIAN_FRONTEND=noninteractive
+
+RUN apt-get -y update \
+ && apt-get install -y software-properties-common \
+ && add-apt-repository ppa:deadsnakes/ppa
+
+RUN apt install -y bash \
+ build-essential \
+ git \
+ git-lfs \
+ curl \
+ ca-certificates \
+ libsndfile1-dev \
+ python3.10 \
+ python3-pip \
+ libgl1 \
+ zip \
+ wget \
+ python3.10-venv && \
+ rm -rf /var/lib/apt/lists
+
+# make sure to use venv
+RUN python3.10 -m venv /opt/venv
+ENV PATH="/opt/venv/bin:$PATH"
+
+# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
+RUN python3.10 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \
+ python3.10 -m uv pip install --no-cache-dir \
+ torch \
+ torchvision \
+ torchaudio \
+ invisible_watermark \
+ --extra-index-url https://download.pytorch.org/whl/cpu && \
+ python3.10 -m uv pip install --no-cache-dir \
+ accelerate \
+ datasets \
+ hf-doc-builder \
+ huggingface-hub \
+ Jinja2 \
+ librosa \
+ numpy==1.26.4 \
+ scipy \
+ tensorboard \
+ transformers \
+ matplotlib \
+ setuptools==69.5.1
+
+CMD ["/bin/bash"]
diff --git a/diffusers/docker/diffusers-flax-cpu/Dockerfile b/diffusers/docker/diffusers-flax-cpu/Dockerfile
new file mode 100644
index 0000000000000000000000000000000000000000..051008aa9a2ee057a56bfe55b7af34595c26c413
--- /dev/null
+++ b/diffusers/docker/diffusers-flax-cpu/Dockerfile
@@ -0,0 +1,49 @@
+FROM ubuntu:20.04
+LABEL maintainer="Hugging Face"
+LABEL repository="diffusers"
+
+ENV DEBIAN_FRONTEND=noninteractive
+
+RUN apt-get -y update \
+ && apt-get install -y software-properties-common \
+ && add-apt-repository ppa:deadsnakes/ppa
+
+RUN apt install -y bash \
+ build-essential \
+ git \
+ git-lfs \
+ curl \
+ ca-certificates \
+ libsndfile1-dev \
+ libgl1 \
+ python3.10 \
+ python3-pip \
+ python3.10-venv && \
+ rm -rf /var/lib/apt/lists
+
+# make sure to use venv
+RUN python3.10 -m venv /opt/venv
+ENV PATH="/opt/venv/bin:$PATH"
+
+# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
+# follow the instructions here: https://cloud.google.com/tpu/docs/run-in-container#train_a_jax_model_in_a_docker_container
+RUN python3 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \
+ python3 -m uv pip install --upgrade --no-cache-dir \
+ clu \
+ "jax[cpu]>=0.2.16,!=0.3.2" \
+ "flax>=0.4.1" \
+ "jaxlib>=0.1.65" && \
+ python3 -m uv pip install --no-cache-dir \
+ accelerate \
+ datasets \
+ hf-doc-builder \
+ huggingface-hub \
+ Jinja2 \
+ librosa \
+ numpy==1.26.4 \
+ scipy \
+ tensorboard \
+ transformers \
+ hf_transfer
+
+CMD ["/bin/bash"]
\ No newline at end of file
diff --git a/diffusers/docker/diffusers-flax-tpu/Dockerfile b/diffusers/docker/diffusers-flax-tpu/Dockerfile
new file mode 100644
index 0000000000000000000000000000000000000000..405f068923b79136197daa8089340734d1435c60
--- /dev/null
+++ b/diffusers/docker/diffusers-flax-tpu/Dockerfile
@@ -0,0 +1,51 @@
+FROM ubuntu:20.04
+LABEL maintainer="Hugging Face"
+LABEL repository="diffusers"
+
+ENV DEBIAN_FRONTEND=noninteractive
+
+RUN apt-get -y update \
+ && apt-get install -y software-properties-common \
+ && add-apt-repository ppa:deadsnakes/ppa
+
+RUN apt install -y bash \
+ build-essential \
+ git \
+ git-lfs \
+ curl \
+ ca-certificates \
+ libsndfile1-dev \
+ libgl1 \
+ python3.10 \
+ python3-pip \
+ python3.10-venv && \
+ rm -rf /var/lib/apt/lists
+
+# make sure to use venv
+RUN python3.10 -m venv /opt/venv
+ENV PATH="/opt/venv/bin:$PATH"
+
+# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
+# follow the instructions here: https://cloud.google.com/tpu/docs/run-in-container#train_a_jax_model_in_a_docker_container
+RUN python3 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \
+ python3 -m pip install --no-cache-dir \
+ "jax[tpu]>=0.2.16,!=0.3.2" \
+ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html && \
+ python3 -m uv pip install --upgrade --no-cache-dir \
+ clu \
+ "flax>=0.4.1" \
+ "jaxlib>=0.1.65" && \
+ python3 -m uv pip install --no-cache-dir \
+ accelerate \
+ datasets \
+ hf-doc-builder \
+ huggingface-hub \
+ Jinja2 \
+ librosa \
+ numpy==1.26.4 \
+ scipy \
+ tensorboard \
+ transformers \
+ hf_transfer
+
+CMD ["/bin/bash"]
\ No newline at end of file
diff --git a/diffusers/docker/diffusers-onnxruntime-cpu/Dockerfile b/diffusers/docker/diffusers-onnxruntime-cpu/Dockerfile
new file mode 100644
index 0000000000000000000000000000000000000000..6f4b13e8a9ba0368bcc55336e7574f326b9cef80
--- /dev/null
+++ b/diffusers/docker/diffusers-onnxruntime-cpu/Dockerfile
@@ -0,0 +1,49 @@
+FROM ubuntu:20.04
+LABEL maintainer="Hugging Face"
+LABEL repository="diffusers"
+
+ENV DEBIAN_FRONTEND=noninteractive
+
+RUN apt-get -y update \
+ && apt-get install -y software-properties-common \
+ && add-apt-repository ppa:deadsnakes/ppa
+
+RUN apt install -y bash \
+ build-essential \
+ git \
+ git-lfs \
+ curl \
+ ca-certificates \
+ libsndfile1-dev \
+ libgl1 \
+ python3.10 \
+ python3-pip \
+ python3.10-venv && \
+ rm -rf /var/lib/apt/lists
+
+# make sure to use venv
+RUN python3.10 -m venv /opt/venv
+ENV PATH="/opt/venv/bin:$PATH"
+
+# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
+RUN python3 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \
+ python3 -m uv pip install --no-cache-dir \
+ torch==2.1.2 \
+ torchvision==0.16.2 \
+ torchaudio==2.1.2 \
+ onnxruntime \
+ --extra-index-url https://download.pytorch.org/whl/cpu && \
+ python3 -m uv pip install --no-cache-dir \
+ accelerate \
+ datasets \
+ hf-doc-builder \
+ huggingface-hub \
+ Jinja2 \
+ librosa \
+ numpy==1.26.4 \
+ scipy \
+ tensorboard \
+ transformers \
+ hf_transfer
+
+CMD ["/bin/bash"]
\ No newline at end of file
diff --git a/diffusers/docker/diffusers-onnxruntime-cuda/Dockerfile b/diffusers/docker/diffusers-onnxruntime-cuda/Dockerfile
new file mode 100644
index 0000000000000000000000000000000000000000..bd1d871033c96bd9eb0aa6157747c0fad3fa9093
--- /dev/null
+++ b/diffusers/docker/diffusers-onnxruntime-cuda/Dockerfile
@@ -0,0 +1,50 @@
+FROM nvidia/cuda:12.1.0-runtime-ubuntu20.04
+LABEL maintainer="Hugging Face"
+LABEL repository="diffusers"
+
+ENV DEBIAN_FRONTEND=noninteractive
+
+RUN apt-get -y update \
+ && apt-get install -y software-properties-common \
+ && add-apt-repository ppa:deadsnakes/ppa
+
+RUN apt install -y bash \
+ build-essential \
+ git \
+ git-lfs \
+ curl \
+ ca-certificates \
+ libsndfile1-dev \
+ libgl1 \
+ python3.10 \
+ python3-pip \
+ python3.10-venv && \
+ rm -rf /var/lib/apt/lists
+
+# make sure to use venv
+RUN python3.10 -m venv /opt/venv
+ENV PATH="/opt/venv/bin:$PATH"
+
+# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
+RUN python3.10 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \
+ python3.10 -m uv pip install --no-cache-dir \
+ torch \
+ torchvision \
+ torchaudio \
+ "onnxruntime-gpu>=1.13.1" \
+ --extra-index-url https://download.pytorch.org/whl/cu117 && \
+ python3.10 -m uv pip install --no-cache-dir \
+ accelerate \
+ datasets \
+ hf-doc-builder \
+ huggingface-hub \
+ hf_transfer \
+ Jinja2 \
+ librosa \
+ numpy==1.26.4 \
+ scipy \
+ tensorboard \
+ transformers \
+ hf_transfer
+
+CMD ["/bin/bash"]
\ No newline at end of file
diff --git a/diffusers/docker/diffusers-pytorch-compile-cuda/Dockerfile b/diffusers/docker/diffusers-pytorch-compile-cuda/Dockerfile
new file mode 100644
index 0000000000000000000000000000000000000000..cb4a9c0f9896b0a1e8c3f21ef7baef1035a70d44
--- /dev/null
+++ b/diffusers/docker/diffusers-pytorch-compile-cuda/Dockerfile
@@ -0,0 +1,50 @@
+FROM nvidia/cuda:12.1.0-runtime-ubuntu20.04
+LABEL maintainer="Hugging Face"
+LABEL repository="diffusers"
+
+ENV DEBIAN_FRONTEND=noninteractive
+
+RUN apt-get -y update \
+ && apt-get install -y software-properties-common \
+ && add-apt-repository ppa:deadsnakes/ppa
+
+RUN apt install -y bash \
+ build-essential \
+ git \
+ git-lfs \
+ curl \
+ ca-certificates \
+ libsndfile1-dev \
+ libgl1 \
+ python3.10 \
+ python3.10-dev \
+ python3-pip \
+ python3.10-venv && \
+ rm -rf /var/lib/apt/lists
+
+# make sure to use venv
+RUN python3.10 -m venv /opt/venv
+ENV PATH="/opt/venv/bin:$PATH"
+
+# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
+RUN python3.10 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \
+ python3.10 -m uv pip install --no-cache-dir \
+ torch \
+ torchvision \
+ torchaudio \
+ invisible_watermark && \
+ python3.10 -m pip install --no-cache-dir \
+ accelerate \
+ datasets \
+ hf-doc-builder \
+ huggingface-hub \
+ hf_transfer \
+ Jinja2 \
+ librosa \
+ numpy==1.26.4 \
+ scipy \
+ tensorboard \
+ transformers \
+ hf_transfer
+
+CMD ["/bin/bash"]
diff --git a/diffusers/docker/diffusers-pytorch-cpu/Dockerfile b/diffusers/docker/diffusers-pytorch-cpu/Dockerfile
new file mode 100644
index 0000000000000000000000000000000000000000..8d98c52598d2095a2842f97a2f68c56391e24431
--- /dev/null
+++ b/diffusers/docker/diffusers-pytorch-cpu/Dockerfile
@@ -0,0 +1,50 @@
+FROM ubuntu:20.04
+LABEL maintainer="Hugging Face"
+LABEL repository="diffusers"
+
+ENV DEBIAN_FRONTEND=noninteractive
+
+RUN apt-get -y update \
+ && apt-get install -y software-properties-common \
+ && add-apt-repository ppa:deadsnakes/ppa
+
+RUN apt install -y bash \
+ build-essential \
+ git \
+ git-lfs \
+ curl \
+ ca-certificates \
+ libsndfile1-dev \
+ python3.10 \
+ python3.10-dev \
+ python3-pip \
+ libgl1 \
+ python3.10-venv && \
+ rm -rf /var/lib/apt/lists
+
+# make sure to use venv
+RUN python3.10 -m venv /opt/venv
+ENV PATH="/opt/venv/bin:$PATH"
+
+# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
+RUN python3.10 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \
+ python3.10 -m uv pip install --no-cache-dir \
+ torch \
+ torchvision \
+ torchaudio \
+ invisible_watermark \
+ --extra-index-url https://download.pytorch.org/whl/cpu && \
+ python3.10 -m uv pip install --no-cache-dir \
+ accelerate \
+ datasets \
+ hf-doc-builder \
+ huggingface-hub \
+ Jinja2 \
+ librosa \
+ numpy==1.26.4 \
+ scipy \
+ tensorboard \
+ transformers matplotlib \
+ hf_transfer
+
+CMD ["/bin/bash"]
diff --git a/diffusers/docker/diffusers-pytorch-cuda/Dockerfile b/diffusers/docker/diffusers-pytorch-cuda/Dockerfile
new file mode 100644
index 0000000000000000000000000000000000000000..695f5ed08dc5b202b84f57f3bfab7fe8559c9d97
--- /dev/null
+++ b/diffusers/docker/diffusers-pytorch-cuda/Dockerfile
@@ -0,0 +1,51 @@
+FROM nvidia/cuda:12.1.0-runtime-ubuntu20.04
+LABEL maintainer="Hugging Face"
+LABEL repository="diffusers"
+
+ENV DEBIAN_FRONTEND=noninteractive
+
+RUN apt-get -y update \
+ && apt-get install -y software-properties-common \
+ && add-apt-repository ppa:deadsnakes/ppa
+
+RUN apt install -y bash \
+ build-essential \
+ git \
+ git-lfs \
+ curl \
+ ca-certificates \
+ libsndfile1-dev \
+ libgl1 \
+ python3.10 \
+ python3.10-dev \
+ python3-pip \
+ python3.10-venv && \
+ rm -rf /var/lib/apt/lists
+
+# make sure to use venv
+RUN python3.10 -m venv /opt/venv
+ENV PATH="/opt/venv/bin:$PATH"
+
+# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
+RUN python3.10 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \
+ python3.10 -m uv pip install --no-cache-dir \
+ torch \
+ torchvision \
+ torchaudio \
+ invisible_watermark && \
+ python3.10 -m pip install --no-cache-dir \
+ accelerate \
+ datasets \
+ hf-doc-builder \
+ huggingface-hub \
+ hf_transfer \
+ Jinja2 \
+ librosa \
+ numpy==1.26.4 \
+ scipy \
+ tensorboard \
+ transformers \
+ pytorch-lightning \
+ hf_transfer
+
+CMD ["/bin/bash"]
diff --git a/diffusers/docker/diffusers-pytorch-xformers-cuda/Dockerfile b/diffusers/docker/diffusers-pytorch-xformers-cuda/Dockerfile
new file mode 100644
index 0000000000000000000000000000000000000000..1693eb2930242d950f591aa461e7d527c7268fc6
--- /dev/null
+++ b/diffusers/docker/diffusers-pytorch-xformers-cuda/Dockerfile
@@ -0,0 +1,51 @@
+FROM nvidia/cuda:12.1.0-runtime-ubuntu20.04
+LABEL maintainer="Hugging Face"
+LABEL repository="diffusers"
+
+ENV DEBIAN_FRONTEND=noninteractive
+
+RUN apt-get -y update \
+ && apt-get install -y software-properties-common \
+ && add-apt-repository ppa:deadsnakes/ppa
+
+RUN apt install -y bash \
+ build-essential \
+ git \
+ git-lfs \
+ curl \
+ ca-certificates \
+ libsndfile1-dev \
+ libgl1 \
+ python3.10 \
+ python3.10-dev \
+ python3-pip \
+ python3.10-venv && \
+ rm -rf /var/lib/apt/lists
+
+# make sure to use venv
+RUN python3.10 -m venv /opt/venv
+ENV PATH="/opt/venv/bin:$PATH"
+
+# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
+RUN python3.10 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \
+ python3.10 -m pip install --no-cache-dir \
+ torch \
+ torchvision \
+ torchaudio \
+ invisible_watermark && \
+ python3.10 -m uv pip install --no-cache-dir \
+ accelerate \
+ datasets \
+ hf-doc-builder \
+ huggingface-hub \
+ hf_transfer \
+ Jinja2 \
+ librosa \
+ numpy==1.26.4 \
+ scipy \
+ tensorboard \
+ transformers \
+ xformers \
+ hf_transfer
+
+CMD ["/bin/bash"]
diff --git a/diffusers/examples/README.md b/diffusers/examples/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..c275070405456548a48fa63c73e473a9173f327e
--- /dev/null
+++ b/diffusers/examples/README.md
@@ -0,0 +1,70 @@
+
+
+# 🧨 Diffusers Examples
+
+Diffusers examples are a collection of scripts to demonstrate how to effectively use the `diffusers` library
+for a variety of use cases involving training or fine-tuning.
+
+**Note**: If you are looking for **official** examples on how to use `diffusers` for inference, please have a look at [src/diffusers/pipelines](https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines).
+
+Our examples aspire to be **self-contained**, **easy-to-tweak**, **beginner-friendly** and for **one-purpose-only**.
+More specifically, this means:
+
+- **Self-contained**: An example script shall only depend on "pip-install-able" Python packages that can be found in a `requirements.txt` file. Example scripts shall **not** depend on any local files. This means that one can simply download an example script, *e.g.* [train_unconditional.py](https://github.com/huggingface/diffusers/blob/main/examples/unconditional_image_generation/train_unconditional.py), install the required dependencies, *e.g.* [requirements.txt](https://github.com/huggingface/diffusers/blob/main/examples/unconditional_image_generation/requirements.txt) and execute the example script.
+- **Easy-to-tweak**: While we strive to present as many use cases as possible, the example scripts are just that - examples. It is expected that they won't work out-of-the box on your specific problem and that you will be required to change a few lines of code to adapt them to your needs. To help you with that, most of the examples fully expose the preprocessing of the data and the training loop to allow you to tweak and edit them as required.
+- **Beginner-friendly**: We do not aim for providing state-of-the-art training scripts for the newest models, but rather examples that can be used as a way to better understand diffusion models and how to use them with the `diffusers` library. We often purposefully leave out certain state-of-the-art methods if we consider them too complex for beginners.
+- **One-purpose-only**: Examples should show one task and one task only. Even if a task is from a modeling point of view very similar, *e.g.* image super-resolution and image modification tend to use the same model and training method, we want examples to showcase only one task to keep them as readable and easy-to-understand as possible.
+
+We provide **official** examples that cover the most popular tasks of diffusion models.
+*Official* examples are **actively** maintained by the `diffusers` maintainers and we try to rigorously follow our example philosophy as defined above.
+If you feel like another important example should exist, we are more than happy to welcome a [Feature Request](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=&template=feature_request.md&title=) or directly a [Pull Request](https://github.com/huggingface/diffusers/compare) from you!
+
+Training examples show how to pretrain or fine-tune diffusion models for a variety of tasks. Currently we support:
+
+| Task | 🤗 Accelerate | 🤗 Datasets | Colab
+|---|---|:---:|:---:|
+| [**Unconditional Image Generation**](./unconditional_image_generation) | ✅ | ✅ | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb)
+| [**Text-to-Image fine-tuning**](./text_to_image) | ✅ | ✅ |
+| [**Textual Inversion**](./textual_inversion) | ✅ | - | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/sd_textual_inversion_training.ipynb)
+| [**Dreambooth**](./dreambooth) | ✅ | - | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/sd_dreambooth_training.ipynb)
+| [**ControlNet**](./controlnet) | ✅ | ✅ | -
+| [**InstructPix2Pix**](./instruct_pix2pix) | ✅ | ✅ | -
+| [**Reinforcement Learning for Control**](./reinforcement_learning) | - | - | coming soon.
+
+## Community
+
+In addition, we provide **community** examples, which are examples added and maintained by our community.
+Community examples can consist of both *training* examples or *inference* pipelines.
+For such examples, we are more lenient regarding the philosophy defined above and also cannot guarantee to provide maintenance for every issue.
+Examples that are useful for the community, but are either not yet deemed popular or not yet following our above philosophy should go into the [community examples](https://github.com/huggingface/diffusers/tree/main/examples/community) folder. The community folder therefore includes training examples and inference pipelines.
+**Note**: Community examples can be a [great first contribution](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22) to show to the community how you like to use `diffusers` 🪄.
+
+## Research Projects
+
+We also provide **research_projects** examples that are maintained by the community as defined in the respective research project folders. These examples are useful and offer the extended capabilities which are complementary to the official examples. You may refer to [research_projects](https://github.com/huggingface/diffusers/tree/main/examples/research_projects) for details.
+
+## Important note
+
+To make sure you can successfully run the latest versions of the example scripts, you have to **install the library from source** and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
+```bash
+git clone https://github.com/huggingface/diffusers
+cd diffusers
+pip install .
+```
+Then cd in the example folder of your choice and run
+```bash
+pip install -r requirements.txt
+```
diff --git a/diffusers/examples/advanced_diffusion_training/README.md b/diffusers/examples/advanced_diffusion_training/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..cd8c5feda9f038802590fe179e90873d64579ad1
--- /dev/null
+++ b/diffusers/examples/advanced_diffusion_training/README.md
@@ -0,0 +1,454 @@
+# Advanced diffusion training examples
+
+## Train Dreambooth LoRA with Stable Diffusion XL
+> [!TIP]
+> 💡 This example follows the techniques and recommended practices covered in the blog post: [LoRA training scripts of the world, unite!](https://huggingface.co/blog/sdxl_lora_advanced_script). Make sure to check it out before starting 🤗
+
+[DreamBooth](https://arxiv.org/abs/2208.12242) is a method to personalize text2image models like stable diffusion given just a few(3~5) images of a subject.
+
+LoRA - Low-Rank Adaption of Large Language Models, was first introduced by Microsoft in [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen*
+In a nutshell, LoRA allows to adapt pretrained models by adding pairs of rank-decomposition matrices to existing weights and **only** training those newly added weights. This has a couple of advantages:
+- Previous pretrained weights are kept frozen so that the model is not prone to [catastrophic forgetting](https://www.pnas.org/doi/10.1073/pnas.1611835114)
+- Rank-decomposition matrices have significantly fewer parameters than the original model, which means that trained LoRA weights are easily portable.
+- LoRA attention layers allow to control to which extent the model is adapted towards new training images via a `scale` parameter.
+[cloneofsimo](https://github.com/cloneofsimo) was the first to try out LoRA training for Stable Diffusion in
+the popular [lora](https://github.com/cloneofsimo/lora) GitHub repository.
+
+The `train_dreambooth_lora_sdxl_advanced.py` script shows how to implement dreambooth-LoRA, combining the training process shown in `train_dreambooth_lora_sdxl.py`, with
+advanced features and techniques, inspired and built upon contributions by [Nataniel Ruiz](https://twitter.com/natanielruizg): [Dreambooth](https://dreambooth.github.io), [Rinon Gal](https://twitter.com/RinonGal): [Textual Inversion](https://textual-inversion.github.io), [Ron Mokady](https://twitter.com/MokadyRon): [Pivotal Tuning](https://arxiv.org/abs/2106.05744), [Simo Ryu](https://twitter.com/cloneofsimo): [cog-sdxl](https://github.com/replicate/cog-sdxl),
+[Kohya](https://twitter.com/kohya_tech/): [sd-scripts](https://github.com/kohya-ss/sd-scripts), [The Last Ben](https://twitter.com/__TheBen): [fast-stable-diffusion](https://github.com/TheLastBen/fast-stable-diffusion) ❤️
+
+> [!NOTE]
+> 💡If this is your first time training a Dreambooth LoRA, congrats!🥳
+> You might want to familiarize yourself more with the techniques: [Dreambooth blog](https://huggingface.co/blog/dreambooth), [Using LoRA for Efficient Stable Diffusion Fine-Tuning blog](https://huggingface.co/blog/lora)
+
+📚 Read more about the advanced features and best practices in this community derived blog post: [LoRA training scripts of the world, unite!](https://huggingface.co/blog/sdxl_lora_advanced_script)
+
+
+## Running locally with PyTorch
+
+### Installing the dependencies
+
+Before running the scripts, make sure to install the library's training dependencies:
+
+**Important**
+
+To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
+```bash
+git clone https://github.com/huggingface/diffusers
+cd diffusers
+pip install -e .
+```
+
+Then cd in the `examples/advanced_diffusion_training` folder and run
+```bash
+pip install -r requirements.txt
+```
+
+And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
+
+```bash
+accelerate config
+```
+
+Or for a default accelerate configuration without answering questions about your environment
+
+```bash
+accelerate config default
+```
+
+Or if your environment doesn't support an interactive shell e.g. a notebook
+
+```python
+from accelerate.utils import write_basic_config
+write_basic_config()
+```
+
+When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
+Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.
+
+### Pivotal Tuning
+**Training with text encoder(s)**
+
+Alongside the UNet, LoRA fine-tuning of the text encoders is also supported. In addition to the text encoder optimization
+available with `train_dreambooth_lora_sdxl_advanced.py`, in the advanced script **pivotal tuning** is also supported.
+[pivotal tuning](https://huggingface.co/blog/sdxl_lora_advanced_script#pivotal-tuning) combines Textual Inversion with regular diffusion fine-tuning -
+we insert new tokens into the text encoders of the model, instead of reusing existing ones.
+We then optimize the newly-inserted token embeddings to represent the new concept.
+
+To do so, just specify `--train_text_encoder_ti` while launching training (for regular text encoder optimizations, use `--train_text_encoder`).
+Please keep the following points in mind:
+
+* SDXL has two text encoders. So, we fine-tune both using LoRA.
+* When not fine-tuning the text encoders, we ALWAYS precompute the text embeddings to save memory.
+
+### 3D icon example
+
+Now let's get our dataset. For this example we will use some cool images of 3d rendered icons: https://huggingface.co/datasets/linoyts/3d_icon.
+
+Let's first download it locally:
+
+```python
+from huggingface_hub import snapshot_download
+
+local_dir = "./3d_icon"
+snapshot_download(
+ "LinoyTsaban/3d_icon",
+ local_dir=local_dir, repo_type="dataset",
+ ignore_patterns=".gitattributes",
+)
+```
+
+Let's review some of the advanced features we're going to be using for this example:
+- **custom captions**:
+To use custom captioning, first ensure that you have the datasets library installed, otherwise you can install it by
+```bash
+pip install datasets
+```
+
+Now we'll simply specify the name of the dataset and caption column (in this case it's "prompt")
+
+```
+--dataset_name=./3d_icon
+--caption_column=prompt
+```
+
+You can also load a dataset straight from by specifying it's name in `dataset_name`.
+Look [here](https://huggingface.co/blog/sdxl_lora_advanced_script#custom-captioning) for more info on creating/loadin your own caption dataset.
+
+- **optimizer**: for this example, we'll use [prodigy](https://huggingface.co/blog/sdxl_lora_advanced_script#adaptive-optimizers) - an adaptive optimizer
+- **pivotal tuning**
+- **min SNR gamma**
+
+**Now, we can launch training:**
+
+```bash
+export MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0"
+export DATASET_NAME="./3d_icon"
+export OUTPUT_DIR="3d-icon-SDXL-LoRA"
+export VAE_PATH="madebyollin/sdxl-vae-fp16-fix"
+
+accelerate launch train_dreambooth_lora_sdxl_advanced.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --pretrained_vae_model_name_or_path=$VAE_PATH \
+ --dataset_name=$DATASET_NAME \
+ --instance_prompt="3d icon in the style of TOK" \
+ --validation_prompt="a TOK icon of an astronaut riding a horse, in the style of TOK" \
+ --output_dir=$OUTPUT_DIR \
+ --caption_column="prompt" \
+ --mixed_precision="bf16" \
+ --resolution=1024 \
+ --train_batch_size=3 \
+ --repeats=1 \
+ --report_to="wandb"\
+ --gradient_accumulation_steps=1 \
+ --gradient_checkpointing \
+ --learning_rate=1.0 \
+ --text_encoder_lr=1.0 \
+ --optimizer="prodigy"\
+ --train_text_encoder_ti\
+ --train_text_encoder_ti_frac=0.5\
+ --snr_gamma=5.0 \
+ --lr_scheduler="constant" \
+ --lr_warmup_steps=0 \
+ --rank=8 \
+ --max_train_steps=1000 \
+ --checkpointing_steps=2000 \
+ --seed="0" \
+ --push_to_hub
+```
+
+To better track our training experiments, we're using the following flags in the command above:
+
+* `report_to="wandb` will ensure the training runs are tracked on Weights and Biases. To use it, be sure to install `wandb` with `pip install wandb`.
+* `validation_prompt` and `validation_epochs` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.
+
+Our experiments were conducted on a single 40GB A100 GPU.
+
+
+### Inference
+
+Once training is done, we can perform inference like so:
+1. starting with loading the unet lora weights
+```python
+import torch
+from huggingface_hub import hf_hub_download, upload_file
+from diffusers import DiffusionPipeline
+from diffusers.models import AutoencoderKL
+from safetensors.torch import load_file
+
+username = "linoyts"
+repo_id = f"{username}/3d-icon-SDXL-LoRA"
+
+pipe = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0",
+ torch_dtype=torch.float16,
+ variant="fp16",
+).to("cuda")
+
+
+pipe.load_lora_weights(repo_id, weight_name="pytorch_lora_weights.safetensors")
+```
+2. now we load the pivotal tuning embeddings
+
+```python
+text_encoders = [pipe.text_encoder, pipe.text_encoder_2]
+tokenizers = [pipe.tokenizer, pipe.tokenizer_2]
+
+embedding_path = hf_hub_download(repo_id=repo_id, filename="3d-icon-SDXL-LoRA_emb.safetensors", repo_type="model")
+
+state_dict = load_file(embedding_path)
+# load embeddings of text_encoder 1 (CLIP ViT-L/14)
+pipe.load_textual_inversion(state_dict["clip_l"], token=["", ""], text_encoder=pipe.text_encoder, tokenizer=pipe.tokenizer)
+# load embeddings of text_encoder 2 (CLIP ViT-G/14)
+pipe.load_textual_inversion(state_dict["clip_g"], token=["", ""], text_encoder=pipe.text_encoder_2, tokenizer=pipe.tokenizer_2)
+```
+
+3. let's generate images
+
+```python
+instance_token = ""
+prompt = f"a {instance_token} icon of an orange llama eating ramen, in the style of {instance_token}"
+
+image = pipe(prompt=prompt, num_inference_steps=25, cross_attention_kwargs={"scale": 1.0}).images[0]
+image.save("llama.png")
+```
+
+### Comfy UI / AUTOMATIC1111 Inference
+The new script fully supports textual inversion loading with Comfy UI and AUTOMATIC1111 formats!
+
+**AUTOMATIC1111 / SD.Next** \
+In AUTOMATIC1111/SD.Next we will load a LoRA and a textual embedding at the same time.
+- *LoRA*: Besides the diffusers format, the script will also train a WebUI compatible LoRA. It is generated as `{your_lora_name}.safetensors`. You can then include it in your `models/Lora` directory.
+- *Embedding*: the embedding is the same for diffusers and WebUI. You can download your `{lora_name}_emb.safetensors` file from a trained model, and include it in your `embeddings` directory.
+
+You can then run inference by prompting `a y2k_emb webpage about the movie Mean Girls `. You can use the `y2k_emb` token normally, including increasing its weight by doing `(y2k_emb:1.2)`.
+
+**ComfyUI** \
+In ComfyUI we will load a LoRA and a textual embedding at the same time.
+- *LoRA*: Besides the diffusers format, the script will also train a ComfyUI compatible LoRA. It is generated as `{your_lora_name}.safetensors`. You can then include it in your `models/Lora` directory. Then you will load the LoRALoader node and hook that up with your model and CLIP. [Official guide for loading LoRAs](https://comfyanonymous.github.io/ComfyUI_examples/lora/)
+- *Embedding*: the embedding is the same for diffusers and WebUI. You can download your `{lora_name}_emb.safetensors` file from a trained model, and include it in your `models/embeddings` directory and use it in your prompts like `embedding:y2k_emb`. [Official guide for loading embeddings](https://comfyanonymous.github.io/ComfyUI_examples/textual_inversion_embeddings/).
+-
+### Specifying a better VAE
+
+SDXL's VAE is known to suffer from numerical instability issues. This is why we also expose a CLI argument namely `--pretrained_vae_model_name_or_path` that lets you specify the location of a better VAE (such as [this one](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)).
+
+### DoRA training
+The advanced script supports DoRA training too!
+> Proposed in [DoRA: Weight-Decomposed Low-Rank Adaptation](https://arxiv.org/abs/2402.09353),
+**DoRA** is very similar to LoRA, except it decomposes the pre-trained weight into two components, **magnitude** and **direction** and employs LoRA for _directional_ updates to efficiently minimize the number of trainable parameters.
+The authors found that by using DoRA, both the learning capacity and training stability of LoRA are enhanced without any additional overhead during inference.
+
+> [!NOTE]
+> 💡DoRA training is still _experimental_
+> and is likely to require different hyperparameter values to perform best compared to a LoRA.
+> Specifically, we've noticed 2 differences to take into account your training:
+> 1. **LoRA seem to converge faster than DoRA** (so a set of parameters that may lead to overfitting when training a LoRA may be working well for a DoRA)
+> 2. **DoRA quality superior to LoRA especially in lower ranks** the difference in quality of DoRA of rank 8 and LoRA of rank 8 appears to be more significant than when training ranks of 32 or 64 for example.
+> This is also aligned with some of the quantitative analysis shown in the paper.
+
+**Usage**
+1. To use DoRA you need to install `peft` from main:
+```bash
+pip install git+https://github.com/huggingface/peft.git
+```
+2. Enable DoRA training by adding this flag
+```bash
+--use_dora
+```
+**Inference**
+The inference is the same as if you train a regular LoRA 🤗
+
+## Conducting EDM-style training
+
+It's now possible to perform EDM-style training as proposed in [Elucidating the Design Space of Diffusion-Based Generative Models](https://arxiv.org/abs/2206.00364).
+
+simply set:
+
+```diff
++ --do_edm_style_training \
+```
+
+Other SDXL-like models that use the EDM formulation, such as [playgroundai/playground-v2.5-1024px-aesthetic](https://huggingface.co/playgroundai/playground-v2.5-1024px-aesthetic), can also be DreamBooth'd with the script. Below is an example command:
+
+```bash
+accelerate launch train_dreambooth_lora_sdxl_advanced.py \
+ --pretrained_model_name_or_path="playgroundai/playground-v2.5-1024px-aesthetic" \
+ --dataset_name="linoyts/3d_icon" \
+ --instance_prompt="3d icon in the style of TOK" \
+ --validation_prompt="a TOK icon of an astronaut riding a horse, in the style of TOK" \
+ --output_dir="3d-icon-SDXL-LoRA" \
+ --do_edm_style_training \
+ --caption_column="prompt" \
+ --mixed_precision="bf16" \
+ --resolution=1024 \
+ --train_batch_size=3 \
+ --repeats=1 \
+ --report_to="wandb"\
+ --gradient_accumulation_steps=1 \
+ --gradient_checkpointing \
+ --learning_rate=1.0 \
+ --text_encoder_lr=1.0 \
+ --optimizer="prodigy"\
+ --train_text_encoder_ti\
+ --train_text_encoder_ti_frac=0.5\
+ --lr_scheduler="constant" \
+ --lr_warmup_steps=0 \
+ --rank=8 \
+ --max_train_steps=1000 \
+ --checkpointing_steps=2000 \
+ --seed="0" \
+ --push_to_hub
+```
+
+> [!CAUTION]
+> Min-SNR gamma is not supported with the EDM-style training yet. When training with the PlaygroundAI model, it's recommended to not pass any "variant".
+
+### B-LoRA training
+The advanced script now supports B-LoRA training too!
+> Proposed in [Implicit Style-Content Separation using B-LoRA](https://arxiv.org/abs/2403.14572),
+B-LoRA is a method that leverages LoRA to implicitly separate the style and content components of a **single** image.
+It was shown that learning the LoRA weights of two specific blocks (referred to as B-LoRAs)
+achieves style-content separation that cannot be achieved by training each B-LoRA independently.
+Once trained, the two B-LoRAs can be used as independent components to allow various image stylization tasks
+
+**Usage**
+Enable B-LoRA training by adding this flag
+```bash
+--use_blora
+```
+You can train a B-LoRA with as little as 1 image, and 1000 steps. Try this default configuration as a start:
+```bash
+!accelerate launch train_dreambooth_b-lora_sdxl.py \
+ --pretrained_model_name_or_path="stabilityai/stable-diffusion-xl-base-1.0" \
+ --instance_data_dir="linoyts/B-LoRA_teddy_bear" \
+ --output_dir="B-LoRA_teddy_bear" \
+ --instance_prompt="a [v18]" \
+ --resolution=1024 \
+ --rank=64 \
+ --train_batch_size=1 \
+ --learning_rate=5e-5 \
+ --lr_scheduler="constant" \
+ --lr_warmup_steps=0 \
+ --max_train_steps=1000 \
+ --checkpointing_steps=2000 \
+ --seed="0" \
+ --gradient_checkpointing \
+ --mixed_precision="fp16"
+```
+**Inference**
+The inference is a bit different:
+1. we need load *specific* unet layers (as opposed to a regular LoRA/DoRA)
+2. the trained layers we load, changes based on our objective (e.g. style/content)
+
+```python
+import torch
+from diffusers import StableDiffusionXLPipeline, AutoencoderKL
+
+# taken & modified from B-LoRA repo - https://github.com/yardenfren1996/B-LoRA/blob/main/blora_utils.py
+def is_belong_to_blocks(key, blocks):
+ try:
+ for g in blocks:
+ if g in key:
+ return True
+ return False
+ except Exception as e:
+ raise type(e)(f'failed to is_belong_to_block, due to: {e}')
+
+def lora_lora_unet_blocks(lora_path, alpha, target_blocks):
+ state_dict, _ = pipeline.lora_state_dict(lora_path)
+ filtered_state_dict = {k: v * alpha for k, v in state_dict.items() if is_belong_to_blocks(k, target_blocks)}
+ return filtered_state_dict
+
+vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
+pipeline = StableDiffusionXLPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0",
+ vae=vae,
+ torch_dtype=torch.float16,
+).to("cuda")
+
+# pick a blora for content/style (you can also set one to None)
+content_B_lora_path = "lora-library/B-LoRA-teddybear"
+style_B_lora_path= "lora-library/B-LoRA-pen_sketch"
+
+
+content_B_LoRA = lora_lora_unet_blocks(content_B_lora_path,alpha=1,target_blocks=["unet.up_blocks.0.attentions.0"])
+style_B_LoRA = lora_lora_unet_blocks(style_B_lora_path,alpha=1.1,target_blocks=["unet.up_blocks.0.attentions.1"])
+combined_lora = {**content_B_LoRA, **style_B_LoRA}
+
+# Load both loras
+pipeline.load_lora_into_unet(combined_lora, None, pipeline.unet)
+
+#generate
+prompt = "a [v18] in [v30] style"
+pipeline(prompt, num_images_per_prompt=4).images
+```
+### LoRA training of Targeted U-net Blocks
+The advanced script now supports custom choice of U-net blocks to train during Dreambooth LoRA tuning.
+> [!NOTE]
+> This feature is still experimental
+
+> Recently, works like B-LoRA showed the potential advantages of learning the LoRA weights of specific U-net blocks, not only in speed & memory,
+> but also in reducing the amount of needed data, improving style manipulation and overcoming overfitting issues.
+> In light of this, we're introducing a new feature to the advanced script to allow for configurable U-net learned blocks.
+
+**Usage**
+Configure LoRA learned U-net blocks adding a `lora_unet_blocks` flag, with a comma seperated string specifying the targeted blocks.
+e.g:
+```bash
+--lora_unet_blocks="unet.up_blocks.0.attentions.0,unet.up_blocks.0.attentions.1"
+```
+
+> [!NOTE]
+> if you specify both `--use_blora` and `--lora_unet_blocks`, values given in --lora_unet_blocks will be ignored.
+> When enabling --use_blora, targeted U-net blocks are automatically set to be "unet.up_blocks.0.attentions.0,unet.up_blocks.0.attentions.1" as discussed in the paper.
+> If you wish to experiment with different blocks, specify `--lora_unet_blocks` only.
+
+**Inference**
+Inference is the same as for B-LoRAs, except the input targeted blocks should be modified based on your training configuration.
+```python
+import torch
+from diffusers import StableDiffusionXLPipeline, AutoencoderKL
+
+# taken & modified from B-LoRA repo - https://github.com/yardenfren1996/B-LoRA/blob/main/blora_utils.py
+def is_belong_to_blocks(key, blocks):
+ try:
+ for g in blocks:
+ if g in key:
+ return True
+ return False
+ except Exception as e:
+ raise type(e)(f'failed to is_belong_to_block, due to: {e}')
+
+def lora_lora_unet_blocks(lora_path, alpha, target_blocks):
+ state_dict, _ = pipeline.lora_state_dict(lora_path)
+ filtered_state_dict = {k: v * alpha for k, v in state_dict.items() if is_belong_to_blocks(k, target_blocks)}
+ return filtered_state_dict
+
+vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
+pipeline = StableDiffusionXLPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0",
+ vae=vae,
+ torch_dtype=torch.float16,
+).to("cuda")
+
+lora_path = "lora-library/B-LoRA-pen_sketch"
+
+state_dict = lora_lora_unet_blocks(content_B_lora_path,alpha=1,target_blocks=["unet.up_blocks.0.attentions.0"])
+
+# Load trained lora layers into the unet
+pipeline.load_lora_into_unet(state_dict, None, pipeline.unet)
+
+#generate
+prompt = "a dog in [v30] style"
+pipeline(prompt, num_images_per_prompt=4).images
+```
+
+
+### Tips and Tricks
+Check out [these recommended practices](https://huggingface.co/blog/sdxl_lora_advanced_script#additional-good-practices)
+
+## Running on Colab Notebook
+Check out [this notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/SDXL_Dreambooth_LoRA_advanced_example.ipynb).
+to train using the advanced features (including pivotal tuning), and [this notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/SDXL_DreamBooth_LoRA_.ipynb) to train on a free colab, using some of the advanced features (excluding pivotal tuning)
+
diff --git a/diffusers/examples/advanced_diffusion_training/requirements.txt b/diffusers/examples/advanced_diffusion_training/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..3f86855e1d1e03c20874e9de4a6037a119509444
--- /dev/null
+++ b/diffusers/examples/advanced_diffusion_training/requirements.txt
@@ -0,0 +1,7 @@
+accelerate>=0.16.0
+torchvision
+transformers>=4.25.1
+ftfy
+tensorboard
+Jinja2
+peft==0.7.0
\ No newline at end of file
diff --git a/diffusers/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py b/diffusers/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py
new file mode 100644
index 0000000000000000000000000000000000000000..7e1a0298ba1d88716692910d0ec9c6740601e9e1
--- /dev/null
+++ b/diffusers/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py
@@ -0,0 +1,2013 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+import argparse
+import gc
+import hashlib
+import itertools
+import logging
+import math
+import os
+import re
+import shutil
+import warnings
+from contextlib import nullcontext
+from pathlib import Path
+from typing import List, Optional
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+
+# imports of the TokenEmbeddingsHandler class
+import torch.utils.checkpoint
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
+from huggingface_hub import create_repo, upload_folder
+from packaging import version
+from peft import LoraConfig
+from peft.utils import get_peft_model_state_dict
+from PIL import Image
+from PIL.ImageOps import exif_transpose
+from safetensors.torch import load_file, save_file
+from torch.utils.data import Dataset
+from torchvision import transforms
+from tqdm.auto import tqdm
+from transformers import AutoTokenizer, PretrainedConfig
+
+import diffusers
+from diffusers import (
+ AutoencoderKL,
+ DDPMScheduler,
+ DPMSolverMultistepScheduler,
+ StableDiffusionPipeline,
+ UNet2DConditionModel,
+)
+from diffusers.loaders import StableDiffusionLoraLoaderMixin
+from diffusers.optimization import get_scheduler
+from diffusers.training_utils import compute_snr
+from diffusers.utils import (
+ check_min_version,
+ convert_all_state_dict_to_peft,
+ convert_state_dict_to_diffusers,
+ convert_state_dict_to_kohya,
+ is_wandb_available,
+)
+from diffusers.utils.import_utils import is_xformers_available
+
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.31.0.dev0")
+
+logger = get_logger(__name__)
+
+
+def save_model_card(
+ repo_id: str,
+ use_dora: bool,
+ images=None,
+ base_model=str,
+ train_text_encoder=False,
+ train_text_encoder_ti=False,
+ token_abstraction_dict=None,
+ instance_prompt=str,
+ validation_prompt=str,
+ repo_folder=None,
+ vae_path=None,
+):
+ img_str = "widget:\n"
+ lora = "lora" if not use_dora else "dora"
+ for i, image in enumerate(images):
+ image.save(os.path.join(repo_folder, f"image_{i}.png"))
+ img_str += f"""
+ - text: '{validation_prompt if validation_prompt else ' ' }'
+ output:
+ url:
+ "image_{i}.png"
+ """
+ if not images:
+ img_str += f"""
+ - text: '{instance_prompt}'
+ """
+ embeddings_filename = f"{repo_folder}_emb"
+ instance_prompt_webui = re.sub(r"", "", re.sub(r"", embeddings_filename, instance_prompt, count=1))
+ ti_keys = ", ".join(f'"{match}"' for match in re.findall(r"", instance_prompt))
+ if instance_prompt_webui != embeddings_filename:
+ instance_prompt_sentence = f"For example, `{instance_prompt_webui}`"
+ else:
+ instance_prompt_sentence = ""
+ trigger_str = f"You should use {instance_prompt} to trigger the image generation."
+ diffusers_imports_pivotal = ""
+ diffusers_example_pivotal = ""
+ webui_example_pivotal = ""
+ if train_text_encoder_ti:
+ trigger_str = (
+ "To trigger image generation of trained concept(or concepts) replace each concept identifier "
+ "in you prompt with the new inserted tokens:\n"
+ )
+ diffusers_imports_pivotal = """from huggingface_hub import hf_hub_download
+from safetensors.torch import load_file
+ """
+ diffusers_example_pivotal = f"""embedding_path = hf_hub_download(repo_id='{repo_id}', filename='{embeddings_filename}.safetensors', repo_type="model")
+state_dict = load_file(embedding_path)
+pipeline.load_textual_inversion(state_dict["clip_l"], token=[{ti_keys}], text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer)
+ """
+ webui_example_pivotal = f"""- *Embeddings*: download **[`{embeddings_filename}.safetensors` here 💾](/{repo_id}/blob/main/{embeddings_filename}.safetensors)**.
+ - Place it on it on your `embeddings` folder
+ - Use it by adding `{embeddings_filename}` to your prompt. {instance_prompt_sentence}
+ (you need both the LoRA and the embeddings as they were trained together for this LoRA)
+ """
+ if token_abstraction_dict:
+ for key, value in token_abstraction_dict.items():
+ tokens = "".join(value)
+ trigger_str += f"""
+to trigger concept `{key}` → use `{tokens}` in your prompt \n
+"""
+
+ yaml = f"""---
+tags:
+- stable-diffusion
+- stable-diffusion-diffusers
+- diffusers-training
+- text-to-image
+- diffusers
+- {lora}
+- template:sd-lora
+{img_str}
+base_model: {base_model}
+instance_prompt: {instance_prompt}
+license: openrail++
+---
+"""
+
+ model_card = f"""
+# SD1.5 LoRA DreamBooth - {repo_id}
+
+
+
+## Model description
+
+### These are {repo_id} LoRA adaption weights for {base_model}.
+
+## Download model
+
+### Use it with UIs such as AUTOMATIC1111, Comfy UI, SD.Next, Invoke
+
+- **LoRA**: download **[`{repo_folder}.safetensors` here 💾](/{repo_id}/blob/main/{repo_folder}.safetensors)**.
+ - Place it on your `models/Lora` folder.
+ - On AUTOMATIC1111, load the LoRA by adding `` to your prompt. On ComfyUI just [load it as a regular LoRA](https://comfyanonymous.github.io/ComfyUI_examples/lora/).
+{webui_example_pivotal}
+
+## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers)
+
+```py
+from diffusers import AutoPipelineForText2Image
+import torch
+{diffusers_imports_pivotal}
+pipeline = AutoPipelineForText2Image.from_pretrained('runwayml/stable-diffusion-v1-5', torch_dtype=torch.float16).to('cuda')
+pipeline.load_lora_weights('{repo_id}', weight_name='pytorch_lora_weights.safetensors')
+{diffusers_example_pivotal}
+image = pipeline('{validation_prompt if validation_prompt else instance_prompt}').images[0]
+```
+
+For more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters)
+
+## Trigger words
+
+{trigger_str}
+
+## Details
+All [Files & versions](/{repo_id}/tree/main).
+
+The weights were trained using [🧨 diffusers Advanced Dreambooth Training Script](https://github.com/huggingface/diffusers/blob/main/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py).
+
+LoRA for the text encoder was enabled. {train_text_encoder}.
+
+Pivotal tuning was enabled: {train_text_encoder_ti}.
+
+Special VAE used for training: {vae_path}.
+
+"""
+ with open(os.path.join(repo_folder, "README.md"), "w") as f:
+ f.write(yaml + model_card)
+
+
+def import_model_class_from_model_name_or_path(
+ pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
+):
+ text_encoder_config = PretrainedConfig.from_pretrained(
+ pretrained_model_name_or_path, subfolder=subfolder, revision=revision
+ )
+ model_class = text_encoder_config.architectures[0]
+
+ if model_class == "CLIPTextModel":
+ from transformers import CLIPTextModel
+
+ return CLIPTextModel
+ elif model_class == "CLIPTextModelWithProjection":
+ from transformers import CLIPTextModelWithProjection
+
+ return CLIPTextModelWithProjection
+ else:
+ raise ValueError(f"{model_class} is not supported.")
+
+
+def parse_args(input_args=None):
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--pretrained_vae_model_name_or_path",
+ type=str,
+ default=None,
+ help="Path to pretrained VAE model with better numerical stability. More details: https://github.com/huggingface/diffusers/pull/4038.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--variant",
+ type=str,
+ default=None,
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
+ )
+ parser.add_argument(
+ "--dataset_name",
+ type=str,
+ default=None,
+ help=(
+ "The name of the Dataset (from the HuggingFace hub) containing the training data of instance images (could be your own, possibly private,"
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
+ " or to a folder containing files that 🤗 Datasets can understand.To load the custom captions, the training set directory needs to follow the structure of a "
+ "datasets ImageFolder, containing both the images and the corresponding caption for each image. see: "
+ "https://huggingface.co/docs/datasets/image_dataset for more information"
+ ),
+ )
+ parser.add_argument(
+ "--dataset_config_name",
+ type=str,
+ default=None,
+ help="The config of the Dataset. In some cases, a dataset may have more than one configuration (for example "
+ "if it contains different subsets of data within, and you only wish to load a specific subset - in that case specify the desired configuration using --dataset_config_name. Leave as "
+ "None if there's only one config.",
+ )
+ parser.add_argument(
+ "--instance_data_dir",
+ type=str,
+ default=None,
+ help="A path to local folder containing the training data of instance images. Specify this arg instead of "
+ "--dataset_name if you wish to train using a local folder without custom captions. If you wish to train with custom captions please specify "
+ "--dataset_name instead.",
+ )
+
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+
+ parser.add_argument(
+ "--image_column",
+ type=str,
+ default="image",
+ help="The column of the dataset containing the target image. By "
+ "default, the standard Image Dataset maps out 'file_name' "
+ "to 'image'.",
+ )
+ parser.add_argument(
+ "--caption_column",
+ type=str,
+ default=None,
+ help="The column of the dataset containing the instance prompt for each image",
+ )
+
+ parser.add_argument("--repeats", type=int, default=1, help="How many times to repeat the training data.")
+
+ parser.add_argument(
+ "--class_data_dir",
+ type=str,
+ default=None,
+ required=False,
+ help="A folder containing the training data of class images.",
+ )
+ parser.add_argument(
+ "--instance_prompt",
+ type=str,
+ default=None,
+ required=True,
+ help="The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'",
+ )
+ parser.add_argument(
+ "--token_abstraction",
+ type=str,
+ default="TOK",
+ help="identifier specifying the instance(or instances) as used in instance_prompt, validation prompt, "
+ "captions - e.g. TOK. To use multiple identifiers, please specify them in a comma separated string - e.g. "
+ "'TOK,TOK2,TOK3' etc.",
+ )
+
+ parser.add_argument(
+ "--num_new_tokens_per_abstraction",
+ type=int,
+ default=2,
+ help="number of new tokens inserted to the tokenizers per token_abstraction identifier when "
+ "--train_text_encoder_ti = True. By default, each --token_abstraction (e.g. TOK) is mapped to 2 new "
+ "tokens - ",
+ )
+
+ parser.add_argument(
+ "--class_prompt",
+ type=str,
+ default=None,
+ help="The prompt to specify images in the same class as provided instance images.",
+ )
+ parser.add_argument(
+ "--validation_prompt",
+ type=str,
+ default=None,
+ help="A prompt that is used during validation to verify that the model is learning.",
+ )
+ parser.add_argument(
+ "--num_validation_images",
+ type=int,
+ default=4,
+ help="Number of images that should be generated during validation with `validation_prompt`.",
+ )
+ parser.add_argument(
+ "--validation_epochs",
+ type=int,
+ default=50,
+ help=(
+ "Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt"
+ " `args.validation_prompt` multiple times: `args.num_validation_images`."
+ ),
+ )
+ parser.add_argument(
+ "--with_prior_preservation",
+ default=False,
+ action="store_true",
+ help="Flag to add prior preservation loss.",
+ )
+ parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.")
+ parser.add_argument(
+ "--num_class_images",
+ type=int,
+ default=100,
+ help=(
+ "Minimal class images for prior preservation loss. If there are not enough images already present in"
+ " class_data_dir, additional images will be sampled with class_prompt."
+ ),
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="lora-dreambooth-model",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--center_crop",
+ default=False,
+ action="store_true",
+ help=(
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
+ " cropped. The images will be resized to the resolution first before cropping."
+ ),
+ )
+ parser.add_argument(
+ "--train_text_encoder",
+ action="store_true",
+ help="Whether to train the text encoder. If set, the text encoder should be float32 precision.",
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument(
+ "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=1)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
+ " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+
+ parser.add_argument(
+ "--text_encoder_lr",
+ type=float,
+ default=5e-6,
+ help="Text encoder learning rate to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+
+ parser.add_argument(
+ "--snr_gamma",
+ type=float,
+ default=None,
+ help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
+ "More details here: https://arxiv.org/abs/2303.09556.",
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--lr_num_cycles",
+ type=int,
+ default=1,
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
+ )
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+
+ parser.add_argument(
+ "--train_text_encoder_ti",
+ action="store_true",
+ help=("Whether to use textual inversion"),
+ )
+
+ parser.add_argument(
+ "--train_text_encoder_ti_frac",
+ type=float,
+ default=0.5,
+ help=("The percentage of epochs to perform textual inversion"),
+ )
+
+ parser.add_argument(
+ "--train_text_encoder_frac",
+ type=float,
+ default=1.0,
+ help=("The percentage of epochs to perform text encoder tuning"),
+ )
+
+ parser.add_argument(
+ "--optimizer",
+ type=str,
+ default="adamW",
+ help=('The optimizer type to use. Choose between ["AdamW", "prodigy"]'),
+ )
+
+ parser.add_argument(
+ "--use_8bit_adam",
+ action="store_true",
+ help="Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW",
+ )
+
+ parser.add_argument(
+ "--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam and Prodigy optimizers."
+ )
+ parser.add_argument(
+ "--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam and Prodigy optimizers."
+ )
+ parser.add_argument(
+ "--prodigy_beta3",
+ type=float,
+ default=None,
+ help="coefficients for computing the Prodigy stepsize using running averages. If set to None, "
+ "uses the value of square root of beta2. Ignored if optimizer is adamW",
+ )
+ parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params")
+ parser.add_argument(
+ "--adam_weight_decay_text_encoder", type=float, default=None, help="Weight decay to use for text_encoder"
+ )
+
+ parser.add_argument(
+ "--adam_epsilon",
+ type=float,
+ default=1e-08,
+ help="Epsilon value for the Adam optimizer and Prodigy optimizers.",
+ )
+
+ parser.add_argument(
+ "--prodigy_use_bias_correction",
+ type=bool,
+ default=True,
+ help="Turn on Adam's bias correction. True by default. Ignored if optimizer is adamW",
+ )
+ parser.add_argument(
+ "--prodigy_safeguard_warmup",
+ type=bool,
+ default=True,
+ help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. "
+ "Ignored if optimizer is adamW",
+ )
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument(
+ "--prior_generation_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp32", "fp16", "bf16"],
+ help=(
+ "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32."
+ ),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+ parser.add_argument(
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
+ )
+ parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.")
+ parser.add_argument(
+ "--rank",
+ type=int,
+ default=4,
+ help=("The dimension of the LoRA update matrices."),
+ )
+ parser.add_argument(
+ "--use_dora",
+ action="store_true",
+ default=False,
+ help=(
+ "Wether to train a DoRA as proposed in- DoRA: Weight-Decomposed Low-Rank Adaptation https://arxiv.org/abs/2402.09353. "
+ "Note: to use DoRA you need to install peft from main, `pip install git+https://github.com/huggingface/peft.git`"
+ ),
+ )
+ parser.add_argument(
+ "--cache_latents",
+ action="store_true",
+ default=False,
+ help="Cache the VAE latents",
+ )
+
+ if input_args is not None:
+ args = parser.parse_args(input_args)
+ else:
+ args = parser.parse_args()
+
+ if args.dataset_name is None and args.instance_data_dir is None:
+ raise ValueError("Specify either `--dataset_name` or `--instance_data_dir`")
+
+ if args.dataset_name is not None and args.instance_data_dir is not None:
+ raise ValueError("Specify only one of `--dataset_name` or `--instance_data_dir`")
+
+ if args.train_text_encoder and args.train_text_encoder_ti:
+ raise ValueError(
+ "Specify only one of `--train_text_encoder` or `--train_text_encoder_ti. "
+ "For full LoRA text encoder training check --train_text_encoder, for textual "
+ "inversion training check `--train_text_encoder_ti`"
+ )
+
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ if args.with_prior_preservation:
+ if args.class_data_dir is None:
+ raise ValueError("You must specify a data directory for class images.")
+ if args.class_prompt is None:
+ raise ValueError("You must specify prompt for class images.")
+ else:
+ # logger is not available yet
+ if args.class_data_dir is not None:
+ warnings.warn("You need not use --class_data_dir without --with_prior_preservation.")
+ if args.class_prompt is not None:
+ warnings.warn("You need not use --class_prompt without --with_prior_preservation.")
+
+ return args
+
+
+# Taken from https://github.com/replicate/cog-sdxl/blob/main/dataset_and_utils.py
+class TokenEmbeddingsHandler:
+ def __init__(self, text_encoders, tokenizers):
+ self.text_encoders = text_encoders
+ self.tokenizers = tokenizers
+
+ self.train_ids: Optional[torch.Tensor] = None
+ self.inserting_toks: Optional[List[str]] = None
+ self.embeddings_settings = {}
+
+ def initialize_new_tokens(self, inserting_toks: List[str]):
+ idx = 0
+ for tokenizer, text_encoder in zip(self.tokenizers, self.text_encoders):
+ assert isinstance(inserting_toks, list), "inserting_toks should be a list of strings."
+ assert all(
+ isinstance(tok, str) for tok in inserting_toks
+ ), "All elements in inserting_toks should be strings."
+
+ self.inserting_toks = inserting_toks
+ special_tokens_dict = {"additional_special_tokens": self.inserting_toks}
+ tokenizer.add_special_tokens(special_tokens_dict)
+ text_encoder.resize_token_embeddings(len(tokenizer))
+
+ self.train_ids = tokenizer.convert_tokens_to_ids(self.inserting_toks)
+
+ # random initialization of new tokens
+ std_token_embedding = text_encoder.text_model.embeddings.token_embedding.weight.data.std()
+
+ print(f"{idx} text encoder's std_token_embedding: {std_token_embedding}")
+
+ text_encoder.text_model.embeddings.token_embedding.weight.data[self.train_ids] = (
+ torch.randn(len(self.train_ids), text_encoder.text_model.config.hidden_size)
+ .to(device=self.device)
+ .to(dtype=self.dtype)
+ * std_token_embedding
+ )
+ self.embeddings_settings[
+ f"original_embeddings_{idx}"
+ ] = text_encoder.text_model.embeddings.token_embedding.weight.data.clone()
+ self.embeddings_settings[f"std_token_embedding_{idx}"] = std_token_embedding
+
+ inu = torch.ones((len(tokenizer),), dtype=torch.bool)
+ inu[self.train_ids] = False
+
+ self.embeddings_settings[f"index_no_updates_{idx}"] = inu
+
+ print(self.embeddings_settings[f"index_no_updates_{idx}"].shape)
+
+ idx += 1
+
+ # Copied from train_dreambooth_lora_sdxl_advanced.py
+ def save_embeddings(self, file_path: str):
+ assert self.train_ids is not None, "Initialize new tokens before saving embeddings."
+ tensors = {}
+ # text_encoder_0 - CLIP ViT-L/14, text_encoder_1 - CLIP ViT-G/14 - TODO - change for sd
+ idx_to_text_encoder_name = {0: "clip_l", 1: "clip_g"}
+ for idx, text_encoder in enumerate(self.text_encoders):
+ assert text_encoder.text_model.embeddings.token_embedding.weight.data.shape[0] == len(
+ self.tokenizers[0]
+ ), "Tokenizers should be the same."
+ new_token_embeddings = text_encoder.text_model.embeddings.token_embedding.weight.data[self.train_ids]
+
+ # New tokens for each text encoder are saved under "clip_l" (for text_encoder 0), "clip_g" (for
+ # text_encoder 1) to keep compatible with the ecosystem.
+ # Note: When loading with diffusers, any name can work - simply specify in inference
+ tensors[idx_to_text_encoder_name[idx]] = new_token_embeddings
+ # tensors[f"text_encoders_{idx}"] = new_token_embeddings
+
+ save_file(tensors, file_path)
+
+ @property
+ def dtype(self):
+ return self.text_encoders[0].dtype
+
+ @property
+ def device(self):
+ return self.text_encoders[0].device
+
+ @torch.no_grad()
+ def retract_embeddings(self):
+ for idx, text_encoder in enumerate(self.text_encoders):
+ index_no_updates = self.embeddings_settings[f"index_no_updates_{idx}"]
+ text_encoder.text_model.embeddings.token_embedding.weight.data[index_no_updates] = (
+ self.embeddings_settings[f"original_embeddings_{idx}"][index_no_updates]
+ .to(device=text_encoder.device)
+ .to(dtype=text_encoder.dtype)
+ )
+
+ # for the parts that were updated, we need to normalize them
+ # to have the same std as before
+ std_token_embedding = self.embeddings_settings[f"std_token_embedding_{idx}"]
+
+ index_updates = ~index_no_updates
+ new_embeddings = text_encoder.text_model.embeddings.token_embedding.weight.data[index_updates]
+ off_ratio = std_token_embedding / new_embeddings.std()
+
+ new_embeddings = new_embeddings * (off_ratio**0.1)
+ text_encoder.text_model.embeddings.token_embedding.weight.data[index_updates] = new_embeddings
+
+
+class DreamBoothDataset(Dataset):
+ """
+ A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
+ It pre-processes the images.
+ """
+
+ def __init__(
+ self,
+ instance_data_root,
+ instance_prompt,
+ class_prompt,
+ dataset_name,
+ dataset_config_name,
+ cache_dir,
+ image_column,
+ caption_column,
+ train_text_encoder_ti,
+ class_data_root=None,
+ class_num=None,
+ token_abstraction_dict=None, # token mapping for textual inversion
+ size=1024,
+ repeats=1,
+ center_crop=False,
+ ):
+ self.size = size
+ self.center_crop = center_crop
+
+ self.instance_prompt = instance_prompt
+ self.custom_instance_prompts = None
+ self.class_prompt = class_prompt
+ self.token_abstraction_dict = token_abstraction_dict
+ self.train_text_encoder_ti = train_text_encoder_ti
+ # if --dataset_name is provided or a metadata jsonl file is provided in the local --instance_data directory,
+ # we load the training data using load_dataset
+ if dataset_name is not None:
+ try:
+ from datasets import load_dataset
+ except ImportError:
+ raise ImportError(
+ "You are trying to load your data using the datasets library. If you wish to train using custom "
+ "captions please install the datasets library: `pip install datasets`. If you wish to load a "
+ "local folder containing images only, specify --instance_data_dir instead."
+ )
+ # Downloading and loading a dataset from the hub.
+ # See more about loading custom images at
+ # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script
+ dataset = load_dataset(
+ dataset_name,
+ dataset_config_name,
+ cache_dir=cache_dir,
+ )
+ # Preprocessing the datasets.
+ column_names = dataset["train"].column_names
+
+ # 6. Get the column names for input/target.
+ if image_column is None:
+ image_column = column_names[0]
+ logger.info(f"image column defaulting to {image_column}")
+ else:
+ if image_column not in column_names:
+ raise ValueError(
+ f"`--image_column` value '{image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
+ )
+ instance_images = dataset["train"][image_column]
+
+ if caption_column is None:
+ logger.info(
+ "No caption column provided, defaulting to instance_prompt for all images. If your dataset "
+ "contains captions/prompts for the images, make sure to specify the "
+ "column as --caption_column"
+ )
+ self.custom_instance_prompts = None
+ else:
+ if caption_column not in column_names:
+ raise ValueError(
+ f"`--caption_column` value '{caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
+ )
+ custom_instance_prompts = dataset["train"][caption_column]
+ # create final list of captions according to --repeats
+ self.custom_instance_prompts = []
+ for caption in custom_instance_prompts:
+ self.custom_instance_prompts.extend(itertools.repeat(caption, repeats))
+ else:
+ self.instance_data_root = Path(instance_data_root)
+ if not self.instance_data_root.exists():
+ raise ValueError("Instance images root doesn't exists.")
+
+ instance_images = [Image.open(path) for path in list(Path(instance_data_root).iterdir())]
+ self.custom_instance_prompts = None
+
+ self.instance_images = []
+ for img in instance_images:
+ self.instance_images.extend(itertools.repeat(img, repeats))
+ self.num_instance_images = len(self.instance_images)
+ self._length = self.num_instance_images
+
+ if class_data_root is not None:
+ self.class_data_root = Path(class_data_root)
+ self.class_data_root.mkdir(parents=True, exist_ok=True)
+ self.class_images_path = list(self.class_data_root.iterdir())
+ if class_num is not None:
+ self.num_class_images = min(len(self.class_images_path), class_num)
+ else:
+ self.num_class_images = len(self.class_images_path)
+ self._length = max(self.num_class_images, self.num_instance_images)
+ else:
+ self.class_data_root = None
+
+ self.image_transforms = transforms.Compose(
+ [
+ transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
+ transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+ )
+
+ def __len__(self):
+ return self._length
+
+ def __getitem__(self, index):
+ example = {}
+ instance_image = self.instance_images[index % self.num_instance_images]
+ instance_image = exif_transpose(instance_image)
+
+ if not instance_image.mode == "RGB":
+ instance_image = instance_image.convert("RGB")
+ example["instance_images"] = self.image_transforms(instance_image)
+
+ if self.custom_instance_prompts:
+ caption = self.custom_instance_prompts[index % self.num_instance_images]
+ if caption:
+ if self.train_text_encoder_ti:
+ # replace instances of --token_abstraction in caption with the new tokens: "" etc.
+ for token_abs, token_replacement in self.token_abstraction_dict.items():
+ caption = caption.replace(token_abs, "".join(token_replacement))
+ example["instance_prompt"] = caption
+ else:
+ example["instance_prompt"] = self.instance_prompt
+
+ else: # custom prompts were provided, but length does not match size of image dataset
+ example["instance_prompt"] = self.instance_prompt
+
+ if self.class_data_root:
+ class_image = Image.open(self.class_images_path[index % self.num_class_images])
+ class_image = exif_transpose(class_image)
+
+ if not class_image.mode == "RGB":
+ class_image = class_image.convert("RGB")
+ example["class_images"] = self.image_transforms(class_image)
+ example["class_prompt"] = self.class_prompt
+
+ return example
+
+
+def collate_fn(examples, with_prior_preservation=False):
+ pixel_values = [example["instance_images"] for example in examples]
+ prompts = [example["instance_prompt"] for example in examples]
+
+ # Concat class and instance examples for prior preservation.
+ # We do this to avoid doing two forward passes.
+ if with_prior_preservation:
+ pixel_values += [example["class_images"] for example in examples]
+ prompts += [example["class_prompt"] for example in examples]
+
+ pixel_values = torch.stack(pixel_values)
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
+
+ batch = {"pixel_values": pixel_values, "prompts": prompts}
+ return batch
+
+
+class PromptDataset(Dataset):
+ """A simple dataset to prepare the prompts to generate class images on multiple GPUs."""
+
+ def __init__(self, prompt, num_samples):
+ self.prompt = prompt
+ self.num_samples = num_samples
+
+ def __len__(self):
+ return self.num_samples
+
+ def __getitem__(self, index):
+ example = {}
+ example["prompt"] = self.prompt
+ example["index"] = index
+ return example
+
+
+def tokenize_prompt(tokenizer, prompt, add_special_tokens=False):
+ text_inputs = tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ add_special_tokens=add_special_tokens,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ return text_input_ids
+
+
+# Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt
+def encode_prompt(text_encoders, tokenizers, prompt, text_input_ids_list=None):
+ for i, text_encoder in enumerate(text_encoders):
+ if tokenizers is not None:
+ tokenizer = tokenizers[i]
+ text_input_ids = tokenize_prompt(tokenizer, prompt)
+ else:
+ assert text_input_ids_list is not None
+ text_input_ids = text_input_ids_list[i]
+
+ prompt_embeds = text_encoder(
+ text_input_ids.to(text_encoder.device),
+ output_hidden_states=True,
+ )
+
+ return prompt_embeds[0]
+
+
+def main(args):
+ if args.report_to == "wandb" and args.hub_token is not None:
+ raise ValueError(
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
+ " Please use `huggingface-cli login` to authenticate with the Hub."
+ )
+
+ logging_dir = Path(args.output_dir, args.logging_dir)
+
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+ kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ kwargs_handlers=[kwargs],
+ )
+
+ if args.report_to == "wandb":
+ if not is_wandb_available():
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
+ import wandb
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Generate class images if prior preservation is enabled.
+ if args.with_prior_preservation:
+ class_images_dir = Path(args.class_data_dir)
+ if not class_images_dir.exists():
+ class_images_dir.mkdir(parents=True)
+ cur_class_images = len(list(class_images_dir.iterdir()))
+
+ if cur_class_images < args.num_class_images:
+ torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
+ if args.prior_generation_precision == "fp32":
+ torch_dtype = torch.float32
+ elif args.prior_generation_precision == "fp16":
+ torch_dtype = torch.float16
+ elif args.prior_generation_precision == "bf16":
+ torch_dtype = torch.bfloat16
+ pipeline = StableDiffusionPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ torch_dtype=torch_dtype,
+ revision=args.revision,
+ variant=args.variant,
+ )
+ pipeline.set_progress_bar_config(disable=True)
+
+ num_new_images = args.num_class_images - cur_class_images
+ logger.info(f"Number of class images to sample: {num_new_images}.")
+
+ sample_dataset = PromptDataset(args.class_prompt, num_new_images)
+ sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
+
+ sample_dataloader = accelerator.prepare(sample_dataloader)
+ pipeline.to(accelerator.device)
+
+ for example in tqdm(
+ sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
+ ):
+ images = pipeline(example["prompt"]).images
+
+ for i, image in enumerate(images):
+ hash_image = hashlib.sha1(image.tobytes()).hexdigest()
+ image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
+ image.save(image_filename)
+
+ del pipeline
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ model_id = args.hub_model_id or Path(args.output_dir).name
+ repo_id = None
+ if args.push_to_hub:
+ repo_id = create_repo(repo_id=model_id, exist_ok=True, token=args.hub_token).repo_id
+
+ # Load the tokenizers
+ tokenizer_one = AutoTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="tokenizer",
+ revision=args.revision,
+ variant=args.variant,
+ use_fast=False,
+ )
+
+ # import correct text encoder classes
+ text_encoder_cls_one = import_model_class_from_model_name_or_path(
+ args.pretrained_model_name_or_path, args.revision
+ )
+
+ # Load scheduler and models
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
+ text_encoder_one = text_encoder_cls_one.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
+ )
+ vae_path = (
+ args.pretrained_model_name_or_path
+ if args.pretrained_vae_model_name_or_path is None
+ else args.pretrained_vae_model_name_or_path
+ )
+ vae = AutoencoderKL.from_pretrained(
+ vae_path,
+ subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
+ revision=args.revision,
+ variant=args.variant,
+ )
+ vae_scaling_factor = vae.config.scaling_factor
+ unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
+ )
+
+ if args.train_text_encoder_ti:
+ # we parse the provided token identifier (or identifiers) into a list. s.t. - "TOK" -> ["TOK"], "TOK,
+ # TOK2" -> ["TOK", "TOK2"] etc.
+ token_abstraction_list = "".join(args.token_abstraction.split()).split(",")
+ logger.info(f"list of token identifiers: {token_abstraction_list}")
+
+ token_abstraction_dict = {}
+ token_idx = 0
+ for i, token in enumerate(token_abstraction_list):
+ token_abstraction_dict[token] = [
+ f"" for j in range(args.num_new_tokens_per_abstraction)
+ ]
+ token_idx += args.num_new_tokens_per_abstraction - 1
+
+ # replace instances of --token_abstraction in --instance_prompt with the new tokens: "" etc.
+ for token_abs, token_replacement in token_abstraction_dict.items():
+ args.instance_prompt = args.instance_prompt.replace(token_abs, "".join(token_replacement))
+ if args.with_prior_preservation:
+ args.class_prompt = args.class_prompt.replace(token_abs, "".join(token_replacement))
+
+ # initialize the new tokens for textual inversion
+ embedding_handler = TokenEmbeddingsHandler([text_encoder_one], [tokenizer_one])
+ inserting_toks = []
+ for new_tok in token_abstraction_dict.values():
+ inserting_toks.extend(new_tok)
+ embedding_handler.initialize_new_tokens(inserting_toks=inserting_toks)
+
+ # We only train the additional adapter LoRA layers
+ vae.requires_grad_(False)
+ text_encoder_one.requires_grad_(False)
+ unet.requires_grad_(False)
+
+ # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
+ # as these weights are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ # Move unet, vae and text_encoder to device and cast to weight_dtype
+ unet.to(accelerator.device, dtype=weight_dtype)
+
+ # The VAE is always in float32 to avoid NaN losses.
+ vae.to(accelerator.device, dtype=torch.float32)
+
+ text_encoder_one.to(accelerator.device, dtype=weight_dtype)
+
+ if args.enable_xformers_memory_efficient_attention:
+ if is_xformers_available():
+ import xformers
+
+ xformers_version = version.parse(xformers.__version__)
+ if xformers_version == version.parse("0.0.16"):
+ logger.warning(
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, "
+ "please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
+ )
+ unet.enable_xformers_memory_efficient_attention()
+ else:
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
+
+ if args.gradient_checkpointing:
+ unet.enable_gradient_checkpointing()
+ if args.train_text_encoder:
+ text_encoder_one.gradient_checkpointing_enable()
+
+ # now we will add new LoRA weights to the attention layers
+ unet_lora_config = LoraConfig(
+ r=args.rank,
+ lora_alpha=args.rank,
+ use_dora=args.use_dora,
+ init_lora_weights="gaussian",
+ target_modules=["to_k", "to_q", "to_v", "to_out.0"],
+ )
+ unet.add_adapter(unet_lora_config)
+
+ # The text encoder comes from 🤗 transformers, so we cannot directly modify it.
+ # So, instead, we monkey-patch the forward calls of its attention-blocks.
+ if args.train_text_encoder:
+ text_lora_config = LoraConfig(
+ r=args.rank,
+ lora_alpha=args.rank,
+ use_dora=args.use_dora,
+ init_lora_weights="gaussian",
+ target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
+ )
+ text_encoder_one.add_adapter(text_lora_config)
+
+ # if we use textual inversion, we freeze all parameters except for the token embeddings
+ # in text encoder
+ elif args.train_text_encoder_ti:
+ text_lora_parameters_one = []
+ for name, param in text_encoder_one.named_parameters():
+ if "token_embedding" in name:
+ # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
+ param = param.to(dtype=torch.float32)
+ param.requires_grad = True
+ text_lora_parameters_one.append(param)
+ else:
+ param.requires_grad = False
+
+ # Make sure the trainable params are in float32.
+ if args.mixed_precision == "fp16":
+ models = [unet]
+ if args.train_text_encoder:
+ models.extend([text_encoder_one])
+ for model in models:
+ for param in model.parameters():
+ # only upcast trainable parameters (LoRA) into fp32
+ if param.requires_grad:
+ param.data = param.to(torch.float32)
+
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
+ def save_model_hook(models, weights, output_dir):
+ if accelerator.is_main_process:
+ # there are only two options here. Either are just the unet attn processor layers
+ # or there are the unet and text encoder atten layers
+ unet_lora_layers_to_save = None
+ text_encoder_one_lora_layers_to_save = None
+
+ for model in models:
+ if isinstance(model, type(accelerator.unwrap_model(unet))):
+ unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model))
+ elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
+ if args.train_text_encoder:
+ text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers(
+ get_peft_model_state_dict(model)
+ )
+ else:
+ raise ValueError(f"unexpected save model: {model.__class__}")
+
+ # make sure to pop weight so that corresponding model is not saved again
+ weights.pop()
+
+ StableDiffusionPipeline.save_lora_weights(
+ output_dir,
+ unet_lora_layers=unet_lora_layers_to_save,
+ text_encoder_lora_layers=text_encoder_one_lora_layers_to_save,
+ )
+ if args.train_text_encoder_ti:
+ embedding_handler.save_embeddings(f"{args.output_dir}/{Path(args.output_dir).name}_emb.safetensors")
+
+ def load_model_hook(models, input_dir):
+ unet_ = None
+ text_encoder_one_ = None
+
+ while len(models) > 0:
+ model = models.pop()
+
+ if isinstance(model, type(accelerator.unwrap_model(unet))):
+ unet_ = model
+ elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
+ text_encoder_one_ = model
+ else:
+ raise ValueError(f"unexpected save model: {model.__class__}")
+
+ lora_state_dict, network_alphas = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir)
+ StableDiffusionLoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_)
+
+ text_encoder_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder." in k}
+ StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder(
+ text_encoder_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_one_
+ )
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ unet_lora_parameters = list(filter(lambda p: p.requires_grad, unet.parameters()))
+
+ if args.train_text_encoder:
+ text_lora_parameters_one = list(filter(lambda p: p.requires_grad, text_encoder_one.parameters()))
+
+ # If neither --train_text_encoder nor --train_text_encoder_ti, text_encoders remain frozen during training
+ freeze_text_encoder = not (args.train_text_encoder or args.train_text_encoder_ti)
+
+ # Optimization parameters
+ unet_lora_parameters_with_lr = {"params": unet_lora_parameters, "lr": args.learning_rate}
+ if not freeze_text_encoder:
+ # different learning rate for text encoder and unet
+ text_lora_parameters_one_with_lr = {
+ "params": text_lora_parameters_one,
+ "weight_decay": args.adam_weight_decay_text_encoder
+ if args.adam_weight_decay_text_encoder
+ else args.adam_weight_decay,
+ "lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate,
+ }
+ params_to_optimize = [
+ unet_lora_parameters_with_lr,
+ text_lora_parameters_one_with_lr,
+ ]
+ else:
+ params_to_optimize = [unet_lora_parameters_with_lr]
+
+ # Optimizer creation
+ if not (args.optimizer.lower() == "prodigy" or args.optimizer.lower() == "adamw"):
+ logger.warning(
+ f"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy]."
+ "Defaulting to adamW"
+ )
+ args.optimizer = "adamw"
+
+ if args.use_8bit_adam and not args.optimizer.lower() == "adamw":
+ logger.warning(
+ f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was "
+ f"set to {args.optimizer.lower()}"
+ )
+
+ if args.optimizer.lower() == "adamw":
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
+ )
+
+ optimizer_class = bnb.optim.AdamW8bit
+ else:
+ optimizer_class = torch.optim.AdamW
+
+ optimizer = optimizer_class(
+ params_to_optimize,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ if args.optimizer.lower() == "prodigy":
+ try:
+ import prodigyopt
+ except ImportError:
+ raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`")
+
+ optimizer_class = prodigyopt.Prodigy
+
+ if args.learning_rate <= 0.1:
+ logger.warning(
+ "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0"
+ )
+ if args.train_text_encoder and args.text_encoder_lr:
+ logger.warning(
+ f"Learning rates were provided both for the unet and the text encoder- e.g. text_encoder_lr:"
+ f" {args.text_encoder_lr} and learning_rate: {args.learning_rate}. "
+ f"When using prodigy only learning_rate is used as the initial learning rate."
+ )
+ # changes the learning rate of text_encoder_parameters_one to be
+ # --learning_rate
+ params_to_optimize[1]["lr"] = args.learning_rate
+
+ optimizer = optimizer_class(
+ params_to_optimize,
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ beta3=args.prodigy_beta3,
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ decouple=args.prodigy_decouple,
+ use_bias_correction=args.prodigy_use_bias_correction,
+ safeguard_warmup=args.prodigy_safeguard_warmup,
+ )
+
+ # Dataset and DataLoaders creation:
+ train_dataset = DreamBoothDataset(
+ instance_data_root=args.instance_data_dir,
+ instance_prompt=args.instance_prompt,
+ class_prompt=args.class_prompt,
+ dataset_name=args.dataset_name,
+ dataset_config_name=args.dataset_config_name,
+ cache_dir=args.cache_dir,
+ image_column=args.image_column,
+ train_text_encoder_ti=args.train_text_encoder_ti,
+ caption_column=args.caption_column,
+ class_data_root=args.class_data_dir if args.with_prior_preservation else None,
+ token_abstraction_dict=token_abstraction_dict if args.train_text_encoder_ti else None,
+ class_num=args.num_class_images,
+ size=args.resolution,
+ repeats=args.repeats,
+ center_crop=args.center_crop,
+ )
+
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset,
+ batch_size=args.train_batch_size,
+ shuffle=True,
+ collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),
+ num_workers=args.dataloader_num_workers,
+ )
+
+ if not args.train_text_encoder:
+ tokenizers = [tokenizer_one]
+ text_encoders = [text_encoder_one]
+
+ def compute_text_embeddings(prompt, text_encoders, tokenizers):
+ with torch.no_grad():
+ prompt_embeds = encode_prompt(text_encoders, tokenizers, prompt)
+ prompt_embeds = prompt_embeds.to(accelerator.device)
+ return prompt_embeds
+
+ # If no type of tuning is done on the text_encoder and custom instance prompts are NOT
+ # provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid
+ # the redundant encoding.
+ if freeze_text_encoder and not train_dataset.custom_instance_prompts:
+ instance_prompt_hidden_states = compute_text_embeddings(args.instance_prompt, text_encoders, tokenizers)
+
+ # Handle class prompt for prior-preservation.
+ if args.with_prior_preservation:
+ if freeze_text_encoder:
+ class_prompt_hidden_states = compute_text_embeddings(args.class_prompt, text_encoders, tokenizers)
+
+ # Clear the memory here
+ if freeze_text_encoder and not train_dataset.custom_instance_prompts:
+ del tokenizers, text_encoders
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ # if --train_text_encoder_ti we need add_special_tokens to be True for textual inversion
+ add_special_tokens = True if args.train_text_encoder_ti else False
+
+ if not train_dataset.custom_instance_prompts:
+ if freeze_text_encoder:
+ prompt_embeds = instance_prompt_hidden_states
+ if args.with_prior_preservation:
+ prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0)
+
+ # if we're optimizing the text encoder (both if instance prompt is used for all images or custom prompts) we need to tokenize and encode the
+ # batch prompts on all training steps
+ else:
+ tokens_one = tokenize_prompt(tokenizer_one, args.instance_prompt, add_special_tokens)
+ if args.with_prior_preservation:
+ class_tokens_one = tokenize_prompt(tokenizer_one, args.class_prompt, add_special_tokens)
+ tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0)
+
+ if args.train_text_encoder_ti and args.validation_prompt:
+ # replace instances of --token_abstraction in validation prompt with the new tokens: "" etc.
+ for token_abs, token_replacement in train_dataset.token_abstraction_dict.items():
+ args.validation_prompt = args.validation_prompt.replace(token_abs, "".join(token_replacement))
+ print("validation prompt:", args.validation_prompt)
+
+ if args.cache_latents:
+ latents_cache = []
+ for batch in tqdm(train_dataloader, desc="Caching latents"):
+ with torch.no_grad():
+ batch["pixel_values"] = batch["pixel_values"].to(
+ accelerator.device, non_blocking=True, dtype=torch.float32
+ )
+ latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist)
+
+ if args.validation_prompt is None:
+ del vae
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+
+ # Scheduler and math around the number of training steps.
+ # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
+ num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
+ if args.max_train_steps is None:
+ len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
+ num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
+ num_training_steps_for_scheduler = (
+ args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes
+ )
+ else:
+ num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=num_warmup_steps_for_scheduler,
+ num_training_steps=num_training_steps_for_scheduler,
+ num_cycles=args.lr_num_cycles,
+ power=args.lr_power,
+ )
+
+ # Prepare everything with our `accelerator`.
+ if not freeze_text_encoder:
+ unet, text_encoder_one, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ unet, text_encoder_one, optimizer, train_dataloader, lr_scheduler
+ )
+ else:
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ unet, optimizer, train_dataloader, lr_scheduler
+ )
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ if num_training_steps_for_scheduler != args.max_train_steps * accelerator.num_processes:
+ logger.warning(
+ f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
+ f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
+ f"This inconsistency may result in the learning rate scheduler not functioning properly."
+ )
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ accelerator.init_trackers("dreambooth-lora-sd-15", config=vars(args))
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the mos recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ initial_global_step = 0
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ initial_global_step = global_step
+ first_epoch = global_step // num_update_steps_per_epoch
+
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
+
+ if args.train_text_encoder:
+ num_train_epochs_text_encoder = int(args.train_text_encoder_frac * args.num_train_epochs)
+ elif args.train_text_encoder_ti: # args.train_text_encoder_ti
+ num_train_epochs_text_encoder = int(args.train_text_encoder_ti_frac * args.num_train_epochs)
+
+ for epoch in range(first_epoch, args.num_train_epochs):
+ # if performing any kind of optimization of text_encoder params
+ if args.train_text_encoder or args.train_text_encoder_ti:
+ if epoch == num_train_epochs_text_encoder:
+ print("PIVOT HALFWAY", epoch)
+ # stopping optimization of text_encoder params
+ # re setting the optimizer to optimize only on unet params
+ optimizer.param_groups[1]["lr"] = 0.0
+
+ else:
+ # still optimizng the text encoder
+ text_encoder_one.train()
+ # set top parameter requires_grad = True for gradient checkpointing works
+ if args.train_text_encoder:
+ text_encoder_one.text_model.embeddings.requires_grad_(True)
+
+ unet.train()
+ for step, batch in enumerate(train_dataloader):
+ with accelerator.accumulate(unet):
+ prompts = batch["prompts"]
+ # encode batch prompts when custom prompts are provided for each image -
+ if train_dataset.custom_instance_prompts:
+ if freeze_text_encoder:
+ prompt_embeds = compute_text_embeddings(prompts, text_encoders, tokenizers)
+
+ else:
+ tokens_one = tokenize_prompt(tokenizer_one, prompts, add_special_tokens)
+
+ if args.cache_latents:
+ model_input = latents_cache[step].sample()
+ else:
+ pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
+ model_input = vae.encode(pixel_values).latent_dist.sample()
+
+ model_input = model_input * vae_scaling_factor
+ if args.pretrained_vae_model_name_or_path is None:
+ model_input = model_input.to(weight_dtype)
+
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(model_input)
+ if args.noise_offset:
+ # https://www.crosslabs.org//blog/diffusion-with-offset-noise
+ noise += args.noise_offset * torch.randn(
+ (model_input.shape[0], model_input.shape[1], 1, 1), device=model_input.device
+ )
+ bsz = model_input.shape[0]
+ # Sample a random timestep for each image
+ timesteps = torch.randint(
+ 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device
+ )
+ timesteps = timesteps.long()
+
+ # Add noise to the model input according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)
+
+ # Calculate the elements to repeat depending on the use of prior-preservation and custom captions.
+ if not train_dataset.custom_instance_prompts:
+ elems_to_repeat_text_embeds = bsz // 2 if args.with_prior_preservation else bsz
+
+ else:
+ elems_to_repeat_text_embeds = 1
+
+ # Predict the noise residual
+ if freeze_text_encoder:
+ prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1)
+ model_pred = unet(noisy_model_input, timesteps, prompt_embeds_input).sample
+ else:
+ prompt_embeds = encode_prompt(
+ text_encoders=[text_encoder_one],
+ tokenizers=None,
+ prompt=None,
+ text_input_ids_list=[tokens_one],
+ )
+ prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1)
+ model_pred = unet(noisy_model_input, timesteps, prompt_embeds_input).sample
+
+ # Get the target for loss depending on the prediction type
+ if noise_scheduler.config.prediction_type == "epsilon":
+ target = noise
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ target = noise_scheduler.get_velocity(model_input, noise, timesteps)
+ else:
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
+
+ if args.with_prior_preservation:
+ # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
+ model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
+ target, target_prior = torch.chunk(target, 2, dim=0)
+
+ # Compute prior loss
+ prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
+
+ if args.snr_gamma is None:
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
+ else:
+ # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
+ # Since we predict the noise instead of x_0, the original formulation is slightly changed.
+ # This is discussed in Section 4.2 of the same paper.
+
+ if args.with_prior_preservation:
+ # if we're using prior preservation, we calc snr for instance loss only -
+ # and hence only need timesteps corresponding to instance images
+ snr_timesteps, _ = torch.chunk(timesteps, 2, dim=0)
+ else:
+ snr_timesteps = timesteps
+
+ snr = compute_snr(noise_scheduler, snr_timesteps)
+ base_weight = (
+ torch.stack([snr, args.snr_gamma * torch.ones_like(snr_timesteps)], dim=1).min(dim=1)[0] / snr
+ )
+
+ if noise_scheduler.config.prediction_type == "v_prediction":
+ # Velocity objective needs to be floored to an SNR weight of one.
+ mse_loss_weights = base_weight + 1
+ else:
+ # Epsilon and sample both use the same loss weights.
+ mse_loss_weights = base_weight
+
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
+ loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
+ loss = loss.mean()
+
+ if args.with_prior_preservation:
+ # Add the prior loss to the instance loss.
+ loss = loss + args.prior_loss_weight * prior_loss
+
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ params_to_clip = (
+ itertools.chain(unet_lora_parameters, text_lora_parameters_one)
+ if (args.train_text_encoder or args.train_text_encoder_ti)
+ else unet_lora_parameters
+ )
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ # every step, we reset the embeddings to the original embeddings.
+ if args.train_text_encoder_ti:
+ for idx, text_encoder in enumerate(text_encoders):
+ embedding_handler.retract_embeddings()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ global_step += 1
+
+ if accelerator.is_main_process:
+ if global_step % args.checkpointing_steps == 0:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+ accelerator.log(logs, step=global_step)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ if accelerator.is_main_process:
+ if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
+ logger.info(
+ f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
+ f" {args.validation_prompt}."
+ )
+ # create pipeline
+ if freeze_text_encoder:
+ text_encoder_one = text_encoder_cls_one.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="text_encoder",
+ revision=args.revision,
+ variant=args.variant,
+ )
+ pipeline = StableDiffusionPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ vae=vae,
+ tokenizer=tokenizer_one,
+ text_encoder=accelerator.unwrap_model(text_encoder_one),
+ unet=accelerator.unwrap_model(unet),
+ revision=args.revision,
+ variant=args.variant,
+ torch_dtype=weight_dtype,
+ )
+
+ # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
+ scheduler_args = {}
+
+ if "variance_type" in pipeline.scheduler.config:
+ variance_type = pipeline.scheduler.config.variance_type
+
+ if variance_type in ["learned", "learned_range"]:
+ variance_type = "fixed_small"
+
+ scheduler_args["variance_type"] = variance_type
+
+ pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
+ pipeline.scheduler.config, **scheduler_args
+ )
+
+ pipeline = pipeline.to(accelerator.device)
+ pipeline.set_progress_bar_config(disable=True)
+
+ # run inference
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
+ pipeline_args = {"prompt": args.validation_prompt}
+
+ if torch.backends.mps.is_available():
+ autocast_ctx = nullcontext()
+ else:
+ autocast_ctx = torch.autocast(accelerator.device.type)
+
+ with autocast_ctx:
+ images = [
+ pipeline(**pipeline_args, generator=generator).images[0]
+ for _ in range(args.num_validation_images)
+ ]
+
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
+ if tracker.name == "wandb":
+ tracker.log(
+ {
+ "validation": [
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
+ for i, image in enumerate(images)
+ ]
+ }
+ )
+ del pipeline
+ torch.cuda.empty_cache()
+
+ # Save the lora layers
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ unet = accelerator.unwrap_model(unet)
+ unet = unet.to(torch.float32)
+ unet_lora_layers = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet))
+
+ if args.train_text_encoder:
+ text_encoder_one = accelerator.unwrap_model(text_encoder_one)
+ text_encoder_lora_layers = convert_state_dict_to_diffusers(
+ get_peft_model_state_dict(text_encoder_one.to(torch.float32))
+ )
+ else:
+ text_encoder_lora_layers = None
+
+ StableDiffusionPipeline.save_lora_weights(
+ save_directory=args.output_dir,
+ unet_lora_layers=unet_lora_layers,
+ text_encoder_lora_layers=text_encoder_lora_layers,
+ )
+
+ if args.train_text_encoder_ti:
+ embeddings_path = f"{args.output_dir}/{args.output_dir}_emb.safetensors"
+ embedding_handler.save_embeddings(embeddings_path)
+
+ images = []
+ if args.validation_prompt and args.num_validation_images > 0:
+ # Final inference
+ # Load previous pipeline
+ vae = AutoencoderKL.from_pretrained(
+ vae_path,
+ subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
+ revision=args.revision,
+ variant=args.variant,
+ torch_dtype=weight_dtype,
+ )
+ pipeline = StableDiffusionPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ vae=vae,
+ revision=args.revision,
+ variant=args.variant,
+ torch_dtype=weight_dtype,
+ )
+
+ # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
+ scheduler_args = {}
+
+ if "variance_type" in pipeline.scheduler.config:
+ variance_type = pipeline.scheduler.config.variance_type
+
+ if variance_type in ["learned", "learned_range"]:
+ variance_type = "fixed_small"
+
+ scheduler_args["variance_type"] = variance_type
+
+ pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args)
+
+ # load attention processors
+ pipeline.load_lora_weights(args.output_dir)
+
+ # load new tokens
+ if args.train_text_encoder_ti:
+ state_dict = load_file(embeddings_path)
+ all_new_tokens = []
+ for key, value in token_abstraction_dict.items():
+ all_new_tokens.extend(value)
+ pipeline.load_textual_inversion(
+ state_dict["clip_l"],
+ token=all_new_tokens,
+ text_encoder=pipeline.text_encoder,
+ tokenizer=pipeline.tokenizer,
+ )
+ # run inference
+ pipeline = pipeline.to(accelerator.device)
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
+ images = [
+ pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]
+ for _ in range(args.num_validation_images)
+ ]
+
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC")
+ if tracker.name == "wandb":
+ tracker.log(
+ {
+ "test": [
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
+ for i, image in enumerate(images)
+ ]
+ }
+ )
+
+ # Convert to WebUI format
+ lora_state_dict = load_file(f"{args.output_dir}/pytorch_lora_weights.safetensors")
+ peft_state_dict = convert_all_state_dict_to_peft(lora_state_dict)
+ kohya_state_dict = convert_state_dict_to_kohya(peft_state_dict)
+ save_file(kohya_state_dict, f"{args.output_dir}/{Path(args.output_dir).name}.safetensors")
+
+ save_model_card(
+ model_id if not args.push_to_hub else repo_id,
+ use_dora=args.use_dora,
+ images=images,
+ base_model=args.pretrained_model_name_or_path,
+ train_text_encoder=args.train_text_encoder,
+ train_text_encoder_ti=args.train_text_encoder_ti,
+ token_abstraction_dict=train_dataset.token_abstraction_dict,
+ instance_prompt=args.instance_prompt,
+ validation_prompt=args.validation_prompt,
+ repo_folder=args.output_dir,
+ vae_path=args.pretrained_vae_model_name_or_path,
+ )
+ if args.push_to_hub:
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ main(args)
diff --git a/diffusers/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py b/diffusers/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py
new file mode 100644
index 0000000000000000000000000000000000000000..5222c8afe6f1acdae1de0a61d77bd48a0a34cbab
--- /dev/null
+++ b/diffusers/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py
@@ -0,0 +1,2459 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+import argparse
+import gc
+import itertools
+import json
+import logging
+import math
+import os
+import random
+import re
+import shutil
+import warnings
+from contextlib import nullcontext
+from pathlib import Path
+from typing import List, Optional
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
+from huggingface_hub import create_repo, hf_hub_download, upload_folder
+from huggingface_hub.utils import insecure_hashlib
+from packaging import version
+from peft import LoraConfig, set_peft_model_state_dict
+from peft.utils import get_peft_model_state_dict
+from PIL import Image
+from PIL.ImageOps import exif_transpose
+from safetensors.torch import load_file, save_file
+from torch.utils.data import Dataset
+from torchvision import transforms
+from torchvision.transforms.functional import crop
+from tqdm.auto import tqdm
+from transformers import AutoTokenizer, PretrainedConfig
+
+import diffusers
+from diffusers import (
+ AutoencoderKL,
+ DDPMScheduler,
+ DPMSolverMultistepScheduler,
+ EDMEulerScheduler,
+ EulerDiscreteScheduler,
+ StableDiffusionXLPipeline,
+ UNet2DConditionModel,
+)
+from diffusers.loaders import StableDiffusionLoraLoaderMixin
+from diffusers.optimization import get_scheduler
+from diffusers.training_utils import _set_state_dict_into_text_encoder, cast_training_params, compute_snr
+from diffusers.utils import (
+ check_min_version,
+ convert_all_state_dict_to_peft,
+ convert_state_dict_to_diffusers,
+ convert_state_dict_to_kohya,
+ convert_unet_state_dict_to_peft,
+ is_wandb_available,
+)
+from diffusers.utils.import_utils import is_xformers_available
+from diffusers.utils.torch_utils import is_compiled_module
+
+
+if is_wandb_available():
+ import wandb
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.31.0.dev0")
+
+logger = get_logger(__name__)
+
+
+def determine_scheduler_type(pretrained_model_name_or_path, revision):
+ model_index_filename = "model_index.json"
+ if os.path.isdir(pretrained_model_name_or_path):
+ model_index = os.path.join(pretrained_model_name_or_path, model_index_filename)
+ else:
+ model_index = hf_hub_download(
+ repo_id=pretrained_model_name_or_path, filename=model_index_filename, revision=revision
+ )
+
+ with open(model_index, "r") as f:
+ scheduler_type = json.load(f)["scheduler"][1]
+ return scheduler_type
+
+
+def save_model_card(
+ repo_id: str,
+ use_dora: bool,
+ images=None,
+ base_model: str = None,
+ train_text_encoder=False,
+ train_text_encoder_ti=False,
+ token_abstraction_dict=None,
+ instance_prompt: str = None,
+ validation_prompt: str = None,
+ repo_folder=None,
+ vae_path=None,
+):
+ img_str = "widget:\n"
+ lora = "lora" if not use_dora else "dora"
+ for i, image in enumerate(images):
+ image.save(os.path.join(repo_folder, f"image_{i}.png"))
+ img_str += f"""
+ - text: '{validation_prompt if validation_prompt else ' ' }'
+ output:
+ url:
+ "image_{i}.png"
+ """
+ if not images:
+ img_str += f"""
+ - text: '{instance_prompt}'
+ """
+ embeddings_filename = f"{repo_folder}_emb"
+ instance_prompt_webui = re.sub(r"", "", re.sub(r"", embeddings_filename, instance_prompt, count=1))
+ ti_keys = ", ".join(f'"{match}"' for match in re.findall(r"", instance_prompt))
+ if instance_prompt_webui != embeddings_filename:
+ instance_prompt_sentence = f"For example, `{instance_prompt_webui}`"
+ else:
+ instance_prompt_sentence = ""
+ trigger_str = f"You should use {instance_prompt} to trigger the image generation."
+ diffusers_imports_pivotal = ""
+ diffusers_example_pivotal = ""
+ webui_example_pivotal = ""
+ license = ""
+ if "playground" in base_model:
+ license = """\n
+ ## License
+
+ Please adhere to the licensing terms as described [here](https://huggingface.co/playgroundai/playground-v2.5-1024px-aesthetic/blob/main/LICENSE.md).
+ """
+
+ if train_text_encoder_ti:
+ trigger_str = (
+ "To trigger image generation of trained concept(or concepts) replace each concept identifier "
+ "in you prompt with the new inserted tokens:\n"
+ )
+ diffusers_imports_pivotal = """from huggingface_hub import hf_hub_download
+from safetensors.torch import load_file
+ """
+ diffusers_example_pivotal = f"""embedding_path = hf_hub_download(repo_id='{repo_id}', filename='{embeddings_filename}.safetensors', repo_type="model")
+state_dict = load_file(embedding_path)
+pipeline.load_textual_inversion(state_dict["clip_l"], token=[{ti_keys}], text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer)
+pipeline.load_textual_inversion(state_dict["clip_g"], token=[{ti_keys}], text_encoder=pipeline.text_encoder_2, tokenizer=pipeline.tokenizer_2)
+ """
+ webui_example_pivotal = f"""- *Embeddings*: download **[`{embeddings_filename}.safetensors` here 💾](/{repo_id}/blob/main/{embeddings_filename}.safetensors)**.
+ - Place it on it on your `embeddings` folder
+ - Use it by adding `{embeddings_filename}` to your prompt. {instance_prompt_sentence}
+ (you need both the LoRA and the embeddings as they were trained together for this LoRA)
+ """
+ if token_abstraction_dict:
+ for key, value in token_abstraction_dict.items():
+ tokens = "".join(value)
+ trigger_str += f"""
+to trigger concept `{key}` → use `{tokens}` in your prompt \n
+"""
+
+ yaml = f"""---
+tags:
+- stable-diffusion-xl
+- stable-diffusion-xl-diffusers
+- diffusers-training
+- text-to-image
+- diffusers
+- {lora}
+- template:sd-lora
+{img_str}
+base_model: {base_model}
+instance_prompt: {instance_prompt}
+license: openrail++
+---
+"""
+
+ model_card = f"""
+# SDXL LoRA DreamBooth - {repo_id}
+
+
+
+## Model description
+
+### These are {repo_id} LoRA adaption weights for {base_model}.
+
+## Download model
+
+### Use it with UIs such as AUTOMATIC1111, Comfy UI, SD.Next, Invoke
+
+- **LoRA**: download **[`{repo_folder}.safetensors` here 💾](/{repo_id}/blob/main/{repo_folder}.safetensors)**.
+ - Place it on your `models/Lora` folder.
+ - On AUTOMATIC1111, load the LoRA by adding `` to your prompt. On ComfyUI just [load it as a regular LoRA](https://comfyanonymous.github.io/ComfyUI_examples/lora/).
+{webui_example_pivotal}
+
+## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers)
+
+```py
+from diffusers import AutoPipelineForText2Image
+import torch
+{diffusers_imports_pivotal}
+pipeline = AutoPipelineForText2Image.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0', torch_dtype=torch.float16).to('cuda')
+pipeline.load_lora_weights('{repo_id}', weight_name='pytorch_lora_weights.safetensors')
+{diffusers_example_pivotal}
+image = pipeline('{validation_prompt if validation_prompt else instance_prompt}').images[0]
+```
+
+For more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters)
+
+## Trigger words
+
+{trigger_str}
+
+## Details
+All [Files & versions](/{repo_id}/tree/main).
+
+The weights were trained using [🧨 diffusers Advanced Dreambooth Training Script](https://github.com/huggingface/diffusers/blob/main/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py).
+
+LoRA for the text encoder was enabled. {train_text_encoder}.
+
+Pivotal tuning was enabled: {train_text_encoder_ti}.
+
+Special VAE used for training: {vae_path}.
+
+{license}
+"""
+ with open(os.path.join(repo_folder, "README.md"), "w") as f:
+ f.write(yaml + model_card)
+
+
+def log_validation(
+ pipeline,
+ args,
+ accelerator,
+ pipeline_args,
+ epoch,
+ is_final_validation=False,
+):
+ logger.info(
+ f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
+ f" {args.validation_prompt}."
+ )
+
+ # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
+ scheduler_args = {}
+
+ if not args.do_edm_style_training:
+ if "variance_type" in pipeline.scheduler.config:
+ variance_type = pipeline.scheduler.config.variance_type
+
+ if variance_type in ["learned", "learned_range"]:
+ variance_type = "fixed_small"
+
+ scheduler_args["variance_type"] = variance_type
+
+ pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args)
+
+ pipeline = pipeline.to(accelerator.device)
+ pipeline.set_progress_bar_config(disable=True)
+
+ # run inference
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
+ # Currently the context determination is a bit hand-wavy. We can improve it in the future if there's a better
+ # way to condition it. Reference: https://github.com/huggingface/diffusers/pull/7126#issuecomment-1968523051
+ if torch.backends.mps.is_available() or "playground" in args.pretrained_model_name_or_path:
+ autocast_ctx = nullcontext()
+ else:
+ autocast_ctx = torch.autocast(accelerator.device.type)
+
+ with autocast_ctx:
+ images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)]
+
+ for tracker in accelerator.trackers:
+ phase_name = "test" if is_final_validation else "validation"
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images(phase_name, np_images, epoch, dataformats="NHWC")
+ if tracker.name == "wandb":
+ tracker.log(
+ {
+ phase_name: [
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images)
+ ]
+ }
+ )
+
+ del pipeline
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+
+ return images
+
+
+def import_model_class_from_model_name_or_path(
+ pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
+):
+ text_encoder_config = PretrainedConfig.from_pretrained(
+ pretrained_model_name_or_path, subfolder=subfolder, revision=revision
+ )
+ model_class = text_encoder_config.architectures[0]
+
+ if model_class == "CLIPTextModel":
+ from transformers import CLIPTextModel
+
+ return CLIPTextModel
+ elif model_class == "CLIPTextModelWithProjection":
+ from transformers import CLIPTextModelWithProjection
+
+ return CLIPTextModelWithProjection
+ else:
+ raise ValueError(f"{model_class} is not supported.")
+
+
+def parse_args(input_args=None):
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--pretrained_vae_model_name_or_path",
+ type=str,
+ default=None,
+ help="Path to pretrained VAE model with better numerical stability. More details: https://github.com/huggingface/diffusers/pull/4038.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--variant",
+ type=str,
+ default=None,
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
+ )
+ parser.add_argument(
+ "--dataset_name",
+ type=str,
+ default=None,
+ help=(
+ "The name of the Dataset (from the HuggingFace hub) containing the training data of instance images (could be your own, possibly private,"
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
+ " or to a folder containing files that 🤗 Datasets can understand.To load the custom captions, the training set directory needs to follow the structure of a "
+ "datasets ImageFolder, containing both the images and the corresponding caption for each image. see: "
+ "https://huggingface.co/docs/datasets/image_dataset for more information"
+ ),
+ )
+ parser.add_argument(
+ "--dataset_config_name",
+ type=str,
+ default=None,
+ help="The config of the Dataset. In some cases, a dataset may have more than one configuration (for example "
+ "if it contains different subsets of data within, and you only wish to load a specific subset - in that case specify the desired configuration using --dataset_config_name. Leave as "
+ "None if there's only one config.",
+ )
+ parser.add_argument(
+ "--instance_data_dir",
+ type=str,
+ default=None,
+ help="A path to local folder containing the training data of instance images. Specify this arg instead of "
+ "--dataset_name if you wish to train using a local folder without custom captions. If you wish to train with custom captions please specify "
+ "--dataset_name instead.",
+ )
+
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+
+ parser.add_argument(
+ "--image_column",
+ type=str,
+ default="image",
+ help="The column of the dataset containing the target image. By "
+ "default, the standard Image Dataset maps out 'file_name' "
+ "to 'image'.",
+ )
+ parser.add_argument(
+ "--caption_column",
+ type=str,
+ default=None,
+ help="The column of the dataset containing the instance prompt for each image",
+ )
+
+ parser.add_argument("--repeats", type=int, default=1, help="How many times to repeat the training data.")
+
+ parser.add_argument(
+ "--class_data_dir",
+ type=str,
+ default=None,
+ required=False,
+ help="A folder containing the training data of class images.",
+ )
+ parser.add_argument(
+ "--instance_prompt",
+ type=str,
+ default=None,
+ required=True,
+ help="The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'",
+ )
+ parser.add_argument(
+ "--token_abstraction",
+ type=str,
+ default="TOK",
+ help="identifier specifying the instance(or instances) as used in instance_prompt, validation prompt, "
+ "captions - e.g. TOK. To use multiple identifiers, please specify them in a comma separated string - e.g. "
+ "'TOK,TOK2,TOK3' etc.",
+ )
+
+ parser.add_argument(
+ "--num_new_tokens_per_abstraction",
+ type=int,
+ default=2,
+ help="number of new tokens inserted to the tokenizers per token_abstraction identifier when "
+ "--train_text_encoder_ti = True. By default, each --token_abstraction (e.g. TOK) is mapped to 2 new "
+ "tokens - ",
+ )
+
+ parser.add_argument(
+ "--class_prompt",
+ type=str,
+ default=None,
+ help="The prompt to specify images in the same class as provided instance images.",
+ )
+ parser.add_argument(
+ "--validation_prompt",
+ type=str,
+ default=None,
+ help="A prompt that is used during validation to verify that the model is learning.",
+ )
+ parser.add_argument(
+ "--num_validation_images",
+ type=int,
+ default=4,
+ help="Number of images that should be generated during validation with `validation_prompt`.",
+ )
+ parser.add_argument(
+ "--validation_epochs",
+ type=int,
+ default=50,
+ help=(
+ "Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt"
+ " `args.validation_prompt` multiple times: `args.num_validation_images`."
+ ),
+ )
+ parser.add_argument(
+ "--do_edm_style_training",
+ default=False,
+ action="store_true",
+ help="Flag to conduct training using the EDM formulation as introduced in https://arxiv.org/abs/2206.00364.",
+ )
+ parser.add_argument(
+ "--with_prior_preservation",
+ default=False,
+ action="store_true",
+ help="Flag to add prior preservation loss.",
+ )
+ parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.")
+ parser.add_argument(
+ "--num_class_images",
+ type=int,
+ default=100,
+ help=(
+ "Minimal class images for prior preservation loss. If there are not enough images already present in"
+ " class_data_dir, additional images will be sampled with class_prompt."
+ ),
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="lora-dreambooth-model",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=1024,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--center_crop",
+ default=False,
+ action="store_true",
+ help=(
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
+ " cropped. The images will be resized to the resolution first before cropping."
+ ),
+ )
+ parser.add_argument(
+ "--random_flip",
+ action="store_true",
+ help="whether to randomly flip images horizontally",
+ )
+ parser.add_argument(
+ "--train_text_encoder",
+ action="store_true",
+ help="Whether to train the text encoder. If set, the text encoder should be float32 precision.",
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument(
+ "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=1)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
+ " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--clip_skip",
+ type=int,
+ default=None,
+ help="Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that "
+ "the output of the pre-final layer will be used for computing the prompt embeddings.",
+ )
+
+ parser.add_argument(
+ "--text_encoder_lr",
+ type=float,
+ default=5e-6,
+ help="Text encoder learning rate to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+
+ parser.add_argument(
+ "--snr_gamma",
+ type=float,
+ default=None,
+ help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
+ "More details here: https://arxiv.org/abs/2303.09556.",
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--lr_num_cycles",
+ type=int,
+ default=1,
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
+ )
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+
+ parser.add_argument(
+ "--train_text_encoder_ti",
+ action="store_true",
+ help=("Whether to use textual inversion"),
+ )
+
+ parser.add_argument(
+ "--train_text_encoder_ti_frac",
+ type=float,
+ default=0.5,
+ help=("The percentage of epochs to perform textual inversion"),
+ )
+
+ parser.add_argument(
+ "--train_text_encoder_frac",
+ type=float,
+ default=1.0,
+ help=("The percentage of epochs to perform text encoder tuning"),
+ )
+
+ parser.add_argument(
+ "--optimizer",
+ type=str,
+ default="AdamW",
+ help=('The optimizer type to use. Choose between ["AdamW", "prodigy"]'),
+ )
+
+ parser.add_argument(
+ "--use_8bit_adam",
+ action="store_true",
+ help="Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW",
+ )
+
+ parser.add_argument(
+ "--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam and Prodigy optimizers."
+ )
+ parser.add_argument(
+ "--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam and Prodigy optimizers."
+ )
+ parser.add_argument(
+ "--prodigy_beta3",
+ type=float,
+ default=None,
+ help="coefficients for computing the Prodigy stepsize using running averages. If set to None, "
+ "uses the value of square root of beta2. Ignored if optimizer is adamW",
+ )
+ parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params")
+ parser.add_argument(
+ "--adam_weight_decay_text_encoder", type=float, default=None, help="Weight decay to use for text_encoder"
+ )
+
+ parser.add_argument(
+ "--adam_epsilon",
+ type=float,
+ default=1e-08,
+ help="Epsilon value for the Adam optimizer and Prodigy optimizers.",
+ )
+
+ parser.add_argument(
+ "--prodigy_use_bias_correction",
+ type=bool,
+ default=True,
+ help="Turn on Adam's bias correction. True by default. Ignored if optimizer is adamW",
+ )
+ parser.add_argument(
+ "--prodigy_safeguard_warmup",
+ type=bool,
+ default=True,
+ help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. "
+ "Ignored if optimizer is adamW",
+ )
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument(
+ "--prior_generation_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp32", "fp16", "bf16"],
+ help=(
+ "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32."
+ ),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+ parser.add_argument(
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
+ )
+ parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.")
+ parser.add_argument(
+ "--rank",
+ type=int,
+ default=4,
+ help=("The dimension of the LoRA update matrices."),
+ )
+ parser.add_argument(
+ "--use_dora",
+ action="store_true",
+ default=False,
+ help=(
+ "Wether to train a DoRA as proposed in- DoRA: Weight-Decomposed Low-Rank Adaptation https://arxiv.org/abs/2402.09353. "
+ "Note: to use DoRA you need to install peft from main, `pip install git+https://github.com/huggingface/peft.git`"
+ ),
+ )
+ parser.add_argument(
+ "--lora_unet_blocks",
+ type=str,
+ default=None,
+ help=(
+ "the U-net blocks to tune during training. please specify them in a comma separated string, e.g. `unet.up_blocks.0.attentions.0,unet.up_blocks.0.attentions.1` etc."
+ "NOTE: By default (if not specified) - regular LoRA training is performed. "
+ "if --use_blora is enabled, this arg will be ignored, since in B-LoRA training, targeted U-net blocks are `unet.up_blocks.0.attentions.0` and `unet.up_blocks.0.attentions.1`"
+ ),
+ )
+ parser.add_argument(
+ "--use_blora",
+ action="store_true",
+ help=(
+ "Whether to train a B-LoRA as proposed in- Implicit Style-Content Separation using B-LoRA https://arxiv.org/abs/2403.14572. "
+ ),
+ )
+ parser.add_argument(
+ "--cache_latents",
+ action="store_true",
+ default=False,
+ help="Cache the VAE latents",
+ )
+
+ if input_args is not None:
+ args = parser.parse_args(input_args)
+ else:
+ args = parser.parse_args()
+
+ if args.dataset_name is None and args.instance_data_dir is None:
+ raise ValueError("Specify either `--dataset_name` or `--instance_data_dir`")
+
+ if args.dataset_name is not None and args.instance_data_dir is not None:
+ raise ValueError("Specify only one of `--dataset_name` or `--instance_data_dir`")
+
+ if args.train_text_encoder and args.train_text_encoder_ti:
+ raise ValueError(
+ "Specify only one of `--train_text_encoder` or `--train_text_encoder_ti. "
+ "For full LoRA text encoder training check --train_text_encoder, for textual "
+ "inversion training check `--train_text_encoder_ti`"
+ )
+ if args.use_blora and args.lora_unet_blocks:
+ warnings.warn(
+ "You specified both `--use_blora` and `--lora_unet_blocks`, for B-LoRA training, target unet blocks are: `unet.up_blocks.0.attentions.0` and `unet.up_blocks.0.attentions.1`. "
+ "If you wish to target different U-net blocks, don't enable `--use_blora`"
+ )
+
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ if args.with_prior_preservation:
+ if args.class_data_dir is None:
+ raise ValueError("You must specify a data directory for class images.")
+ if args.class_prompt is None:
+ raise ValueError("You must specify prompt for class images.")
+ else:
+ # logger is not available yet
+ if args.class_data_dir is not None:
+ warnings.warn("You need not use --class_data_dir without --with_prior_preservation.")
+ if args.class_prompt is not None:
+ warnings.warn("You need not use --class_prompt without --with_prior_preservation.")
+
+ return args
+
+
+# Taken (and slightly modified) from B-LoRA repo https://github.com/yardenfren1996/B-LoRA/blob/main/blora_utils.py
+def is_belong_to_blocks(key, blocks):
+ try:
+ for g in blocks:
+ if g in key:
+ return True
+ return False
+ except Exception as e:
+ raise type(e)(f"failed to is_belong_to_block, due to: {e}")
+
+
+def get_unet_lora_target_modules(unet, use_blora, target_blocks=None):
+ if use_blora:
+ content_b_lora_blocks = "unet.up_blocks.0.attentions.0"
+ style_b_lora_blocks = "unet.up_blocks.0.attentions.1"
+ target_blocks = [content_b_lora_blocks, style_b_lora_blocks]
+ try:
+ blocks = [(".").join(blk.split(".")[1:]) for blk in target_blocks]
+
+ attns = [
+ attn_processor_name.rsplit(".", 1)[0]
+ for attn_processor_name, _ in unet.attn_processors.items()
+ if is_belong_to_blocks(attn_processor_name, blocks)
+ ]
+
+ target_modules = [f"{attn}.{mat}" for mat in ["to_k", "to_q", "to_v", "to_out.0"] for attn in attns]
+ return target_modules
+ except Exception as e:
+ raise type(e)(
+ f"failed to get_target_modules, due to: {e}. "
+ f"Please check the modules specified in --lora_unet_blocks are correct"
+ )
+
+
+# Taken from https://github.com/replicate/cog-sdxl/blob/main/dataset_and_utils.py
+class TokenEmbeddingsHandler:
+ def __init__(self, text_encoders, tokenizers):
+ self.text_encoders = text_encoders
+ self.tokenizers = tokenizers
+
+ self.train_ids: Optional[torch.Tensor] = None
+ self.inserting_toks: Optional[List[str]] = None
+ self.embeddings_settings = {}
+
+ def initialize_new_tokens(self, inserting_toks: List[str]):
+ idx = 0
+ for tokenizer, text_encoder in zip(self.tokenizers, self.text_encoders):
+ assert isinstance(inserting_toks, list), "inserting_toks should be a list of strings."
+ assert all(
+ isinstance(tok, str) for tok in inserting_toks
+ ), "All elements in inserting_toks should be strings."
+
+ self.inserting_toks = inserting_toks
+ special_tokens_dict = {"additional_special_tokens": self.inserting_toks}
+ tokenizer.add_special_tokens(special_tokens_dict)
+ text_encoder.resize_token_embeddings(len(tokenizer))
+
+ self.train_ids = tokenizer.convert_tokens_to_ids(self.inserting_toks)
+
+ # random initialization of new tokens
+ std_token_embedding = text_encoder.text_model.embeddings.token_embedding.weight.data.std()
+
+ print(f"{idx} text encoder's std_token_embedding: {std_token_embedding}")
+
+ text_encoder.text_model.embeddings.token_embedding.weight.data[self.train_ids] = (
+ torch.randn(len(self.train_ids), text_encoder.text_model.config.hidden_size)
+ .to(device=self.device)
+ .to(dtype=self.dtype)
+ * std_token_embedding
+ )
+ self.embeddings_settings[
+ f"original_embeddings_{idx}"
+ ] = text_encoder.text_model.embeddings.token_embedding.weight.data.clone()
+ self.embeddings_settings[f"std_token_embedding_{idx}"] = std_token_embedding
+
+ inu = torch.ones((len(tokenizer),), dtype=torch.bool)
+ inu[self.train_ids] = False
+
+ self.embeddings_settings[f"index_no_updates_{idx}"] = inu
+
+ print(self.embeddings_settings[f"index_no_updates_{idx}"].shape)
+
+ idx += 1
+
+ def save_embeddings(self, file_path: str):
+ assert self.train_ids is not None, "Initialize new tokens before saving embeddings."
+ tensors = {}
+ # text_encoder_0 - CLIP ViT-L/14, text_encoder_1 - CLIP ViT-G/14
+ idx_to_text_encoder_name = {0: "clip_l", 1: "clip_g"}
+ for idx, text_encoder in enumerate(self.text_encoders):
+ assert text_encoder.text_model.embeddings.token_embedding.weight.data.shape[0] == len(
+ self.tokenizers[0]
+ ), "Tokenizers should be the same."
+ new_token_embeddings = text_encoder.text_model.embeddings.token_embedding.weight.data[self.train_ids]
+
+ # New tokens for each text encoder are saved under "clip_l" (for text_encoder 0), "clip_g" (for
+ # text_encoder 1) to keep compatible with the ecosystem.
+ # Note: When loading with diffusers, any name can work - simply specify in inference
+ tensors[idx_to_text_encoder_name[idx]] = new_token_embeddings
+ # tensors[f"text_encoders_{idx}"] = new_token_embeddings
+
+ save_file(tensors, file_path)
+
+ @property
+ def dtype(self):
+ return self.text_encoders[0].dtype
+
+ @property
+ def device(self):
+ return self.text_encoders[0].device
+
+ @torch.no_grad()
+ def retract_embeddings(self):
+ for idx, text_encoder in enumerate(self.text_encoders):
+ index_no_updates = self.embeddings_settings[f"index_no_updates_{idx}"]
+ text_encoder.text_model.embeddings.token_embedding.weight.data[index_no_updates] = (
+ self.embeddings_settings[f"original_embeddings_{idx}"][index_no_updates]
+ .to(device=text_encoder.device)
+ .to(dtype=text_encoder.dtype)
+ )
+
+ # for the parts that were updated, we need to normalize them
+ # to have the same std as before
+ std_token_embedding = self.embeddings_settings[f"std_token_embedding_{idx}"]
+
+ index_updates = ~index_no_updates
+ new_embeddings = text_encoder.text_model.embeddings.token_embedding.weight.data[index_updates]
+ off_ratio = std_token_embedding / new_embeddings.std()
+
+ new_embeddings = new_embeddings * (off_ratio**0.1)
+ text_encoder.text_model.embeddings.token_embedding.weight.data[index_updates] = new_embeddings
+
+
+class DreamBoothDataset(Dataset):
+ """
+ A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
+ It pre-processes the images.
+ """
+
+ def __init__(
+ self,
+ instance_data_root,
+ instance_prompt,
+ class_prompt,
+ train_text_encoder_ti,
+ class_data_root=None,
+ class_num=None,
+ token_abstraction_dict=None, # token mapping for textual inversion
+ size=1024,
+ repeats=1,
+ center_crop=False,
+ ):
+ self.size = size
+ self.center_crop = center_crop
+
+ self.instance_prompt = instance_prompt
+ self.custom_instance_prompts = None
+ self.class_prompt = class_prompt
+ self.token_abstraction_dict = token_abstraction_dict
+ self.train_text_encoder_ti = train_text_encoder_ti
+ # if --dataset_name is provided or a metadata jsonl file is provided in the local --instance_data directory,
+ # we load the training data using load_dataset
+ if args.dataset_name is not None:
+ try:
+ from datasets import load_dataset
+ except ImportError:
+ raise ImportError(
+ "You are trying to load your data using the datasets library. If you wish to train using custom "
+ "captions please install the datasets library: `pip install datasets`. If you wish to load a "
+ "local folder containing images only, specify --instance_data_dir instead."
+ )
+ # Downloading and loading a dataset from the hub.
+ # See more about loading custom images at
+ # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script
+ dataset = load_dataset(
+ args.dataset_name,
+ args.dataset_config_name,
+ cache_dir=args.cache_dir,
+ )
+ # Preprocessing the datasets.
+ column_names = dataset["train"].column_names
+
+ # 6. Get the column names for input/target.
+ if args.image_column is None:
+ image_column = column_names[0]
+ logger.info(f"image column defaulting to {image_column}")
+ else:
+ image_column = args.image_column
+ if image_column not in column_names:
+ raise ValueError(
+ f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
+ )
+ instance_images = dataset["train"][image_column]
+
+ if args.caption_column is None:
+ logger.info(
+ "No caption column provided, defaulting to instance_prompt for all images. If your dataset "
+ "contains captions/prompts for the images, make sure to specify the "
+ "column as --caption_column"
+ )
+ self.custom_instance_prompts = None
+ else:
+ if args.caption_column not in column_names:
+ raise ValueError(
+ f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
+ )
+ custom_instance_prompts = dataset["train"][args.caption_column]
+ # create final list of captions according to --repeats
+ self.custom_instance_prompts = []
+ for caption in custom_instance_prompts:
+ self.custom_instance_prompts.extend(itertools.repeat(caption, repeats))
+ else:
+ self.instance_data_root = Path(instance_data_root)
+ if not self.instance_data_root.exists():
+ raise ValueError("Instance images root doesn't exists.")
+
+ instance_images = [Image.open(path) for path in list(Path(instance_data_root).iterdir())]
+ self.custom_instance_prompts = None
+
+ self.instance_images = []
+ for img in instance_images:
+ self.instance_images.extend(itertools.repeat(img, repeats))
+
+ # image processing to prepare for using SD-XL micro-conditioning
+ self.original_sizes = []
+ self.crop_top_lefts = []
+ self.pixel_values = []
+ train_resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR)
+ train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size)
+ train_flip = transforms.RandomHorizontalFlip(p=1.0)
+ train_transforms = transforms.Compose(
+ [
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+ )
+ # if using B-LoRA for single image. do not use transformations
+ single_image = len(self.instance_images) < 2
+ for image in self.instance_images:
+ if not single_image:
+ image = exif_transpose(image)
+ if not image.mode == "RGB":
+ image = image.convert("RGB")
+ self.original_sizes.append((image.height, image.width))
+ image = train_resize(image)
+
+ if not single_image and args.random_flip and random.random() < 0.5:
+ # flip
+ image = train_flip(image)
+ if args.center_crop or single_image:
+ y1 = max(0, int(round((image.height - args.resolution) / 2.0)))
+ x1 = max(0, int(round((image.width - args.resolution) / 2.0)))
+ image = train_crop(image)
+ else:
+ y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution))
+ image = crop(image, y1, x1, h, w)
+ crop_top_left = (y1, x1)
+ self.crop_top_lefts.append(crop_top_left)
+ image = train_transforms(image)
+ self.pixel_values.append(image)
+
+ self.num_instance_images = len(self.instance_images)
+ self._length = self.num_instance_images
+
+ if class_data_root is not None:
+ self.class_data_root = Path(class_data_root)
+ self.class_data_root.mkdir(parents=True, exist_ok=True)
+ self.class_images_path = list(self.class_data_root.iterdir())
+
+ self.original_sizes_class_imgs = []
+ self.crop_top_lefts_class_imgs = []
+ self.pixel_values_class_imgs = []
+ self.class_images = [Image.open(path) for path in self.class_images_path]
+ for image in self.class_images:
+ image = exif_transpose(image)
+ if not image.mode == "RGB":
+ image = image.convert("RGB")
+ self.original_sizes_class_imgs.append((image.height, image.width))
+ image = train_resize(image)
+ if args.random_flip and random.random() < 0.5:
+ # flip
+ image = train_flip(image)
+ if args.center_crop:
+ y1 = max(0, int(round((image.height - args.resolution) / 2.0)))
+ x1 = max(0, int(round((image.width - args.resolution) / 2.0)))
+ image = train_crop(image)
+ else:
+ y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution))
+ image = crop(image, y1, x1, h, w)
+ crop_top_left = (y1, x1)
+ self.crop_top_lefts_class_imgs.append(crop_top_left)
+ image = train_transforms(image)
+ self.pixel_values_class_imgs.append(image)
+
+ if class_num is not None:
+ self.num_class_images = min(len(self.class_images_path), class_num)
+ else:
+ self.num_class_images = len(self.class_images_path)
+ self._length = max(self.num_class_images, self.num_instance_images)
+ else:
+ self.class_data_root = None
+
+ self.image_transforms = transforms.Compose(
+ [
+ transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
+ transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+ )
+
+ def __len__(self):
+ return self._length
+
+ def __getitem__(self, index):
+ example = {}
+ example["instance_images"] = self.pixel_values[index % self.num_instance_images]
+ example["original_size"] = self.original_sizes[index % self.num_instance_images]
+ example["crop_top_left"] = self.crop_top_lefts[index % self.num_instance_images]
+
+ if self.custom_instance_prompts:
+ caption = self.custom_instance_prompts[index % self.num_instance_images]
+ if caption:
+ if self.train_text_encoder_ti:
+ # replace instances of --token_abstraction in caption with the new tokens: "" etc.
+ for token_abs, token_replacement in self.token_abstraction_dict.items():
+ caption = caption.replace(token_abs, "".join(token_replacement))
+ example["instance_prompt"] = caption
+ else:
+ example["instance_prompt"] = self.instance_prompt
+
+ else: # custom prompts were provided, but length does not match size of image dataset
+ example["instance_prompt"] = self.instance_prompt
+
+ if self.class_data_root:
+ example["class_prompt"] = self.class_prompt
+ example["class_images"] = self.pixel_values_class_imgs[index % self.num_class_images]
+ example["class_original_size"] = self.original_sizes_class_imgs[index % self.num_class_images]
+ example["class_crop_top_left"] = self.crop_top_lefts_class_imgs[index % self.num_class_images]
+
+ return example
+
+
+def collate_fn(examples, with_prior_preservation=False):
+ pixel_values = [example["instance_images"] for example in examples]
+ prompts = [example["instance_prompt"] for example in examples]
+ original_sizes = [example["original_size"] for example in examples]
+ crop_top_lefts = [example["crop_top_left"] for example in examples]
+
+ # Concat class and instance examples for prior preservation.
+ # We do this to avoid doing two forward passes.
+ if with_prior_preservation:
+ pixel_values += [example["class_images"] for example in examples]
+ prompts += [example["class_prompt"] for example in examples]
+ original_sizes += [example["class_original_size"] for example in examples]
+ crop_top_lefts += [example["class_crop_top_left"] for example in examples]
+
+ pixel_values = torch.stack(pixel_values)
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
+
+ batch = {
+ "pixel_values": pixel_values,
+ "prompts": prompts,
+ "original_sizes": original_sizes,
+ "crop_top_lefts": crop_top_lefts,
+ }
+ return batch
+
+
+class PromptDataset(Dataset):
+ """A simple dataset to prepare the prompts to generate class images on multiple GPUs."""
+
+ def __init__(self, prompt, num_samples):
+ self.prompt = prompt
+ self.num_samples = num_samples
+
+ def __len__(self):
+ return self.num_samples
+
+ def __getitem__(self, index):
+ example = {}
+ example["prompt"] = self.prompt
+ example["index"] = index
+ return example
+
+
+def tokenize_prompt(tokenizer, prompt, add_special_tokens=False):
+ text_inputs = tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ add_special_tokens=add_special_tokens,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ return text_input_ids
+
+
+# Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt
+def encode_prompt(text_encoders, tokenizers, prompt, text_input_ids_list=None, clip_skip=None):
+ prompt_embeds_list = []
+
+ for i, text_encoder in enumerate(text_encoders):
+ if tokenizers is not None:
+ tokenizer = tokenizers[i]
+ text_input_ids = tokenize_prompt(tokenizer, prompt)
+ else:
+ assert text_input_ids_list is not None
+ text_input_ids = text_input_ids_list[i]
+
+ prompt_embeds = text_encoder(
+ text_input_ids.to(text_encoder.device), output_hidden_states=True, return_dict=False
+ )
+
+ # We are only ALWAYS interested in the pooled output of the final text encoder
+ pooled_prompt_embeds = prompt_embeds[0]
+ if clip_skip is None:
+ prompt_embeds = prompt_embeds[-1][-2]
+ else:
+ # "2" because SDXL always indexes from the penultimate layer.
+ prompt_embeds = prompt_embeds[-1][-(clip_skip + 2)]
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
+ prompt_embeds_list.append(prompt_embeds)
+
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
+ pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
+ return prompt_embeds, pooled_prompt_embeds
+
+
+def main(args):
+ if args.report_to == "wandb" and args.hub_token is not None:
+ raise ValueError(
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
+ " Please use `huggingface-cli login` to authenticate with the Hub."
+ )
+
+ if args.do_edm_style_training and args.snr_gamma is not None:
+ raise ValueError("Min-SNR formulation is not supported when conducting EDM-style training.")
+
+ if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
+ # due to pytorch#99272, MPS does not yet support bfloat16.
+ raise ValueError(
+ "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
+ )
+
+ logging_dir = Path(args.output_dir, args.logging_dir)
+
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+ kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ kwargs_handlers=[kwargs],
+ )
+
+ # Disable AMP for MPS.
+ if torch.backends.mps.is_available():
+ accelerator.native_amp = False
+
+ if args.report_to == "wandb":
+ if not is_wandb_available():
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Generate class images if prior preservation is enabled.
+ if args.with_prior_preservation:
+ class_images_dir = Path(args.class_data_dir)
+ if not class_images_dir.exists():
+ class_images_dir.mkdir(parents=True)
+ cur_class_images = len(list(class_images_dir.iterdir()))
+
+ if cur_class_images < args.num_class_images:
+ has_supported_fp16_accelerator = torch.cuda.is_available() or torch.backends.mps.is_available()
+ torch_dtype = torch.float16 if has_supported_fp16_accelerator else torch.float32
+ if args.prior_generation_precision == "fp32":
+ torch_dtype = torch.float32
+ elif args.prior_generation_precision == "fp16":
+ torch_dtype = torch.float16
+ elif args.prior_generation_precision == "bf16":
+ torch_dtype = torch.bfloat16
+ pipeline = StableDiffusionXLPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ torch_dtype=torch_dtype,
+ revision=args.revision,
+ variant=args.variant,
+ )
+ pipeline.set_progress_bar_config(disable=True)
+
+ num_new_images = args.num_class_images - cur_class_images
+ logger.info(f"Number of class images to sample: {num_new_images}.")
+
+ sample_dataset = PromptDataset(args.class_prompt, num_new_images)
+ sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
+
+ sample_dataloader = accelerator.prepare(sample_dataloader)
+ pipeline.to(accelerator.device)
+
+ for example in tqdm(
+ sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
+ ):
+ images = pipeline(example["prompt"]).images
+
+ for i, image in enumerate(images):
+ hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest()
+ image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
+ image.save(image_filename)
+
+ del pipeline
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ model_id = args.hub_model_id or Path(args.output_dir).name
+ repo_id = None
+ if args.push_to_hub:
+ repo_id = create_repo(repo_id=model_id, exist_ok=True, token=args.hub_token).repo_id
+
+ # Load the tokenizers
+ tokenizer_one = AutoTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="tokenizer",
+ revision=args.revision,
+ variant=args.variant,
+ use_fast=False,
+ )
+ tokenizer_two = AutoTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="tokenizer_2",
+ revision=args.revision,
+ variant=args.variant,
+ use_fast=False,
+ )
+
+ # import correct text encoder classes
+ text_encoder_cls_one = import_model_class_from_model_name_or_path(
+ args.pretrained_model_name_or_path, args.revision
+ )
+ text_encoder_cls_two = import_model_class_from_model_name_or_path(
+ args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2"
+ )
+
+ # Load scheduler and models
+ scheduler_type = determine_scheduler_type(args.pretrained_model_name_or_path, args.revision)
+ if "EDM" in scheduler_type:
+ args.do_edm_style_training = True
+ noise_scheduler = EDMEulerScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
+ logger.info("Performing EDM-style training!")
+ elif args.do_edm_style_training:
+ noise_scheduler = EulerDiscreteScheduler.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="scheduler"
+ )
+ logger.info("Performing EDM-style training!")
+ else:
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
+
+ text_encoder_one = text_encoder_cls_one.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
+ )
+ text_encoder_two = text_encoder_cls_two.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant
+ )
+ vae_path = (
+ args.pretrained_model_name_or_path
+ if args.pretrained_vae_model_name_or_path is None
+ else args.pretrained_vae_model_name_or_path
+ )
+ vae = AutoencoderKL.from_pretrained(
+ vae_path,
+ subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
+ revision=args.revision,
+ variant=args.variant,
+ )
+ latents_mean = latents_std = None
+ if hasattr(vae.config, "latents_mean") and vae.config.latents_mean is not None:
+ latents_mean = torch.tensor(vae.config.latents_mean).view(1, 4, 1, 1)
+ if hasattr(vae.config, "latents_std") and vae.config.latents_std is not None:
+ latents_std = torch.tensor(vae.config.latents_std).view(1, 4, 1, 1)
+
+ unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
+ )
+
+ if args.train_text_encoder_ti:
+ # we parse the provided token identifier (or identifiers) into a list. s.t. - "TOK" -> ["TOK"], "TOK,
+ # TOK2" -> ["TOK", "TOK2"] etc.
+ token_abstraction_list = "".join(args.token_abstraction.split()).split(",")
+ logger.info(f"list of token identifiers: {token_abstraction_list}")
+
+ token_abstraction_dict = {}
+ token_idx = 0
+ for i, token in enumerate(token_abstraction_list):
+ token_abstraction_dict[token] = [
+ f"" for j in range(args.num_new_tokens_per_abstraction)
+ ]
+ token_idx += args.num_new_tokens_per_abstraction - 1
+
+ # replace instances of --token_abstraction in --instance_prompt with the new tokens: "" etc.
+ for token_abs, token_replacement in token_abstraction_dict.items():
+ args.instance_prompt = args.instance_prompt.replace(token_abs, "".join(token_replacement))
+ if args.with_prior_preservation:
+ args.class_prompt = args.class_prompt.replace(token_abs, "".join(token_replacement))
+ if args.validation_prompt:
+ args.validation_prompt = args.validation_prompt.replace(token_abs, "".join(token_replacement))
+ print("validation prompt:", args.validation_prompt)
+ # initialize the new tokens for textual inversion
+ embedding_handler = TokenEmbeddingsHandler(
+ [text_encoder_one, text_encoder_two], [tokenizer_one, tokenizer_two]
+ )
+ inserting_toks = []
+ for new_tok in token_abstraction_dict.values():
+ inserting_toks.extend(new_tok)
+ embedding_handler.initialize_new_tokens(inserting_toks=inserting_toks)
+
+ # We only train the additional adapter LoRA layers
+ vae.requires_grad_(False)
+ text_encoder_one.requires_grad_(False)
+ text_encoder_two.requires_grad_(False)
+ unet.requires_grad_(False)
+
+ # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
+ # as these weights are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:
+ # due to pytorch#99272, MPS does not yet support bfloat16.
+ raise ValueError(
+ "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
+ )
+
+ # Move unet, vae and text_encoder to device and cast to weight_dtype
+ unet.to(accelerator.device, dtype=weight_dtype)
+
+ # The VAE is always in float32 to avoid NaN losses.
+ vae.to(accelerator.device, dtype=torch.float32)
+
+ text_encoder_one.to(accelerator.device, dtype=weight_dtype)
+ text_encoder_two.to(accelerator.device, dtype=weight_dtype)
+
+ if args.enable_xformers_memory_efficient_attention:
+ if is_xformers_available():
+ import xformers
+
+ xformers_version = version.parse(xformers.__version__)
+ if xformers_version == version.parse("0.0.16"):
+ logger.warning(
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, "
+ "please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
+ )
+ unet.enable_xformers_memory_efficient_attention()
+ else:
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
+
+ if args.gradient_checkpointing:
+ unet.enable_gradient_checkpointing()
+ if args.train_text_encoder:
+ text_encoder_one.gradient_checkpointing_enable()
+ text_encoder_two.gradient_checkpointing_enable()
+
+ # now we will add new LoRA weights to the attention layers
+
+ if args.use_blora:
+ # if using B-LoRA, the targeted blocks to train are automatically set
+ target_modules = get_unet_lora_target_modules(unet, use_blora=True)
+ elif args.lora_unet_blocks:
+ # if training specific unet blocks not in the B-LoRA scheme
+ target_blocks_list = "".join(args.lora_unet_blocks.split()).split(",")
+ logger.info(f"list of unet blocks to train: {target_blocks_list}")
+ target_modules = get_unet_lora_target_modules(unet, use_blora=False, target_blocks=target_blocks_list)
+ else:
+ target_modules = ["to_k", "to_q", "to_v", "to_out.0"]
+
+ unet_lora_config = LoraConfig(
+ r=args.rank,
+ use_dora=args.use_dora,
+ lora_alpha=args.rank,
+ init_lora_weights="gaussian",
+ target_modules=target_modules,
+ )
+ unet.add_adapter(unet_lora_config)
+
+ # The text encoder comes from 🤗 transformers, so we cannot directly modify it.
+ # So, instead, we monkey-patch the forward calls of its attention-blocks.
+ if args.train_text_encoder:
+ text_lora_config = LoraConfig(
+ r=args.rank,
+ use_dora=args.use_dora,
+ lora_alpha=args.rank,
+ init_lora_weights="gaussian",
+ target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
+ )
+ text_encoder_one.add_adapter(text_lora_config)
+ text_encoder_two.add_adapter(text_lora_config)
+
+ # if we use textual inversion, we freeze all parameters except for the token embeddings
+ # in text encoder
+ elif args.train_text_encoder_ti:
+ text_lora_parameters_one = []
+ for name, param in text_encoder_one.named_parameters():
+ if "token_embedding" in name:
+ # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
+ param.data = param.to(dtype=torch.float32)
+ param.requires_grad = True
+ text_lora_parameters_one.append(param)
+ else:
+ param.requires_grad = False
+ text_lora_parameters_two = []
+ for name, param in text_encoder_two.named_parameters():
+ if "token_embedding" in name:
+ # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
+ param.data = param.to(dtype=torch.float32)
+ param.requires_grad = True
+ text_lora_parameters_two.append(param)
+ else:
+ param.requires_grad = False
+
+ def unwrap_model(model):
+ model = accelerator.unwrap_model(model)
+ model = model._orig_mod if is_compiled_module(model) else model
+ return model
+
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
+ def save_model_hook(models, weights, output_dir):
+ if accelerator.is_main_process:
+ # there are only two options here. Either are just the unet attn processor layers
+ # or there are the unet and text encoder atten layers
+ unet_lora_layers_to_save = None
+ text_encoder_one_lora_layers_to_save = None
+ text_encoder_two_lora_layers_to_save = None
+
+ for model in models:
+ if isinstance(model, type(unwrap_model(unet))):
+ unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model))
+ elif isinstance(model, type(unwrap_model(text_encoder_one))):
+ if args.train_text_encoder:
+ text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers(
+ get_peft_model_state_dict(model)
+ )
+ elif isinstance(model, type(unwrap_model(text_encoder_two))):
+ if args.train_text_encoder:
+ text_encoder_two_lora_layers_to_save = convert_state_dict_to_diffusers(
+ get_peft_model_state_dict(model)
+ )
+ else:
+ raise ValueError(f"unexpected save model: {model.__class__}")
+
+ # make sure to pop weight so that corresponding model is not saved again
+ weights.pop()
+
+ StableDiffusionXLPipeline.save_lora_weights(
+ output_dir,
+ unet_lora_layers=unet_lora_layers_to_save,
+ text_encoder_lora_layers=text_encoder_one_lora_layers_to_save,
+ text_encoder_2_lora_layers=text_encoder_two_lora_layers_to_save,
+ )
+ if args.train_text_encoder_ti:
+ embedding_handler.save_embeddings(f"{args.output_dir}/{Path(args.output_dir).name}_emb.safetensors")
+
+ def load_model_hook(models, input_dir):
+ unet_ = None
+ text_encoder_one_ = None
+ text_encoder_two_ = None
+
+ while len(models) > 0:
+ model = models.pop()
+
+ if isinstance(model, type(unwrap_model(unet))):
+ unet_ = model
+ elif isinstance(model, type(unwrap_model(text_encoder_one))):
+ text_encoder_one_ = model
+ elif isinstance(model, type(unwrap_model(text_encoder_two))):
+ text_encoder_two_ = model
+ else:
+ raise ValueError(f"unexpected save model: {model.__class__}")
+
+ lora_state_dict, network_alphas = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir)
+
+ unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")}
+ unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
+ incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
+ if incompatible_keys is not None:
+ # check only for unexpected keys
+ unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
+ if unexpected_keys:
+ logger.warning(
+ f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
+ f" {unexpected_keys}. "
+ )
+
+ if args.train_text_encoder:
+ # Do we need to call `scale_lora_layers()` here?
+ _set_state_dict_into_text_encoder(lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_one_)
+
+ _set_state_dict_into_text_encoder(
+ lora_state_dict, prefix="text_encoder_2.", text_encoder=text_encoder_two_
+ )
+
+ # Make sure the trainable params are in float32. This is again needed since the base models
+ # are in `weight_dtype`. More details:
+ # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
+ if args.mixed_precision == "fp16":
+ models = [unet_]
+ if args.train_text_encoder:
+ models.extend([text_encoder_one_, text_encoder_two_])
+ # only upcast trainable parameters (LoRA) into fp32
+ cast_training_params(models)
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32 and torch.cuda.is_available():
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ # Make sure the trainable params are in float32.
+ if args.mixed_precision == "fp16":
+ models = [unet]
+ if args.train_text_encoder:
+ models.extend([text_encoder_one, text_encoder_two])
+
+ # only upcast trainable parameters (LoRA) into fp32
+ cast_training_params(models, dtype=torch.float32)
+
+ unet_lora_parameters = list(filter(lambda p: p.requires_grad, unet.parameters()))
+
+ if args.train_text_encoder:
+ text_lora_parameters_one = list(filter(lambda p: p.requires_grad, text_encoder_one.parameters()))
+ text_lora_parameters_two = list(filter(lambda p: p.requires_grad, text_encoder_two.parameters()))
+
+ # If neither --train_text_encoder nor --train_text_encoder_ti, text_encoders remain frozen during training
+ freeze_text_encoder = not (args.train_text_encoder or args.train_text_encoder_ti)
+
+ # Optimization parameters
+ unet_lora_parameters_with_lr = {"params": unet_lora_parameters, "lr": args.learning_rate}
+ if not freeze_text_encoder:
+ # different learning rate for text encoder and unet
+ text_lora_parameters_one_with_lr = {
+ "params": text_lora_parameters_one,
+ "weight_decay": args.adam_weight_decay_text_encoder
+ if args.adam_weight_decay_text_encoder
+ else args.adam_weight_decay,
+ "lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate,
+ }
+ text_lora_parameters_two_with_lr = {
+ "params": text_lora_parameters_two,
+ "weight_decay": args.adam_weight_decay_text_encoder
+ if args.adam_weight_decay_text_encoder
+ else args.adam_weight_decay,
+ "lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate,
+ }
+ params_to_optimize = [
+ unet_lora_parameters_with_lr,
+ text_lora_parameters_one_with_lr,
+ text_lora_parameters_two_with_lr,
+ ]
+ else:
+ params_to_optimize = [unet_lora_parameters_with_lr]
+
+ # Optimizer creation
+ if not (args.optimizer.lower() == "prodigy" or args.optimizer.lower() == "adamw"):
+ logger.warning(
+ f"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy]."
+ "Defaulting to adamW"
+ )
+ args.optimizer = "adamw"
+
+ if args.use_8bit_adam and not args.optimizer.lower() == "adamw":
+ logger.warning(
+ f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was "
+ f"set to {args.optimizer.lower()}"
+ )
+
+ if args.optimizer.lower() == "adamw":
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
+ )
+
+ optimizer_class = bnb.optim.AdamW8bit
+ else:
+ optimizer_class = torch.optim.AdamW
+
+ optimizer = optimizer_class(
+ params_to_optimize,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ if args.optimizer.lower() == "prodigy":
+ try:
+ import prodigyopt
+ except ImportError:
+ raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`")
+
+ optimizer_class = prodigyopt.Prodigy
+
+ if args.learning_rate <= 0.1:
+ logger.warning(
+ "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0"
+ )
+ if args.train_text_encoder and args.text_encoder_lr:
+ logger.warning(
+ f"Learning rates were provided both for the unet and the text encoder- e.g. text_encoder_lr:"
+ f" {args.text_encoder_lr} and learning_rate: {args.learning_rate}. "
+ f"When using prodigy only learning_rate is used as the initial learning rate."
+ )
+ # changes the learning rate of text_encoder_parameters_one and text_encoder_parameters_two to be
+ # --learning_rate
+ params_to_optimize[1]["lr"] = args.learning_rate
+ params_to_optimize[2]["lr"] = args.learning_rate
+
+ optimizer = optimizer_class(
+ params_to_optimize,
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ beta3=args.prodigy_beta3,
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ decouple=args.prodigy_decouple,
+ use_bias_correction=args.prodigy_use_bias_correction,
+ safeguard_warmup=args.prodigy_safeguard_warmup,
+ )
+
+ # Dataset and DataLoaders creation:
+ train_dataset = DreamBoothDataset(
+ instance_data_root=args.instance_data_dir,
+ instance_prompt=args.instance_prompt,
+ class_prompt=args.class_prompt,
+ train_text_encoder_ti=args.train_text_encoder_ti,
+ class_data_root=args.class_data_dir if args.with_prior_preservation else None,
+ token_abstraction_dict=token_abstraction_dict if args.train_text_encoder_ti else None,
+ class_num=args.num_class_images,
+ size=args.resolution,
+ repeats=args.repeats,
+ center_crop=args.center_crop,
+ )
+
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset,
+ batch_size=args.train_batch_size,
+ shuffle=True,
+ collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),
+ num_workers=args.dataloader_num_workers,
+ )
+
+ # Computes additional embeddings/ids required by the SDXL UNet.
+ # regular text embeddings (when `train_text_encoder` is not True)
+ # pooled text embeddings
+ # time ids
+
+ def compute_time_ids(crops_coords_top_left, original_size=None):
+ # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
+ target_size = (args.resolution, args.resolution)
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
+ add_time_ids = torch.tensor([add_time_ids])
+ add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype)
+ return add_time_ids
+
+ if not args.train_text_encoder:
+ tokenizers = [tokenizer_one, tokenizer_two]
+ text_encoders = [text_encoder_one, text_encoder_two]
+
+ def compute_text_embeddings(prompt, text_encoders, tokenizers, clip_skip):
+ with torch.no_grad():
+ prompt_embeds, pooled_prompt_embeds = encode_prompt(text_encoders, tokenizers, prompt, clip_skip)
+ prompt_embeds = prompt_embeds.to(accelerator.device)
+ pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device)
+ return prompt_embeds, pooled_prompt_embeds
+
+ # If no type of tuning is done on the text_encoder and custom instance prompts are NOT
+ # provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid
+ # the redundant encoding.
+ if freeze_text_encoder and not train_dataset.custom_instance_prompts:
+ instance_prompt_hidden_states, instance_pooled_prompt_embeds = compute_text_embeddings(
+ args.instance_prompt, text_encoders, tokenizers, args.clip_skip
+ )
+
+ # Handle class prompt for prior-preservation.
+ if args.with_prior_preservation:
+ if freeze_text_encoder:
+ class_prompt_hidden_states, class_pooled_prompt_embeds = compute_text_embeddings(
+ args.class_prompt, text_encoders, tokenizers
+ )
+
+ # Clear the memory here
+ if freeze_text_encoder and not train_dataset.custom_instance_prompts:
+ del tokenizers, text_encoders
+ gc.collect()
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+
+ # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
+ # pack the statically computed variables appropriately here. This is so that we don't
+ # have to pass them to the dataloader.
+
+ # if --train_text_encoder_ti we need add_special_tokens to be True fo textual inversion
+ add_special_tokens = True if args.train_text_encoder_ti else False
+
+ if not train_dataset.custom_instance_prompts:
+ if freeze_text_encoder:
+ prompt_embeds = instance_prompt_hidden_states
+ unet_add_text_embeds = instance_pooled_prompt_embeds
+ if args.with_prior_preservation:
+ prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0)
+ unet_add_text_embeds = torch.cat([unet_add_text_embeds, class_pooled_prompt_embeds], dim=0)
+ # if we're optimizing the text encoder (both if instance prompt is used for all images or custom prompts) we need to tokenize and encode the
+ # batch prompts on all training steps
+ else:
+ tokens_one = tokenize_prompt(tokenizer_one, args.instance_prompt, add_special_tokens)
+ tokens_two = tokenize_prompt(tokenizer_two, args.instance_prompt, add_special_tokens)
+ if args.with_prior_preservation:
+ class_tokens_one = tokenize_prompt(tokenizer_one, args.class_prompt, add_special_tokens)
+ class_tokens_two = tokenize_prompt(tokenizer_two, args.class_prompt, add_special_tokens)
+ tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0)
+ tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0)
+
+ if args.cache_latents:
+ latents_cache = []
+ for batch in tqdm(train_dataloader, desc="Caching latents"):
+ with torch.no_grad():
+ batch["pixel_values"] = batch["pixel_values"].to(
+ accelerator.device, non_blocking=True, dtype=torch.float32
+ )
+ latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist)
+
+ if args.validation_prompt is None:
+ del vae
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+
+ # Scheduler and math around the number of training steps.
+ # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
+ num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
+ if args.max_train_steps is None:
+ len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
+ num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
+ num_training_steps_for_scheduler = (
+ args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes
+ )
+ else:
+ num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=num_warmup_steps_for_scheduler,
+ num_training_steps=num_training_steps_for_scheduler,
+ num_cycles=args.lr_num_cycles,
+ power=args.lr_power,
+ )
+
+ # Prepare everything with our `accelerator`.
+ if not freeze_text_encoder:
+ unet, text_encoder_one, text_encoder_two, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ unet, text_encoder_one, text_encoder_two, optimizer, train_dataloader, lr_scheduler
+ )
+ else:
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ unet, optimizer, train_dataloader, lr_scheduler
+ )
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ if num_training_steps_for_scheduler != args.max_train_steps * accelerator.num_processes:
+ logger.warning(
+ f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
+ f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
+ f"This inconsistency may result in the learning rate scheduler not functioning properly."
+ )
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ tracker_name = (
+ "dreambooth-lora-sd-xl"
+ if "playground" not in args.pretrained_model_name_or_path
+ else "dreambooth-lora-playground"
+ )
+ accelerator.init_trackers(tracker_name, config=vars(args))
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the mos recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ initial_global_step = 0
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ initial_global_step = global_step
+ first_epoch = global_step // num_update_steps_per_epoch
+
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
+
+ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
+ sigmas = noise_scheduler.sigmas.to(device=accelerator.device, dtype=dtype)
+ schedule_timesteps = noise_scheduler.timesteps.to(accelerator.device)
+ timesteps = timesteps.to(accelerator.device)
+
+ step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
+
+ sigma = sigmas[step_indices].flatten()
+ while len(sigma.shape) < n_dim:
+ sigma = sigma.unsqueeze(-1)
+ return sigma
+
+ if args.train_text_encoder:
+ num_train_epochs_text_encoder = int(args.train_text_encoder_frac * args.num_train_epochs)
+ elif args.train_text_encoder_ti: # args.train_text_encoder_ti
+ num_train_epochs_text_encoder = int(args.train_text_encoder_ti_frac * args.num_train_epochs)
+ # flag used for textual inversion
+ pivoted = False
+ for epoch in range(first_epoch, args.num_train_epochs):
+ unet.train()
+ # if performing any kind of optimization of text_encoder params
+ if args.train_text_encoder or args.train_text_encoder_ti:
+ if epoch == num_train_epochs_text_encoder:
+ print("PIVOT HALFWAY", epoch)
+ # stopping optimization of text_encoder params
+ # this flag is used to reset the optimizer to optimize only on unet params
+ pivoted = True
+
+ else:
+ # still optimizing the text encoder
+ text_encoder_one.train()
+ text_encoder_two.train()
+ # set top parameter requires_grad = True for gradient checkpointing works
+ if args.train_text_encoder:
+ accelerator.unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True)
+ accelerator.unwrap_model(text_encoder_two).text_model.embeddings.requires_grad_(True)
+
+ for step, batch in enumerate(train_dataloader):
+ if pivoted:
+ # stopping optimization of text_encoder params
+ # re setting the optimizer to optimize only on unet params
+ optimizer.param_groups[1]["lr"] = 0.0
+ optimizer.param_groups[2]["lr"] = 0.0
+
+ with accelerator.accumulate(unet):
+ prompts = batch["prompts"]
+ # encode batch prompts when custom prompts are provided for each image -
+ if train_dataset.custom_instance_prompts:
+ if freeze_text_encoder:
+ prompt_embeds, unet_add_text_embeds = compute_text_embeddings(
+ prompts, text_encoders, tokenizers, args.clip_skip
+ )
+
+ else:
+ tokens_one = tokenize_prompt(tokenizer_one, prompts, add_special_tokens)
+ tokens_two = tokenize_prompt(tokenizer_two, prompts, add_special_tokens)
+
+ if args.cache_latents:
+ model_input = latents_cache[step].sample()
+ else:
+ pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
+ model_input = vae.encode(pixel_values).latent_dist.sample()
+
+ if latents_mean is None and latents_std is None:
+ model_input = model_input * vae.config.scaling_factor
+ if args.pretrained_vae_model_name_or_path is None:
+ model_input = model_input.to(weight_dtype)
+ else:
+ latents_mean = latents_mean.to(device=model_input.device, dtype=model_input.dtype)
+ latents_std = latents_std.to(device=model_input.device, dtype=model_input.dtype)
+ model_input = (model_input - latents_mean) * vae.config.scaling_factor / latents_std
+ model_input = model_input.to(dtype=weight_dtype)
+
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(model_input)
+ if args.noise_offset:
+ # https://www.crosslabs.org//blog/diffusion-with-offset-noise
+ noise += args.noise_offset * torch.randn(
+ (model_input.shape[0], model_input.shape[1], 1, 1), device=model_input.device
+ )
+
+ bsz = model_input.shape[0]
+
+ # Sample a random timestep for each image
+ if not args.do_edm_style_training:
+ timesteps = torch.randint(
+ 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device
+ )
+ timesteps = timesteps.long()
+ else:
+ # in EDM formulation, the model is conditioned on the pre-conditioned noise levels
+ # instead of discrete timesteps, so here we sample indices to get the noise levels
+ # from `scheduler.timesteps`
+ indices = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,))
+ timesteps = noise_scheduler.timesteps[indices].to(device=model_input.device)
+
+ # Add noise to the model input according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)
+ # For EDM-style training, we first obtain the sigmas based on the continuous timesteps.
+ # We then precondition the final model inputs based on these sigmas instead of the timesteps.
+ # Follow: Section 5 of https://arxiv.org/abs/2206.00364.
+ if args.do_edm_style_training:
+ sigmas = get_sigmas(timesteps, len(noisy_model_input.shape), noisy_model_input.dtype)
+ if "EDM" in scheduler_type:
+ inp_noisy_latents = noise_scheduler.precondition_inputs(noisy_model_input, sigmas)
+ else:
+ inp_noisy_latents = noisy_model_input / ((sigmas**2 + 1) ** 0.5)
+
+ # time ids
+ add_time_ids = torch.cat(
+ [
+ compute_time_ids(original_size=s, crops_coords_top_left=c)
+ for s, c in zip(batch["original_sizes"], batch["crop_top_lefts"])
+ ]
+ )
+
+ # Calculate the elements to repeat depending on the use of prior-preservation and custom captions.
+ if not train_dataset.custom_instance_prompts:
+ elems_to_repeat_text_embeds = bsz // 2 if args.with_prior_preservation else bsz
+
+ else:
+ elems_to_repeat_text_embeds = 1
+
+ # Predict the noise residual
+ if freeze_text_encoder:
+ unet_added_conditions = {
+ "time_ids": add_time_ids,
+ "text_embeds": unet_add_text_embeds.repeat(elems_to_repeat_text_embeds, 1),
+ }
+ prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1)
+ model_pred = unet(
+ inp_noisy_latents if args.do_edm_style_training else noisy_model_input,
+ timesteps,
+ prompt_embeds_input,
+ added_cond_kwargs=unet_added_conditions,
+ return_dict=False,
+ )[0]
+ else:
+ unet_added_conditions = {"time_ids": add_time_ids}
+ prompt_embeds, pooled_prompt_embeds = encode_prompt(
+ text_encoders=[text_encoder_one, text_encoder_two],
+ tokenizers=None,
+ prompt=None,
+ text_input_ids_list=[tokens_one, tokens_two],
+ clip_skip=args.clip_skip,
+ )
+ unet_added_conditions.update(
+ {"text_embeds": pooled_prompt_embeds.repeat(elems_to_repeat_text_embeds, 1)}
+ )
+ prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1)
+ model_pred = unet(
+ inp_noisy_latents if args.do_edm_style_training else noisy_model_input,
+ timesteps,
+ prompt_embeds_input,
+ added_cond_kwargs=unet_added_conditions,
+ return_dict=False,
+ )[0]
+
+ weighting = None
+ if args.do_edm_style_training:
+ # Similar to the input preconditioning, the model predictions are also preconditioned
+ # on noised model inputs (before preconditioning) and the sigmas.
+ # Follow: Section 5 of https://arxiv.org/abs/2206.00364.
+ if "EDM" in scheduler_type:
+ model_pred = noise_scheduler.precondition_outputs(noisy_model_input, model_pred, sigmas)
+ else:
+ if noise_scheduler.config.prediction_type == "epsilon":
+ model_pred = model_pred * (-sigmas) + noisy_model_input
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ model_pred = model_pred * (-sigmas / (sigmas**2 + 1) ** 0.5) + (
+ noisy_model_input / (sigmas**2 + 1)
+ )
+ # We are not doing weighting here because it tends result in numerical problems.
+ # See: https://github.com/huggingface/diffusers/pull/7126#issuecomment-1968523051
+ # There might be other alternatives for weighting as well:
+ # https://github.com/huggingface/diffusers/pull/7126#discussion_r1505404686
+ if "EDM" not in scheduler_type:
+ weighting = (sigmas**-2.0).float()
+
+ # Get the target for loss depending on the prediction type
+ if noise_scheduler.config.prediction_type == "epsilon":
+ target = model_input if args.do_edm_style_training else noise
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ target = (
+ model_input
+ if args.do_edm_style_training
+ else noise_scheduler.get_velocity(model_input, noise, timesteps)
+ )
+ else:
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
+
+ if args.with_prior_preservation:
+ # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
+ model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
+ target, target_prior = torch.chunk(target, 2, dim=0)
+
+ # Compute prior loss
+ if weighting is not None:
+ prior_loss = torch.mean(
+ (weighting.float() * (model_pred_prior.float() - target_prior.float()) ** 2).reshape(
+ target_prior.shape[0], -1
+ ),
+ 1,
+ )
+ prior_loss = prior_loss.mean()
+ else:
+ prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
+
+ if args.snr_gamma is None:
+ if weighting is not None:
+ loss = torch.mean(
+ (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(
+ target.shape[0], -1
+ ),
+ 1,
+ )
+ loss = loss.mean()
+ else:
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
+ else:
+ # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
+ # Since we predict the noise instead of x_0, the original formulation is slightly changed.
+ # This is discussed in Section 4.2 of the same paper.
+
+ if args.with_prior_preservation:
+ # if we're using prior preservation, we calc snr for instance loss only -
+ # and hence only need timesteps corresponding to instance images
+ snr_timesteps, _ = torch.chunk(timesteps, 2, dim=0)
+ else:
+ snr_timesteps = timesteps
+
+ snr = compute_snr(noise_scheduler, snr_timesteps)
+ base_weight = (
+ torch.stack([snr, args.snr_gamma * torch.ones_like(snr_timesteps)], dim=1).min(dim=1)[0] / snr
+ )
+
+ if noise_scheduler.config.prediction_type == "v_prediction":
+ # Velocity objective needs to be floored to an SNR weight of one.
+ mse_loss_weights = base_weight + 1
+ else:
+ # Epsilon and sample both use the same loss weights.
+ mse_loss_weights = base_weight
+
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
+ loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
+ loss = loss.mean()
+
+ if args.with_prior_preservation:
+ # Add the prior loss to the instance loss.
+ loss = loss + args.prior_loss_weight * prior_loss
+
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ params_to_clip = (
+ itertools.chain(unet_lora_parameters, text_lora_parameters_one, text_lora_parameters_two)
+ if (args.train_text_encoder or args.train_text_encoder_ti)
+ else unet_lora_parameters
+ )
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ # every step, we reset the embeddings to the original embeddings.
+ if args.train_text_encoder_ti:
+ embedding_handler.retract_embeddings()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ global_step += 1
+
+ if accelerator.is_main_process:
+ if global_step % args.checkpointing_steps == 0:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+ accelerator.log(logs, step=global_step)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ if accelerator.is_main_process:
+ if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
+ # create pipeline
+ if freeze_text_encoder:
+ text_encoder_one = text_encoder_cls_one.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="text_encoder",
+ revision=args.revision,
+ variant=args.variant,
+ )
+ text_encoder_two = text_encoder_cls_two.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="text_encoder_2",
+ revision=args.revision,
+ variant=args.variant,
+ )
+ pipeline = StableDiffusionXLPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ vae=vae,
+ tokenizer=tokenizer_one,
+ tokenizer_2=tokenizer_two,
+ text_encoder=accelerator.unwrap_model(text_encoder_one),
+ text_encoder_2=accelerator.unwrap_model(text_encoder_two),
+ unet=accelerator.unwrap_model(unet),
+ revision=args.revision,
+ variant=args.variant,
+ torch_dtype=weight_dtype,
+ )
+ pipeline_args = {"prompt": args.validation_prompt}
+
+ images = log_validation(
+ pipeline,
+ args,
+ accelerator,
+ pipeline_args,
+ epoch,
+ )
+
+ # Save the lora layers
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ unet = unwrap_model(unet)
+ unet = unet.to(torch.float32)
+ unet_lora_layers = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet))
+
+ if args.train_text_encoder:
+ text_encoder_one = unwrap_model(text_encoder_one)
+ text_encoder_lora_layers = convert_state_dict_to_diffusers(
+ get_peft_model_state_dict(text_encoder_one.to(torch.float32))
+ )
+ text_encoder_two = unwrap_model(text_encoder_two)
+ text_encoder_2_lora_layers = convert_state_dict_to_diffusers(
+ get_peft_model_state_dict(text_encoder_two.to(torch.float32))
+ )
+ else:
+ text_encoder_lora_layers = None
+ text_encoder_2_lora_layers = None
+
+ StableDiffusionXLPipeline.save_lora_weights(
+ save_directory=args.output_dir,
+ unet_lora_layers=unet_lora_layers,
+ text_encoder_lora_layers=text_encoder_lora_layers,
+ text_encoder_2_lora_layers=text_encoder_2_lora_layers,
+ )
+
+ if args.train_text_encoder_ti:
+ embeddings_path = f"{args.output_dir}/{args.output_dir}_emb.safetensors"
+ embedding_handler.save_embeddings(embeddings_path)
+
+ # Final inference
+ # Load previous pipeline
+ vae = AutoencoderKL.from_pretrained(
+ vae_path,
+ subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
+ revision=args.revision,
+ variant=args.variant,
+ torch_dtype=weight_dtype,
+ )
+ pipeline = StableDiffusionXLPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ vae=vae,
+ revision=args.revision,
+ variant=args.variant,
+ torch_dtype=weight_dtype,
+ )
+
+ # load attention processors
+ pipeline.load_lora_weights(args.output_dir)
+
+ # run inference
+ images = []
+ if args.validation_prompt and args.num_validation_images > 0:
+ pipeline_args = {"prompt": args.validation_prompt, "num_inference_steps": 25}
+ images = log_validation(
+ pipeline,
+ args,
+ accelerator,
+ pipeline_args,
+ epoch,
+ is_final_validation=True,
+ )
+
+ # Convert to WebUI format
+ lora_state_dict = load_file(f"{args.output_dir}/pytorch_lora_weights.safetensors")
+ peft_state_dict = convert_all_state_dict_to_peft(lora_state_dict)
+ kohya_state_dict = convert_state_dict_to_kohya(peft_state_dict)
+ save_file(kohya_state_dict, f"{args.output_dir}/{Path(args.output_dir).name}.safetensors")
+
+ save_model_card(
+ model_id if not args.push_to_hub else repo_id,
+ use_dora=args.use_dora,
+ images=images,
+ base_model=args.pretrained_model_name_or_path,
+ train_text_encoder=args.train_text_encoder,
+ train_text_encoder_ti=args.train_text_encoder_ti,
+ token_abstraction_dict=train_dataset.token_abstraction_dict,
+ instance_prompt=args.instance_prompt,
+ validation_prompt=args.validation_prompt,
+ repo_folder=args.output_dir,
+ vae_path=args.pretrained_vae_model_name_or_path,
+ )
+
+ if args.push_to_hub:
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ main(args)
diff --git a/diffusers/examples/amused/README.md b/diffusers/examples/amused/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..1230bd8667578de596acdca110b534344e477210
--- /dev/null
+++ b/diffusers/examples/amused/README.md
@@ -0,0 +1,326 @@
+## Amused training
+
+Amused can be finetuned on simple datasets relatively cheaply and quickly. Using 8bit optimizers, lora, and gradient accumulation, amused can be finetuned with as little as 5.5 GB. Here are a set of examples for finetuning amused on some relatively simple datasets. These training recipies are aggressively oriented towards minimal resources and fast verification -- i.e. the batch sizes are quite low and the learning rates are quite high. For optimal quality, you will probably want to increase the batch sizes and decrease learning rates.
+
+All training examples use fp16 mixed precision and gradient checkpointing. We don't show 8 bit adam + lora as its about the same memory use as just using lora (bitsandbytes uses full precision optimizer states for weights below a minimum size).
+
+### Finetuning the 256 checkpoint
+
+These examples finetune on this [nouns](https://huggingface.co/datasets/m1guelpf/nouns) dataset.
+
+Example results:
+
+  
+
+
+#### Full finetuning
+
+Batch size: 8, Learning rate: 1e-4, Gives decent results in 750-1000 steps
+
+| Batch Size | Gradient Accumulation Steps | Effective Total Batch Size | Memory Used |
+|------------|-----------------------------|------------------|-------------|
+| 8 | 1 | 8 | 19.7 GB |
+| 4 | 2 | 8 | 18.3 GB |
+| 1 | 8 | 8 | 17.9 GB |
+
+```sh
+accelerate launch train_amused.py \
+ --output_dir \
+ --train_batch_size \
+ --gradient_accumulation_steps \
+ --learning_rate 1e-4 \
+ --pretrained_model_name_or_path amused/amused-256 \
+ --instance_data_dataset 'm1guelpf/nouns' \
+ --image_key image \
+ --prompt_key text \
+ --resolution 256 \
+ --mixed_precision fp16 \
+ --lr_scheduler constant \
+ --validation_prompts \
+ 'a pixel art character with square red glasses, a baseball-shaped head and a orange-colored body on a dark background' \
+ 'a pixel art character with square orange glasses, a lips-shaped head and a red-colored body on a light background' \
+ 'a pixel art character with square blue glasses, a microwave-shaped head and a purple-colored body on a sunny background' \
+ 'a pixel art character with square red glasses, a baseball-shaped head and a blue-colored body on an orange background' \
+ 'a pixel art character with square red glasses' \
+ 'a pixel art character' \
+ 'square red glasses on a pixel art character' \
+ 'square red glasses on a pixel art character with a baseball-shaped head' \
+ --max_train_steps 10000 \
+ --checkpointing_steps 500 \
+ --validation_steps 250 \
+ --gradient_checkpointing
+```
+
+#### Full finetuning + 8 bit adam
+
+Note that this training config keeps the batch size low and the learning rate high to get results fast with low resources. However, due to 8 bit adam, it will diverge eventually. If you want to train for longer, you will have to up the batch size and lower the learning rate.
+
+Batch size: 16, Learning rate: 2e-5, Gives decent results in ~750 steps
+
+| Batch Size | Gradient Accumulation Steps | Effective Total Batch Size | Memory Used |
+|------------|-----------------------------|------------------|-------------|
+| 16 | 1 | 16 | 20.1 GB |
+| 8 | 2 | 16 | 15.6 GB |
+| 1 | 16 | 16 | 10.7 GB |
+
+```sh
+accelerate launch train_amused.py \
+ --output_dir \
+ --train_batch_size \
+ --gradient_accumulation_steps \
+ --learning_rate 2e-5 \
+ --use_8bit_adam \
+ --pretrained_model_name_or_path amused/amused-256 \
+ --instance_data_dataset 'm1guelpf/nouns' \
+ --image_key image \
+ --prompt_key text \
+ --resolution 256 \
+ --mixed_precision fp16 \
+ --lr_scheduler constant \
+ --validation_prompts \
+ 'a pixel art character with square red glasses, a baseball-shaped head and a orange-colored body on a dark background' \
+ 'a pixel art character with square orange glasses, a lips-shaped head and a red-colored body on a light background' \
+ 'a pixel art character with square blue glasses, a microwave-shaped head and a purple-colored body on a sunny background' \
+ 'a pixel art character with square red glasses, a baseball-shaped head and a blue-colored body on an orange background' \
+ 'a pixel art character with square red glasses' \
+ 'a pixel art character' \
+ 'square red glasses on a pixel art character' \
+ 'square red glasses on a pixel art character with a baseball-shaped head' \
+ --max_train_steps 10000 \
+ --checkpointing_steps 500 \
+ --validation_steps 250 \
+ --gradient_checkpointing
+```
+
+#### Full finetuning + lora
+
+Batch size: 16, Learning rate: 8e-4, Gives decent results in 1000-1250 steps
+
+| Batch Size | Gradient Accumulation Steps | Effective Total Batch Size | Memory Used |
+|------------|-----------------------------|------------------|-------------|
+| 16 | 1 | 16 | 14.1 GB |
+| 8 | 2 | 16 | 10.1 GB |
+| 1 | 16 | 16 | 6.5 GB |
+
+```sh
+accelerate launch train_amused.py \
+ --output_dir \
+ --train_batch_size \
+ --gradient_accumulation_steps \
+ --learning_rate 8e-4 \
+ --use_lora \
+ --pretrained_model_name_or_path amused/amused-256 \
+ --instance_data_dataset 'm1guelpf/nouns' \
+ --image_key image \
+ --prompt_key text \
+ --resolution 256 \
+ --mixed_precision fp16 \
+ --lr_scheduler constant \
+ --validation_prompts \
+ 'a pixel art character with square red glasses, a baseball-shaped head and a orange-colored body on a dark background' \
+ 'a pixel art character with square orange glasses, a lips-shaped head and a red-colored body on a light background' \
+ 'a pixel art character with square blue glasses, a microwave-shaped head and a purple-colored body on a sunny background' \
+ 'a pixel art character with square red glasses, a baseball-shaped head and a blue-colored body on an orange background' \
+ 'a pixel art character with square red glasses' \
+ 'a pixel art character' \
+ 'square red glasses on a pixel art character' \
+ 'square red glasses on a pixel art character with a baseball-shaped head' \
+ --max_train_steps 10000 \
+ --checkpointing_steps 500 \
+ --validation_steps 250 \
+ --gradient_checkpointing
+```
+
+### Finetuning the 512 checkpoint
+
+These examples finetune on this [minecraft](https://huggingface.co/monadical-labs/minecraft-preview) dataset.
+
+Example results:
+
+  
+
+#### Full finetuning
+
+Batch size: 8, Learning rate: 8e-5, Gives decent results in 500-1000 steps
+
+| Batch Size | Gradient Accumulation Steps | Effective Total Batch Size | Memory Used |
+|------------|-----------------------------|------------------|-------------|
+| 8 | 1 | 8 | 24.2 GB |
+| 4 | 2 | 8 | 19.7 GB |
+| 1 | 8 | 8 | 16.99 GB |
+
+```sh
+accelerate launch train_amused.py \
+ --output_dir \
+ --train_batch_size \
+ --gradient_accumulation_steps \
+ --learning_rate 8e-5 \
+ --pretrained_model_name_or_path amused/amused-512 \
+ --instance_data_dataset 'monadical-labs/minecraft-preview' \
+ --prompt_prefix 'minecraft ' \
+ --image_key image \
+ --prompt_key text \
+ --resolution 512 \
+ --mixed_precision fp16 \
+ --lr_scheduler constant \
+ --validation_prompts \
+ 'minecraft Avatar' \
+ 'minecraft character' \
+ 'minecraft' \
+ 'minecraft president' \
+ 'minecraft pig' \
+ --max_train_steps 10000 \
+ --checkpointing_steps 500 \
+ --validation_steps 250 \
+ --gradient_checkpointing
+```
+
+#### Full finetuning + 8 bit adam
+
+Batch size: 8, Learning rate: 5e-6, Gives decent results in 500-1000 steps
+
+| Batch Size | Gradient Accumulation Steps | Effective Total Batch Size | Memory Used |
+|------------|-----------------------------|------------------|-------------|
+| 8 | 1 | 8 | 21.2 GB |
+| 4 | 2 | 8 | 13.3 GB |
+| 1 | 8 | 8 | 9.9 GB |
+
+```sh
+accelerate launch train_amused.py \
+ --output_dir \
+ --train_batch_size \
+ --gradient_accumulation_steps \
+ --learning_rate 5e-6 \
+ --pretrained_model_name_or_path amused/amused-512 \
+ --instance_data_dataset 'monadical-labs/minecraft-preview' \
+ --prompt_prefix 'minecraft ' \
+ --image_key image \
+ --prompt_key text \
+ --resolution 512 \
+ --mixed_precision fp16 \
+ --lr_scheduler constant \
+ --validation_prompts \
+ 'minecraft Avatar' \
+ 'minecraft character' \
+ 'minecraft' \
+ 'minecraft president' \
+ 'minecraft pig' \
+ --max_train_steps 10000 \
+ --checkpointing_steps 500 \
+ --validation_steps 250 \
+ --gradient_checkpointing
+```
+
+#### Full finetuning + lora
+
+Batch size: 8, Learning rate: 1e-4, Gives decent results in 500-1000 steps
+
+| Batch Size | Gradient Accumulation Steps | Effective Total Batch Size | Memory Used |
+|------------|-----------------------------|------------------|-------------|
+| 8 | 1 | 8 | 12.7 GB |
+| 4 | 2 | 8 | 9.0 GB |
+| 1 | 8 | 8 | 5.6 GB |
+
+```sh
+accelerate launch train_amused.py \
+ --output_dir \
+ --train_batch_size \
+ --gradient_accumulation_steps \
+ --learning_rate 1e-4 \
+ --use_lora \
+ --pretrained_model_name_or_path amused/amused-512 \
+ --instance_data_dataset 'monadical-labs/minecraft-preview' \
+ --prompt_prefix 'minecraft ' \
+ --image_key image \
+ --prompt_key text \
+ --resolution 512 \
+ --mixed_precision fp16 \
+ --lr_scheduler constant \
+ --validation_prompts \
+ 'minecraft Avatar' \
+ 'minecraft character' \
+ 'minecraft' \
+ 'minecraft president' \
+ 'minecraft pig' \
+ --max_train_steps 10000 \
+ --checkpointing_steps 500 \
+ --validation_steps 250 \
+ --gradient_checkpointing
+```
+
+### Styledrop
+
+[Styledrop](https://arxiv.org/abs/2306.00983) is an efficient finetuning method for learning a new style from just one or very few images. It has an optional first stage to generate human picked additional training samples. The additional training samples can be used to augment the initial images. Our examples exclude the optional additional image selection stage and instead we just finetune on a single image.
+
+This is our example style image:
+
+
+Download it to your local directory with
+```sh
+wget https://huggingface.co/datasets/diffusers/docs-images/resolve/main/amused/A%20mushroom%20in%20%5BV%5D%20style.png
+```
+
+#### 256
+
+Example results:
+
+  
+
+Learning rate: 4e-4, Gives decent results in 1500-2000 steps
+
+Memory used: 6.5 GB
+
+```sh
+accelerate launch train_amused.py \
+ --output_dir \
+ --mixed_precision fp16 \
+ --report_to wandb \
+ --use_lora \
+ --pretrained_model_name_or_path amused/amused-256 \
+ --train_batch_size 1 \
+ --lr_scheduler constant \
+ --learning_rate 4e-4 \
+ --validation_prompts \
+ 'A chihuahua walking on the street in [V] style' \
+ 'A banana on the table in [V] style' \
+ 'A church on the street in [V] style' \
+ 'A tabby cat walking in the forest in [V] style' \
+ --instance_data_image 'A mushroom in [V] style.png' \
+ --max_train_steps 10000 \
+ --checkpointing_steps 500 \
+ --validation_steps 100 \
+ --resolution 256
+```
+
+#### 512
+
+Example results:
+
+  
+
+Learning rate: 1e-3, Lora alpha 1, Gives decent results in 1500-2000 steps
+
+Memory used: 5.6 GB
+
+```
+accelerate launch train_amused.py \
+ --output_dir \
+ --mixed_precision fp16 \
+ --report_to wandb \
+ --use_lora \
+ --pretrained_model_name_or_path amused/amused-512 \
+ --train_batch_size 1 \
+ --lr_scheduler constant \
+ --learning_rate 1e-3 \
+ --validation_prompts \
+ 'A chihuahua walking on the street in [V] style' \
+ 'A banana on the table in [V] style' \
+ 'A church on the street in [V] style' \
+ 'A tabby cat walking in the forest in [V] style' \
+ --instance_data_image 'A mushroom in [V] style.png' \
+ --max_train_steps 100000 \
+ --checkpointing_steps 500 \
+ --validation_steps 100 \
+ --resolution 512 \
+ --lora_alpha 1
+```
\ No newline at end of file
diff --git a/diffusers/examples/amused/train_amused.py b/diffusers/examples/amused/train_amused.py
new file mode 100644
index 0000000000000000000000000000000000000000..ede51775dd8feb14f5670fb9e1e6b7e9cedc2bae
--- /dev/null
+++ b/diffusers/examples/amused/train_amused.py
@@ -0,0 +1,975 @@
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import copy
+import logging
+import math
+import os
+import shutil
+from contextlib import nullcontext
+from pathlib import Path
+
+import torch
+import torch.nn.functional as F
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import ProjectConfiguration, set_seed
+from datasets import load_dataset
+from peft import LoraConfig
+from peft.utils import get_peft_model_state_dict
+from PIL import Image
+from PIL.ImageOps import exif_transpose
+from torch.utils.data import DataLoader, Dataset, default_collate
+from torchvision import transforms
+from transformers import (
+ CLIPTextModelWithProjection,
+ CLIPTokenizer,
+)
+
+import diffusers.optimization
+from diffusers import AmusedPipeline, AmusedScheduler, EMAModel, UVit2DModel, VQModel
+from diffusers.loaders import AmusedLoraLoaderMixin
+from diffusers.utils import is_wandb_available
+
+
+if is_wandb_available():
+ import wandb
+
+logger = get_logger(__name__, log_level="INFO")
+
+
+def parse_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--variant",
+ type=str,
+ default=None,
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
+ )
+ parser.add_argument(
+ "--instance_data_dataset",
+ type=str,
+ default=None,
+ required=False,
+ help="A Hugging Face dataset containing the training images",
+ )
+ parser.add_argument(
+ "--instance_data_dir",
+ type=str,
+ default=None,
+ required=False,
+ help="A folder containing the training data of instance images.",
+ )
+ parser.add_argument(
+ "--instance_data_image", type=str, default=None, required=False, help="A single training image"
+ )
+ parser.add_argument(
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
+ )
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")
+ parser.add_argument("--ema_decay", type=float, default=0.9999)
+ parser.add_argument("--ema_update_after_step", type=int, default=0)
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="muse_training",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. "
+ "In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference."
+ "Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components."
+ "See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step"
+ "instructions."
+ ),
+ )
+ parser.add_argument(
+ "--logging_steps",
+ type=int,
+ default=50,
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=(
+ "Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
+ " See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
+ " for more details"
+ ),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=0.0003,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--validation_steps",
+ type=int,
+ default=100,
+ help=(
+ "Run validation every X steps. Validation consists of running the prompt"
+ " `args.validation_prompt` multiple times: `args.num_validation_images`"
+ " and logging the images."
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="wandb",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument("--validation_prompts", type=str, nargs="*")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument("--split_vae_encode", type=int, required=False, default=None)
+ parser.add_argument("--min_masking_rate", type=float, default=0.0)
+ parser.add_argument("--cond_dropout_prob", type=float, default=0.0)
+ parser.add_argument("--max_grad_norm", default=None, type=float, help="Max gradient norm.", required=False)
+ parser.add_argument("--use_lora", action="store_true", help="Fine tune the model using LoRa")
+ parser.add_argument("--text_encoder_use_lora", action="store_true", help="Fine tune the model using LoRa")
+ parser.add_argument("--lora_r", default=16, type=int)
+ parser.add_argument("--lora_alpha", default=32, type=int)
+ parser.add_argument("--lora_target_modules", default=["to_q", "to_k", "to_v"], type=str, nargs="+")
+ parser.add_argument("--text_encoder_lora_r", default=16, type=int)
+ parser.add_argument("--text_encoder_lora_alpha", default=32, type=int)
+ parser.add_argument("--text_encoder_lora_target_modules", default=["to_q", "to_k", "to_v"], type=str, nargs="+")
+ parser.add_argument("--train_text_encoder", action="store_true")
+ parser.add_argument("--image_key", type=str, required=False)
+ parser.add_argument("--prompt_key", type=str, required=False)
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument("--prompt_prefix", type=str, required=False, default=None)
+
+ args = parser.parse_args()
+
+ if args.report_to == "wandb":
+ if not is_wandb_available():
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
+
+ num_datasources = sum(
+ [x is not None for x in [args.instance_data_dir, args.instance_data_image, args.instance_data_dataset]]
+ )
+
+ if num_datasources != 1:
+ raise ValueError(
+ "provide one and only one of `--instance_data_dir`, `--instance_data_image`, or `--instance_data_dataset`"
+ )
+
+ if args.instance_data_dir is not None:
+ if not os.path.exists(args.instance_data_dir):
+ raise ValueError(f"Does not exist: `--args.instance_data_dir` {args.instance_data_dir}")
+
+ if args.instance_data_image is not None:
+ if not os.path.exists(args.instance_data_image):
+ raise ValueError(f"Does not exist: `--args.instance_data_image` {args.instance_data_image}")
+
+ if args.instance_data_dataset is not None and (args.image_key is None or args.prompt_key is None):
+ raise ValueError("`--instance_data_dataset` requires setting `--image_key` and `--prompt_key`")
+
+ return args
+
+
+class InstanceDataRootDataset(Dataset):
+ def __init__(
+ self,
+ instance_data_root,
+ tokenizer,
+ size=512,
+ ):
+ self.size = size
+ self.tokenizer = tokenizer
+ self.instance_images_path = list(Path(instance_data_root).iterdir())
+
+ def __len__(self):
+ return len(self.instance_images_path)
+
+ def __getitem__(self, index):
+ image_path = self.instance_images_path[index % len(self.instance_images_path)]
+ instance_image = Image.open(image_path)
+ rv = process_image(instance_image, self.size)
+
+ prompt = os.path.splitext(os.path.basename(image_path))[0]
+ rv["prompt_input_ids"] = tokenize_prompt(self.tokenizer, prompt)[0]
+ return rv
+
+
+class InstanceDataImageDataset(Dataset):
+ def __init__(
+ self,
+ instance_data_image,
+ train_batch_size,
+ size=512,
+ ):
+ self.value = process_image(Image.open(instance_data_image), size)
+ self.train_batch_size = train_batch_size
+
+ def __len__(self):
+ # Needed so a full batch of the data can be returned. Otherwise will return
+ # batches of size 1
+ return self.train_batch_size
+
+ def __getitem__(self, index):
+ return self.value
+
+
+class HuggingFaceDataset(Dataset):
+ def __init__(
+ self,
+ hf_dataset,
+ tokenizer,
+ image_key,
+ prompt_key,
+ prompt_prefix=None,
+ size=512,
+ ):
+ self.size = size
+ self.image_key = image_key
+ self.prompt_key = prompt_key
+ self.tokenizer = tokenizer
+ self.hf_dataset = hf_dataset
+ self.prompt_prefix = prompt_prefix
+
+ def __len__(self):
+ return len(self.hf_dataset)
+
+ def __getitem__(self, index):
+ item = self.hf_dataset[index]
+
+ rv = process_image(item[self.image_key], self.size)
+
+ prompt = item[self.prompt_key]
+
+ if self.prompt_prefix is not None:
+ prompt = self.prompt_prefix + prompt
+
+ rv["prompt_input_ids"] = tokenize_prompt(self.tokenizer, prompt)[0]
+
+ return rv
+
+
+def process_image(image, size):
+ image = exif_transpose(image)
+
+ if not image.mode == "RGB":
+ image = image.convert("RGB")
+
+ orig_height = image.height
+ orig_width = image.width
+
+ image = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR)(image)
+
+ c_top, c_left, _, _ = transforms.RandomCrop.get_params(image, output_size=(size, size))
+ image = transforms.functional.crop(image, c_top, c_left, size, size)
+
+ image = transforms.ToTensor()(image)
+
+ micro_conds = torch.tensor(
+ [orig_width, orig_height, c_top, c_left, 6.0],
+ )
+
+ return {"image": image, "micro_conds": micro_conds}
+
+
+def tokenize_prompt(tokenizer, prompt):
+ return tokenizer(
+ prompt,
+ truncation=True,
+ padding="max_length",
+ max_length=77,
+ return_tensors="pt",
+ ).input_ids
+
+
+def encode_prompt(text_encoder, input_ids):
+ outputs = text_encoder(input_ids, return_dict=True, output_hidden_states=True)
+ encoder_hidden_states = outputs.hidden_states[-2]
+ cond_embeds = outputs[0]
+ return encoder_hidden_states, cond_embeds
+
+
+def main(args):
+ if args.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ logging_dir = Path(args.output_dir, args.logging_dir)
+
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ )
+ # Disable AMP for MPS.
+ if torch.backends.mps.is_available():
+ accelerator.native_amp = False
+
+ if accelerator.is_main_process:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+
+ if accelerator.is_main_process:
+ accelerator.init_trackers("amused", config=vars(copy.deepcopy(args)))
+
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # TODO - will have to fix loading if training text encoder
+ text_encoder = CLIPTextModelWithProjection.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
+ )
+ tokenizer = CLIPTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision, variant=args.variant
+ )
+ vq_model = VQModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="vqvae", revision=args.revision, variant=args.variant
+ )
+
+ if args.train_text_encoder:
+ if args.text_encoder_use_lora:
+ lora_config = LoraConfig(
+ r=args.text_encoder_lora_r,
+ lora_alpha=args.text_encoder_lora_alpha,
+ target_modules=args.text_encoder_lora_target_modules,
+ )
+ text_encoder.add_adapter(lora_config)
+ text_encoder.train()
+ text_encoder.requires_grad_(True)
+ else:
+ text_encoder.eval()
+ text_encoder.requires_grad_(False)
+
+ vq_model.requires_grad_(False)
+
+ model = UVit2DModel.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="transformer",
+ revision=args.revision,
+ variant=args.variant,
+ )
+
+ if args.use_lora:
+ lora_config = LoraConfig(
+ r=args.lora_r,
+ lora_alpha=args.lora_alpha,
+ target_modules=args.lora_target_modules,
+ )
+ model.add_adapter(lora_config)
+
+ model.train()
+
+ if args.gradient_checkpointing:
+ model.enable_gradient_checkpointing()
+ if args.train_text_encoder:
+ text_encoder.gradient_checkpointing_enable()
+
+ if args.use_ema:
+ ema = EMAModel(
+ model.parameters(),
+ decay=args.ema_decay,
+ update_after_step=args.ema_update_after_step,
+ model_cls=UVit2DModel,
+ model_config=model.config,
+ )
+
+ def save_model_hook(models, weights, output_dir):
+ if accelerator.is_main_process:
+ transformer_lora_layers_to_save = None
+ text_encoder_lora_layers_to_save = None
+
+ for model_ in models:
+ if isinstance(model_, type(accelerator.unwrap_model(model))):
+ if args.use_lora:
+ transformer_lora_layers_to_save = get_peft_model_state_dict(model_)
+ else:
+ model_.save_pretrained(os.path.join(output_dir, "transformer"))
+ elif isinstance(model_, type(accelerator.unwrap_model(text_encoder))):
+ if args.text_encoder_use_lora:
+ text_encoder_lora_layers_to_save = get_peft_model_state_dict(model_)
+ else:
+ model_.save_pretrained(os.path.join(output_dir, "text_encoder"))
+ else:
+ raise ValueError(f"unexpected save model: {model_.__class__}")
+
+ # make sure to pop weight so that corresponding model is not saved again
+ weights.pop()
+
+ if transformer_lora_layers_to_save is not None or text_encoder_lora_layers_to_save is not None:
+ AmusedLoraLoaderMixin.save_lora_weights(
+ output_dir,
+ transformer_lora_layers=transformer_lora_layers_to_save,
+ text_encoder_lora_layers=text_encoder_lora_layers_to_save,
+ )
+
+ if args.use_ema:
+ ema.save_pretrained(os.path.join(output_dir, "ema_model"))
+
+ def load_model_hook(models, input_dir):
+ transformer = None
+ text_encoder_ = None
+
+ while len(models) > 0:
+ model_ = models.pop()
+
+ if isinstance(model_, type(accelerator.unwrap_model(model))):
+ if args.use_lora:
+ transformer = model_
+ else:
+ load_model = UVit2DModel.from_pretrained(os.path.join(input_dir, "transformer"))
+ model_.load_state_dict(load_model.state_dict())
+ del load_model
+ elif isinstance(model, type(accelerator.unwrap_model(text_encoder))):
+ if args.text_encoder_use_lora:
+ text_encoder_ = model_
+ else:
+ load_model = CLIPTextModelWithProjection.from_pretrained(os.path.join(input_dir, "text_encoder"))
+ model_.load_state_dict(load_model.state_dict())
+ del load_model
+ else:
+ raise ValueError(f"unexpected save model: {model.__class__}")
+
+ if transformer is not None or text_encoder_ is not None:
+ lora_state_dict, network_alphas = AmusedLoraLoaderMixin.lora_state_dict(input_dir)
+ AmusedLoraLoaderMixin.load_lora_into_text_encoder(
+ lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_
+ )
+ AmusedLoraLoaderMixin.load_lora_into_transformer(
+ lora_state_dict, network_alphas=network_alphas, transformer=transformer
+ )
+
+ if args.use_ema:
+ load_from = EMAModel.from_pretrained(os.path.join(input_dir, "ema_model"), model_cls=UVit2DModel)
+ ema.load_state_dict(load_from.state_dict())
+ del load_from
+
+ accelerator.register_load_state_pre_hook(load_model_hook)
+ accelerator.register_save_state_pre_hook(save_model_hook)
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+ )
+
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
+ )
+
+ optimizer_cls = bnb.optim.AdamW8bit
+ else:
+ optimizer_cls = torch.optim.AdamW
+
+ # no decay on bias and layernorm and embedding
+ no_decay = ["bias", "layer_norm.weight", "mlm_ln.weight", "embeddings.weight"]
+ optimizer_grouped_parameters = [
+ {
+ "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
+ "weight_decay": args.adam_weight_decay,
+ },
+ {
+ "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
+ "weight_decay": 0.0,
+ },
+ ]
+
+ if args.train_text_encoder:
+ optimizer_grouped_parameters.append(
+ {"params": text_encoder.parameters(), "weight_decay": args.adam_weight_decay}
+ )
+
+ optimizer = optimizer_cls(
+ optimizer_grouped_parameters,
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ logger.info("Creating dataloaders and lr_scheduler")
+
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ if args.instance_data_dir is not None:
+ dataset = InstanceDataRootDataset(
+ instance_data_root=args.instance_data_dir,
+ tokenizer=tokenizer,
+ size=args.resolution,
+ )
+ elif args.instance_data_image is not None:
+ dataset = InstanceDataImageDataset(
+ instance_data_image=args.instance_data_image,
+ train_batch_size=args.train_batch_size,
+ size=args.resolution,
+ )
+ elif args.instance_data_dataset is not None:
+ dataset = HuggingFaceDataset(
+ hf_dataset=load_dataset(args.instance_data_dataset, split="train"),
+ tokenizer=tokenizer,
+ image_key=args.image_key,
+ prompt_key=args.prompt_key,
+ prompt_prefix=args.prompt_prefix,
+ size=args.resolution,
+ )
+ else:
+ assert False
+
+ train_dataloader = DataLoader(
+ dataset,
+ batch_size=args.train_batch_size,
+ shuffle=True,
+ num_workers=args.dataloader_num_workers,
+ collate_fn=default_collate,
+ )
+ train_dataloader.num_batches = len(train_dataloader)
+
+ lr_scheduler = diffusers.optimization.get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
+ )
+
+ logger.info("Preparing model, optimizer and dataloaders")
+
+ if args.train_text_encoder:
+ model, optimizer, lr_scheduler, train_dataloader, text_encoder = accelerator.prepare(
+ model, optimizer, lr_scheduler, train_dataloader, text_encoder
+ )
+ else:
+ model, optimizer, lr_scheduler, train_dataloader = accelerator.prepare(
+ model, optimizer, lr_scheduler, train_dataloader
+ )
+
+ train_dataloader.num_batches = len(train_dataloader)
+
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ if not args.train_text_encoder:
+ text_encoder.to(device=accelerator.device, dtype=weight_dtype)
+
+ vq_model.to(device=accelerator.device)
+
+ if args.use_ema:
+ ema.to(accelerator.device)
+
+ with nullcontext() if args.train_text_encoder else torch.no_grad():
+ empty_embeds, empty_clip_embeds = encode_prompt(
+ text_encoder, tokenize_prompt(tokenizer, "").to(text_encoder.device, non_blocking=True)
+ )
+
+ # There is a single image, we can just pre-encode the single prompt
+ if args.instance_data_image is not None:
+ prompt = os.path.splitext(os.path.basename(args.instance_data_image))[0]
+ encoder_hidden_states, cond_embeds = encode_prompt(
+ text_encoder, tokenize_prompt(tokenizer, prompt).to(text_encoder.device, non_blocking=True)
+ )
+ encoder_hidden_states = encoder_hidden_states.repeat(args.train_batch_size, 1, 1)
+ cond_embeds = cond_embeds.repeat(args.train_batch_size, 1)
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(train_dataloader.num_batches / args.gradient_accumulation_steps)
+ # Afterwards we recalculate our number of training epochs.
+ # Note: We are not doing epoch based training here, but just using this for book keeping and being able to
+ # reuse the same training loop with other datasets/loaders.
+ num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # Train!
+ logger.info("***** Running training *****")
+ logger.info(f" Num training steps = {args.max_train_steps}")
+ logger.info(f" Instantaneous batch size per device = { args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+
+ resume_from_checkpoint = args.resume_from_checkpoint
+ if resume_from_checkpoint:
+ if resume_from_checkpoint == "latest":
+ # Get the most recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ if len(dirs) > 0:
+ resume_from_checkpoint = os.path.join(args.output_dir, dirs[-1])
+ else:
+ resume_from_checkpoint = None
+
+ if resume_from_checkpoint is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ else:
+ accelerator.print(f"Resuming from checkpoint {resume_from_checkpoint}")
+
+ if resume_from_checkpoint is None:
+ global_step = 0
+ first_epoch = 0
+ else:
+ accelerator.load_state(resume_from_checkpoint)
+ global_step = int(os.path.basename(resume_from_checkpoint).split("-")[1])
+ first_epoch = global_step // num_update_steps_per_epoch
+
+ # As stated above, we are not doing epoch based training here, but just using this for book keeping and being able to
+ # reuse the same training loop with other datasets/loaders.
+ for epoch in range(first_epoch, num_train_epochs):
+ for batch in train_dataloader:
+ with torch.no_grad():
+ micro_conds = batch["micro_conds"].to(accelerator.device, non_blocking=True)
+ pixel_values = batch["image"].to(accelerator.device, non_blocking=True)
+
+ batch_size = pixel_values.shape[0]
+
+ split_batch_size = args.split_vae_encode if args.split_vae_encode is not None else batch_size
+ num_splits = math.ceil(batch_size / split_batch_size)
+ image_tokens = []
+ for i in range(num_splits):
+ start_idx = i * split_batch_size
+ end_idx = min((i + 1) * split_batch_size, batch_size)
+ bs = pixel_values.shape[0]
+ image_tokens.append(
+ vq_model.quantize(vq_model.encode(pixel_values[start_idx:end_idx]).latents)[2][2].reshape(
+ bs, -1
+ )
+ )
+ image_tokens = torch.cat(image_tokens, dim=0)
+
+ batch_size, seq_len = image_tokens.shape
+
+ timesteps = torch.rand(batch_size, device=image_tokens.device)
+ mask_prob = torch.cos(timesteps * math.pi * 0.5)
+ mask_prob = mask_prob.clip(args.min_masking_rate)
+
+ num_token_masked = (seq_len * mask_prob).round().clamp(min=1)
+ batch_randperm = torch.rand(batch_size, seq_len, device=image_tokens.device).argsort(dim=-1)
+ mask = batch_randperm < num_token_masked.unsqueeze(-1)
+
+ mask_id = accelerator.unwrap_model(model).config.vocab_size - 1
+ input_ids = torch.where(mask, mask_id, image_tokens)
+ labels = torch.where(mask, image_tokens, -100)
+
+ if args.cond_dropout_prob > 0.0:
+ assert encoder_hidden_states is not None
+
+ batch_size = encoder_hidden_states.shape[0]
+
+ mask = (
+ torch.zeros((batch_size, 1, 1), device=encoder_hidden_states.device).float().uniform_(0, 1)
+ < args.cond_dropout_prob
+ )
+
+ empty_embeds_ = empty_embeds.expand(batch_size, -1, -1)
+ encoder_hidden_states = torch.where(
+ (encoder_hidden_states * mask).bool(), encoder_hidden_states, empty_embeds_
+ )
+
+ empty_clip_embeds_ = empty_clip_embeds.expand(batch_size, -1)
+ cond_embeds = torch.where((cond_embeds * mask.squeeze(-1)).bool(), cond_embeds, empty_clip_embeds_)
+
+ bs = input_ids.shape[0]
+ vae_scale_factor = 2 ** (len(vq_model.config.block_out_channels) - 1)
+ resolution = args.resolution // vae_scale_factor
+ input_ids = input_ids.reshape(bs, resolution, resolution)
+
+ if "prompt_input_ids" in batch:
+ with nullcontext() if args.train_text_encoder else torch.no_grad():
+ encoder_hidden_states, cond_embeds = encode_prompt(
+ text_encoder, batch["prompt_input_ids"].to(accelerator.device, non_blocking=True)
+ )
+
+ # Train Step
+ with accelerator.accumulate(model):
+ codebook_size = accelerator.unwrap_model(model).config.codebook_size
+
+ logits = (
+ model(
+ input_ids=input_ids,
+ encoder_hidden_states=encoder_hidden_states,
+ micro_conds=micro_conds,
+ pooled_text_emb=cond_embeds,
+ )
+ .reshape(bs, codebook_size, -1)
+ .permute(0, 2, 1)
+ .reshape(-1, codebook_size)
+ )
+
+ loss = F.cross_entropy(
+ logits,
+ labels.view(-1),
+ ignore_index=-100,
+ reduction="mean",
+ )
+
+ # Gather the losses across all processes for logging (if we use distributed training).
+ avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
+ avg_masking_rate = accelerator.gather(mask_prob.repeat(args.train_batch_size)).mean()
+
+ accelerator.backward(loss)
+
+ if args.max_grad_norm is not None and accelerator.sync_gradients:
+ accelerator.clip_grad_norm_(model.parameters(), args.max_grad_norm)
+
+ optimizer.step()
+ lr_scheduler.step()
+
+ optimizer.zero_grad(set_to_none=True)
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ if args.use_ema:
+ ema.step(model.parameters())
+
+ if (global_step + 1) % args.logging_steps == 0:
+ logs = {
+ "step_loss": avg_loss.item(),
+ "lr": lr_scheduler.get_last_lr()[0],
+ "avg_masking_rate": avg_masking_rate.item(),
+ }
+ accelerator.log(logs, step=global_step + 1)
+
+ logger.info(
+ f"Step: {global_step + 1} "
+ f"Loss: {avg_loss.item():0.4f} "
+ f"LR: {lr_scheduler.get_last_lr()[0]:0.6f}"
+ )
+
+ if (global_step + 1) % args.checkpointing_steps == 0:
+ save_checkpoint(args, accelerator, global_step + 1)
+
+ if (global_step + 1) % args.validation_steps == 0 and accelerator.is_main_process:
+ if args.use_ema:
+ ema.store(model.parameters())
+ ema.copy_to(model.parameters())
+
+ with torch.no_grad():
+ logger.info("Generating images...")
+
+ model.eval()
+
+ if args.train_text_encoder:
+ text_encoder.eval()
+
+ scheduler = AmusedScheduler.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="scheduler",
+ revision=args.revision,
+ variant=args.variant,
+ )
+
+ pipe = AmusedPipeline(
+ transformer=accelerator.unwrap_model(model),
+ tokenizer=tokenizer,
+ text_encoder=text_encoder,
+ vqvae=vq_model,
+ scheduler=scheduler,
+ )
+
+ pil_images = pipe(prompt=args.validation_prompts).images
+ wandb_images = [
+ wandb.Image(image, caption=args.validation_prompts[i])
+ for i, image in enumerate(pil_images)
+ ]
+
+ wandb.log({"generated_images": wandb_images}, step=global_step + 1)
+
+ model.train()
+
+ if args.train_text_encoder:
+ text_encoder.train()
+
+ if args.use_ema:
+ ema.restore(model.parameters())
+
+ global_step += 1
+
+ # Stop training if max steps is reached
+ if global_step >= args.max_train_steps:
+ break
+ # End for
+
+ accelerator.wait_for_everyone()
+
+ # Evaluate and save checkpoint at the end of training
+ save_checkpoint(args, accelerator, global_step)
+
+ # Save the final trained checkpoint
+ if accelerator.is_main_process:
+ model = accelerator.unwrap_model(model)
+ if args.use_ema:
+ ema.copy_to(model.parameters())
+ model.save_pretrained(args.output_dir)
+
+ accelerator.end_training()
+
+
+def save_checkpoint(args, accelerator, global_step):
+ output_dir = args.output_dir
+
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if accelerator.is_main_process and args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = Path(output_dir) / f"checkpoint-{global_step}"
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+
+if __name__ == "__main__":
+ main(parse_args())
diff --git a/diffusers/examples/cogvideo/README.md b/diffusers/examples/cogvideo/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..398ae9543150493d8fc175d4dafa7eb8ba523cd5
--- /dev/null
+++ b/diffusers/examples/cogvideo/README.md
@@ -0,0 +1,228 @@
+# LoRA finetuning example for CogVideoX
+
+Low-Rank Adaption of Large Language Models was first introduced by Microsoft in [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen*.
+
+In a nutshell, LoRA allows adapting pretrained models by adding pairs of rank-decomposition matrices to existing weights and **only** training those newly added weights. This has a couple of advantages:
+
+- Previous pretrained weights are kept frozen so that model is not prone to [catastrophic forgetting](https://www.pnas.org/doi/10.1073/pnas.1611835114).
+- Rank-decomposition matrices have significantly fewer parameters than original model, which means that trained LoRA weights are easily portable.
+- LoRA attention layers allow to control to which extent the model is adapted toward new training images via a `scale` parameter.
+
+At the moment, LoRA finetuning has only been tested for [CogVideoX-2b](https://huggingface.co/THUDM/CogVideoX-2b).
+
+## Data Preparation
+
+The training scripts accepts data in two formats.
+
+**First data format**
+
+Two files where one file contains line-separated prompts and another file contains line-separated paths to video data (the path to video files must be relative to the path you pass when specifying `--instance_data_root`). Let's take a look at an example to understand this better!
+
+Assume you've specified `--instance_data_root` as `/dataset`, and that this directory contains the files: `prompts.txt` and `videos.txt`.
+
+The `prompts.txt` file should contain line-separated prompts:
+
+```
+A black and white animated sequence featuring a rabbit, named Rabbity Ribfried, and an anthropomorphic goat in a musical, playful environment, showcasing their evolving interaction.
+A black and white animated sequence on a ship's deck features a bulldog character, named Bully Bulldoger, showcasing exaggerated facial expressions and body language. The character progresses from confident to focused, then to strained and distressed, displaying a range of emotions as it navigates challenges. The ship's interior remains static in the background, with minimalistic details such as a bell and open door. The character's dynamic movements and changing expressions drive the narrative, with no camera movement to distract from its evolving reactions and physical gestures.
+...
+```
+
+The `videos.txt` file should contain line-separate paths to video files. Note that the path should be _relative_ to the `--instance_data_root` directory.
+
+```
+videos/00000.mp4
+videos/00001.mp4
+...
+```
+
+Overall, this is how your dataset would look like if you ran the `tree` command on the dataset root directory:
+
+```
+/dataset
+├── prompts.txt
+├── videos.txt
+├── videos
+ ├── videos/00000.mp4
+ ├── videos/00001.mp4
+ ├── ...
+```
+
+When using this format, the `--caption_column` must be `prompts.txt` and `--video_column` must be `videos.txt`.
+
+**Second data format**
+
+You could use a single CSV file. For the sake of this example, assume you have a `metadata.csv` file. The expected format is:
+
+```
+,
+"""A black and white animated sequence featuring a rabbit, named Rabbity Ribfried, and an anthropomorphic goat in a musical, playful environment, showcasing their evolving interaction.""","""00000.mp4"""
+"""A black and white animated sequence on a ship's deck features a bulldog character, named Bully Bulldoger, showcasing exaggerated facial expressions and body language. The character progresses from confident to focused, then to strained and distressed, displaying a range of emotions as it navigates challenges. The ship's interior remains static in the background, with minimalistic details such as a bell and open door. The character's dynamic movements and changing expressions drive the narrative, with no camera movement to distract from its evolving reactions and physical gestures.""","""00001.mp4"""
+...
+```
+
+In this case, the `--instance_data_root` should be the location where the videos are stored and `--dataset_name` should be either a path to local folder or `load_dataset` compatible hosted HF Dataset Repository or URL. Assuming you have videos of your Minecraft gameplay at `https://huggingface.co/datasets/my-awesome-username/minecraft-videos`, you would have to specify `my-awesome-username/minecraft-videos`.
+
+When using this format, the `--caption_column` must be `` and `--video_column` must be ``.
+
+You are not strictly restricted to the CSV format. As long as the `load_dataset` method supports the file format to load a basic `` and ``, you should be good to go. The reason for going through these dataset organization gymnastics for loading video data is because we found `load_dataset` from the datasets library to not fully support all kinds of video formats. This will undoubtedly be improved in the future.
+
+>![NOTE]
+> CogVideoX works best with long and descriptive LLM-augmented prompts for video generation. We recommend pre-processing your videos by first generating a summary using a VLM and then augmenting the prompts with an LLM. To generate the above captions, we use [MiniCPM-V-26](https://huggingface.co/openbmb/MiniCPM-V-2_6) and [Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct). A very barebones and no-frills example for this is available [here](https://gist.github.com/a-r-r-o-w/4dee20250e82f4e44690a02351324a4a). The official recommendation for augmenting prompts is [ChatGLM](https://huggingface.co/THUDM?search_models=chatglm) and a length of 50-100 words is considered good.
+
+>![NOTE]
+> It is expected that your dataset is already pre-processed. If not, some basic pre-processing can be done by playing with the following parameters:
+> `--height`, `--width`, `--fps`, `--max_num_frames`, `--skip_frames_start` and `--skip_frames_end`.
+> Presently, all videos in your dataset should contain the same number of video frames when using a training batch size > 1.
+
+
+
+## Training
+
+You need to setup your development environment by installing the necessary requirements. The following packages are required:
+- Torch 2.0 or above based on the training features you are utilizing (might require latest or nightly versions for quantized/deepspeed training)
+- `pip install diffusers transformers accelerate peft huggingface_hub` for all things modeling and training related
+- `pip install datasets decord` for loading video training data
+- `pip install bitsandbytes` for using 8-bit Adam or AdamW optimizers for memory-optimized training
+- `pip install wandb` optionally for monitoring training logs
+- `pip install deepspeed` optionally for [DeepSpeed](https://github.com/microsoft/DeepSpeed) training
+- `pip install prodigyopt` optionally if you would like to use the Prodigy optimizer for training
+
+To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
+
+```bash
+git clone https://github.com/huggingface/diffusers
+cd diffusers
+pip install -e .
+```
+
+And initialize an [🤗 Accelerate](https://github.com/huggingface/accelerate/) environment with:
+
+```bash
+accelerate config
+```
+
+Or for a default accelerate configuration without answering questions about your environment
+
+```bash
+accelerate config default
+```
+
+Or if your environment doesn't support an interactive shell (e.g., a notebook)
+
+```python
+from accelerate.utils import write_basic_config
+write_basic_config()
+```
+
+When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups. Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.
+
+If you would like to push your model to the HF Hub after training is completed with a neat model card, make sure you're logged in:
+
+```
+huggingface-cli login
+
+# Alternatively, you could upload your model manually using:
+# huggingface-cli upload my-cool-account-name/my-cool-lora-name /path/to/awesome/lora
+```
+
+Make sure your data is prepared as described in [Data Preparation](#data-preparation). When ready, you can begin training!
+
+Assuming you are training on 50 videos of a similar concept, we have found 1500-2000 steps to work well. The official recommendation, however, is 100 videos with a total of 4000 steps. Assuming you are training on a single GPU with a `--train_batch_size` of `1`:
+- 1500 steps on 50 videos would correspond to `30` training epochs
+- 4000 steps on 100 videos would correspond to `40` training epochs
+
+```bash
+#!/bin/bash
+
+GPU_IDS="0"
+
+accelerate launch --gpu_ids $GPU_IDS examples/cogvideo/train_cogvideox_lora.py \
+ --pretrained_model_name_or_path THUDM/CogVideoX-2b \
+ --cache_dir \
+ --instance_data_root \
+ --dataset_name my-awesome-name/my-awesome-dataset \
+ --caption_column \
+ --video_column \
+ --id_token \
+ --validation_prompt " Spiderman swinging over buildings:::A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance" \
+ --validation_prompt_separator ::: \
+ --num_validation_videos 1 \
+ --validation_epochs 10 \
+ --seed 42 \
+ --rank 64 \
+ --lora_alpha 64 \
+ --mixed_precision fp16 \
+ --output_dir /raid/aryan/cogvideox-lora \
+ --height 480 --width 720 --fps 8 --max_num_frames 49 --skip_frames_start 0 --skip_frames_end 0 \
+ --train_batch_size 1 \
+ --num_train_epochs 30 \
+ --checkpointing_steps 1000 \
+ --gradient_accumulation_steps 1 \
+ --learning_rate 1e-3 \
+ --lr_scheduler cosine_with_restarts \
+ --lr_warmup_steps 200 \
+ --lr_num_cycles 1 \
+ --enable_slicing \
+ --enable_tiling \
+ --optimizer Adam \
+ --adam_beta1 0.9 \
+ --adam_beta2 0.95 \
+ --max_grad_norm 1.0 \
+ --report_to wandb
+```
+
+To better track our training experiments, we're using the following flags in the command above:
+* `--report_to wandb` will ensure the training runs are tracked on Weights and Biases. To use it, be sure to install `wandb` with `pip install wandb`.
+* `validation_prompt` and `validation_epochs` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.
+
+Note that setting the `` is not necessary. From some limited experimentation, we found it to work better (as it resembles [Dreambooth](https://huggingface.co/docs/diffusers/en/training/dreambooth) like training) than without. When provided, the ID_TOKEN is appended to the beginning of each prompt. So, if your ID_TOKEN was `"DISNEY"` and your prompt was `"Spiderman swinging over buildings"`, the effective prompt used in training would be `"DISNEY Spiderman swinging over buildings"`. When not provided, you would either be training without any such additional token or could augment your dataset to apply the token where you wish before starting the training.
+
+> [!TIP]
+> You can pass `--use_8bit_adam` to reduce the memory requirements of training.
+
+> [!IMPORTANT]
+> The following settings have been tested at the time of adding CogVideoX LoRA training support:
+> - Our testing was primarily done on CogVideoX-2b. We will work on CogVideoX-5b and CogVideoX-5b-I2V soon
+> - One dataset comprised of 70 training videos of resolutions `200 x 480 x 720` (F x H x W). From this, by using frame skipping in data preprocessing, we created two smaller 49-frame and 16-frame datasets for faster experimentation and because the maximum limit recommended by the CogVideoX team is 49 frames. Out of the 70 videos, we created three groups of 10, 25 and 50 videos. All videos were similar in nature of the concept being trained.
+> - 25+ videos worked best for training new concepts and styles.
+> - We found that it is better to train with an identifier token that can be specified as `--id_token`. This is similar to Dreambooth-like training but normal finetuning without such a token works too.
+> - Trained concept seemed to work decently well when combined with completely unrelated prompts. We expect even better results if CogVideoX-5B is finetuned.
+> - The original repository uses a `lora_alpha` of `1`. We found this not suitable in many runs, possibly due to difference in modeling backends and training settings. Our recommendation is to set to the `lora_alpha` to either `rank` or `rank // 2`.
+> - If you're training on data whose captions generate bad results with the original model, a `rank` of 64 and above is good and also the recommendation by the team behind CogVideoX. If the generations are already moderately good on your training captions, a `rank` of 16/32 should work. We found that setting the rank too low, say `4`, is not ideal and doesn't produce promising results.
+> - The authors of CogVideoX recommend 4000 training steps and 100 training videos overall to achieve the best result. While that might yield the best results, we found from our limited experimentation that 2000 steps and 25 videos could also be sufficient.
+> - When using the Prodigy opitimizer for training, one can follow the recommendations from [this](https://huggingface.co/blog/sdxl_lora_advanced_script) blog. Prodigy tends to overfit quickly. From my very limited testing, I found a learning rate of `0.5` to be suitable in addition to `--prodigy_use_bias_correction`, `prodigy_safeguard_warmup` and `--prodigy_decouple`.
+> - The recommended learning rate by the CogVideoX authors and from our experimentation with Adam/AdamW is between `1e-3` and `1e-4` for a dataset of 25+ videos.
+>
+> Note that our testing is not exhaustive due to limited time for exploration. Our recommendation would be to play around with the different knobs and dials to find the best settings for your data.
+
+
+
+## Inference
+
+Once you have trained a lora model, the inference can be done simply loading the lora weights into the `CogVideoXPipeline`.
+
+```python
+import torch
+from diffusers import CogVideoXPipeline
+from diffusers.utils import export_to_video
+
+pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-2b", torch_dtype=torch.float16)
+# pipe.load_lora_weights("/path/to/lora/weights", adapter_name="cogvideox-lora") # Or,
+pipe.load_lora_weights("my-awesome-hf-username/my-awesome-lora-name", adapter_name="cogvideox-lora") # If loading from the HF Hub
+pipe.to("cuda")
+
+# Assuming lora_alpha=32 and rank=64 for training. If different, set accordingly
+pipe.set_adapters(["cogvideox-lora"], [32 / 64])
+
+prompt = (
+ "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The "
+ "panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
+ "pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
+ "casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
+ "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
+ "atmosphere of this unique musical performance"
+)
+frames = pipe(prompt, guidance_scale=6, use_dynamic_cfg=True).frames[0]
+export_to_video(frames, "output.mp4", fps=8)
+```
diff --git a/diffusers/examples/cogvideo/requirements.txt b/diffusers/examples/cogvideo/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..c2238804be9f4b41300ac2962fd76afca53f1e71
--- /dev/null
+++ b/diffusers/examples/cogvideo/requirements.txt
@@ -0,0 +1,10 @@
+accelerate>=0.31.0
+torchvision
+transformers>=4.41.2
+ftfy
+tensorboard
+Jinja2
+peft>=0.11.1
+sentencepiece
+decord>=0.6.0
+imageio-ffmpeg
\ No newline at end of file
diff --git a/diffusers/examples/cogvideo/train_cogvideox_lora.py b/diffusers/examples/cogvideo/train_cogvideox_lora.py
new file mode 100644
index 0000000000000000000000000000000000000000..137f3222f6d901990464388c478a9416d16f587e
--- /dev/null
+++ b/diffusers/examples/cogvideo/train_cogvideox_lora.py
@@ -0,0 +1,1544 @@
+# Copyright 2024 The HuggingFace Team.
+# All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import logging
+import math
+import os
+import shutil
+from pathlib import Path
+from typing import List, Optional, Tuple, Union
+
+import torch
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
+from huggingface_hub import create_repo, upload_folder
+from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict
+from torch.utils.data import DataLoader, Dataset
+from torchvision import transforms
+from tqdm.auto import tqdm
+from transformers import AutoTokenizer, T5EncoderModel, T5Tokenizer
+
+import diffusers
+from diffusers import AutoencoderKLCogVideoX, CogVideoXDPMScheduler, CogVideoXPipeline, CogVideoXTransformer3DModel
+from diffusers.models.embeddings import get_3d_rotary_pos_embed
+from diffusers.optimization import get_scheduler
+from diffusers.pipelines.cogvideo.pipeline_cogvideox import get_resize_crop_region_for_grid
+from diffusers.training_utils import (
+ cast_training_params,
+ clear_objs_and_retain_memory,
+)
+from diffusers.utils import check_min_version, convert_unet_state_dict_to_peft, export_to_video, is_wandb_available
+from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
+from diffusers.utils.torch_utils import is_compiled_module
+
+
+if is_wandb_available():
+ import wandb
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.31.0.dev0")
+
+logger = get_logger(__name__)
+
+
+def get_args():
+ parser = argparse.ArgumentParser(description="Simple example of a training script for CogVideoX.")
+
+ # Model information
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--variant",
+ type=str,
+ default=None,
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
+ )
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+
+ # Dataset information
+ parser.add_argument(
+ "--dataset_name",
+ type=str,
+ default=None,
+ help=(
+ "The name of the Dataset (from the HuggingFace hub) containing the training data of instance images (could be your own, possibly private,"
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
+ " or to a folder containing files that 🤗 Datasets can understand."
+ ),
+ )
+ parser.add_argument(
+ "--dataset_config_name",
+ type=str,
+ default=None,
+ help="The config of the Dataset, leave as None if there's only one config.",
+ )
+ parser.add_argument(
+ "--instance_data_root",
+ type=str,
+ default=None,
+ help=("A folder containing the training data."),
+ )
+ parser.add_argument(
+ "--video_column",
+ type=str,
+ default="video",
+ help="The column of the dataset containing videos. Or, the name of the file in `--instance_data_root` folder containing the line-separated path to video data.",
+ )
+ parser.add_argument(
+ "--caption_column",
+ type=str,
+ default="text",
+ help="The column of the dataset containing the instance prompt for each video. Or, the name of the file in `--instance_data_root` folder containing the line-separated instance prompts.",
+ )
+ parser.add_argument(
+ "--id_token", type=str, default=None, help="Identifier token appended to the start of each prompt if provided."
+ )
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+
+ # Validation
+ parser.add_argument(
+ "--validation_prompt",
+ type=str,
+ default=None,
+ help="One or more prompt(s) that is used during validation to verify that the model is learning. Multiple validation prompts should be separated by the '--validation_prompt_seperator' string.",
+ )
+ parser.add_argument(
+ "--validation_prompt_separator",
+ type=str,
+ default=":::",
+ help="String that separates multiple validation prompts",
+ )
+ parser.add_argument(
+ "--num_validation_videos",
+ type=int,
+ default=1,
+ help="Number of videos that should be generated during validation per `validation_prompt`.",
+ )
+ parser.add_argument(
+ "--validation_epochs",
+ type=int,
+ default=50,
+ help=(
+ "Run validation every X epochs. Validation consists of running the prompt `args.validation_prompt` multiple times: `args.num_validation_videos`."
+ ),
+ )
+ parser.add_argument(
+ "--guidance_scale",
+ type=float,
+ default=6,
+ help="The guidance scale to use while sampling validation videos.",
+ )
+ parser.add_argument(
+ "--use_dynamic_cfg",
+ action="store_true",
+ default=False,
+ help="Whether or not to use the default cosine dynamic guidance schedule when sampling validation videos.",
+ )
+
+ # Training information
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--rank",
+ type=int,
+ default=128,
+ help=("The dimension of the LoRA update matrices."),
+ )
+ parser.add_argument(
+ "--lora_alpha",
+ type=float,
+ default=128,
+ help=("The scaling factor to scale LoRA weight update. The actual scaling factor is `lora_alpha / rank`"),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="cogvideox-lora",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument(
+ "--height",
+ type=int,
+ default=480,
+ help="All input videos are resized to this height.",
+ )
+ parser.add_argument(
+ "--width",
+ type=int,
+ default=720,
+ help="All input videos are resized to this width.",
+ )
+ parser.add_argument("--fps", type=int, default=8, help="All input videos will be used at this FPS.")
+ parser.add_argument(
+ "--max_num_frames", type=int, default=49, help="All input videos will be truncated to these many frames."
+ )
+ parser.add_argument(
+ "--skip_frames_start",
+ type=int,
+ default=0,
+ help="Number of frames to skip from the beginning of each input video. Useful if training data contains intro sequences.",
+ )
+ parser.add_argument(
+ "--skip_frames_end",
+ type=int,
+ default=0,
+ help="Number of frames to skip from the end of each input video. Useful if training data contains outro sequences.",
+ )
+ parser.add_argument(
+ "--random_flip",
+ action="store_true",
+ help="whether to randomly flip videos horizontally",
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=1)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides `--num_train_epochs`.",
+ )
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
+ " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--lr_num_cycles",
+ type=int,
+ default=1,
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
+ )
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
+ parser.add_argument(
+ "--enable_slicing",
+ action="store_true",
+ default=False,
+ help="Whether or not to use VAE slicing for saving memory.",
+ )
+ parser.add_argument(
+ "--enable_tiling",
+ action="store_true",
+ default=False,
+ help="Whether or not to use VAE tiling for saving memory.",
+ )
+
+ # Optimizer
+ parser.add_argument(
+ "--optimizer",
+ type=lambda s: s.lower(),
+ default="adam",
+ choices=["adam", "adamw", "prodigy"],
+ help=("The optimizer type to use."),
+ )
+ parser.add_argument(
+ "--use_8bit_adam",
+ action="store_true",
+ help="Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW",
+ )
+ parser.add_argument(
+ "--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam and Prodigy optimizers."
+ )
+ parser.add_argument(
+ "--adam_beta2", type=float, default=0.95, help="The beta2 parameter for the Adam and Prodigy optimizers."
+ )
+ parser.add_argument(
+ "--prodigy_beta3",
+ type=float,
+ default=None,
+ help="Coefficients for computing the Prodigy optimizer's stepsize using running averages. If set to None, uses the value of square root of beta2.",
+ )
+ parser.add_argument("--prodigy_decouple", action="store_true", help="Use AdamW style decoupled weight decay")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params")
+ parser.add_argument(
+ "--adam_epsilon",
+ type=float,
+ default=1e-08,
+ help="Epsilon value for the Adam optimizer and Prodigy optimizers.",
+ )
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--prodigy_use_bias_correction", action="store_true", help="Turn on Adam's bias correction.")
+ parser.add_argument(
+ "--prodigy_safeguard_warmup",
+ action="store_true",
+ help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage.",
+ )
+
+ # Other information
+ parser.add_argument("--tracker_name", type=str, default=None, help="Project tracker name")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help="Directory where logs are stored.",
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default=None,
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+
+ return parser.parse_args()
+
+
+class VideoDataset(Dataset):
+ def __init__(
+ self,
+ instance_data_root: Optional[str] = None,
+ dataset_name: Optional[str] = None,
+ dataset_config_name: Optional[str] = None,
+ caption_column: str = "text",
+ video_column: str = "video",
+ height: int = 480,
+ width: int = 720,
+ fps: int = 8,
+ max_num_frames: int = 49,
+ skip_frames_start: int = 0,
+ skip_frames_end: int = 0,
+ cache_dir: Optional[str] = None,
+ id_token: Optional[str] = None,
+ ) -> None:
+ super().__init__()
+
+ self.instance_data_root = Path(instance_data_root) if instance_data_root is not None else None
+ self.dataset_name = dataset_name
+ self.dataset_config_name = dataset_config_name
+ self.caption_column = caption_column
+ self.video_column = video_column
+ self.height = height
+ self.width = width
+ self.fps = fps
+ self.max_num_frames = max_num_frames
+ self.skip_frames_start = skip_frames_start
+ self.skip_frames_end = skip_frames_end
+ self.cache_dir = cache_dir
+ self.id_token = id_token or ""
+
+ if dataset_name is not None:
+ self.instance_prompts, self.instance_video_paths = self._load_dataset_from_hub()
+ else:
+ self.instance_prompts, self.instance_video_paths = self._load_dataset_from_local_path()
+
+ self.num_instance_videos = len(self.instance_video_paths)
+ if self.num_instance_videos != len(self.instance_prompts):
+ raise ValueError(
+ f"Expected length of instance prompts and videos to be the same but found {len(self.instance_prompts)=} and {len(self.instance_video_paths)=}. Please ensure that the number of caption prompts and videos match in your dataset."
+ )
+
+ self.instance_videos = self._preprocess_data()
+
+ def __len__(self):
+ return self.num_instance_videos
+
+ def __getitem__(self, index):
+ return {
+ "instance_prompt": self.id_token + self.instance_prompts[index],
+ "instance_video": self.instance_videos[index],
+ }
+
+ def _load_dataset_from_hub(self):
+ try:
+ from datasets import load_dataset
+ except ImportError:
+ raise ImportError(
+ "You are trying to load your data using the datasets library. If you wish to train using custom "
+ "captions please install the datasets library: `pip install datasets`. If you wish to load a "
+ "local folder containing images only, specify --instance_data_root instead."
+ )
+
+ # Downloading and loading a dataset from the hub. See more about loading custom images at
+ # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script
+ dataset = load_dataset(
+ self.dataset_name,
+ self.dataset_config_name,
+ cache_dir=self.cache_dir,
+ )
+ column_names = dataset["train"].column_names
+
+ if self.video_column is None:
+ video_column = column_names[0]
+ logger.info(f"`video_column` defaulting to {video_column}")
+ else:
+ video_column = self.video_column
+ if video_column not in column_names:
+ raise ValueError(
+ f"`--video_column` value '{video_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
+ )
+
+ if self.caption_column is None:
+ caption_column = column_names[1]
+ logger.info(f"`caption_column` defaulting to {caption_column}")
+ else:
+ caption_column = self.caption_column
+ if self.caption_column not in column_names:
+ raise ValueError(
+ f"`--caption_column` value '{self.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
+ )
+
+ instance_prompts = dataset["train"][caption_column]
+ instance_videos = [Path(self.instance_data_root, filepath) for filepath in dataset["train"][video_column]]
+
+ return instance_prompts, instance_videos
+
+ def _load_dataset_from_local_path(self):
+ if not self.instance_data_root.exists():
+ raise ValueError("Instance videos root folder does not exist")
+
+ prompt_path = self.instance_data_root.joinpath(self.caption_column)
+ video_path = self.instance_data_root.joinpath(self.video_column)
+
+ if not prompt_path.exists() or not prompt_path.is_file():
+ raise ValueError(
+ "Expected `--caption_column` to be path to a file in `--instance_data_root` containing line-separated text prompts."
+ )
+ if not video_path.exists() or not video_path.is_file():
+ raise ValueError(
+ "Expected `--video_column` to be path to a file in `--instance_data_root` containing line-separated paths to video data in the same directory."
+ )
+
+ with open(prompt_path, "r", encoding="utf-8") as file:
+ instance_prompts = [line.strip() for line in file.readlines() if len(line.strip()) > 0]
+ with open(video_path, "r", encoding="utf-8") as file:
+ instance_videos = [
+ self.instance_data_root.joinpath(line.strip()) for line in file.readlines() if len(line.strip()) > 0
+ ]
+
+ if any(not path.is_file() for path in instance_videos):
+ raise ValueError(
+ "Expected '--video_column' to be a path to a file in `--instance_data_root` containing line-separated paths to video data but found atleast one path that is not a valid file."
+ )
+
+ return instance_prompts, instance_videos
+
+ def _preprocess_data(self):
+ try:
+ import decord
+ except ImportError:
+ raise ImportError(
+ "The `decord` package is required for loading the video dataset. Install with `pip install decord`"
+ )
+
+ decord.bridge.set_bridge("torch")
+
+ videos = []
+ train_transforms = transforms.Compose(
+ [
+ transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0),
+ ]
+ )
+
+ for filename in self.instance_video_paths:
+ video_reader = decord.VideoReader(uri=filename.as_posix(), width=self.width, height=self.height)
+ video_num_frames = len(video_reader)
+
+ start_frame = min(self.skip_frames_start, video_num_frames)
+ end_frame = max(0, video_num_frames - self.skip_frames_end)
+ if end_frame <= start_frame:
+ frames = video_reader.get_batch([start_frame])
+ elif end_frame - start_frame <= self.max_num_frames:
+ frames = video_reader.get_batch(list(range(start_frame, end_frame)))
+ else:
+ indices = list(range(start_frame, end_frame, (end_frame - start_frame) // self.max_num_frames))
+ frames = video_reader.get_batch(indices)
+
+ # Ensure that we don't go over the limit
+ frames = frames[: self.max_num_frames]
+ selected_num_frames = frames.shape[0]
+
+ # Choose first (4k + 1) frames as this is how many is required by the VAE
+ remainder = (3 + (selected_num_frames % 4)) % 4
+ if remainder != 0:
+ frames = frames[:-remainder]
+ selected_num_frames = frames.shape[0]
+
+ assert (selected_num_frames - 1) % 4 == 0
+
+ # Training transforms
+ frames = frames.float()
+ frames = torch.stack([train_transforms(frame) for frame in frames], dim=0)
+ videos.append(frames.permute(0, 3, 1, 2).contiguous()) # [F, C, H, W]
+
+ return videos
+
+
+def save_model_card(
+ repo_id: str,
+ videos=None,
+ base_model: str = None,
+ validation_prompt=None,
+ repo_folder=None,
+ fps=8,
+):
+ widget_dict = []
+ if videos is not None:
+ for i, video in enumerate(videos):
+ export_to_video(video, os.path.join(repo_folder, f"final_video_{i}.mp4", fps=fps))
+ widget_dict.append(
+ {"text": validation_prompt if validation_prompt else " ", "output": {"url": f"video_{i}.mp4"}}
+ )
+
+ model_description = f"""
+# CogVideoX LoRA - {repo_id}
+
+
+
+## Model description
+
+These are {repo_id} LoRA weights for {base_model}.
+
+The weights were trained using the [CogVideoX Diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/cogvideo/train_cogvideox_lora.py).
+
+Was LoRA for the text encoder enabled? No.
+
+## Download model
+
+[Download the *.safetensors LoRA]({repo_id}/tree/main) in the Files & versions tab.
+
+## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers)
+
+```py
+from diffusers import CogVideoXPipeline
+import torch
+
+pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16).to("cuda")
+pipe.load_lora_weights("{repo_id}", weight_name="pytorch_lora_weights.safetensors", adapter_name=["cogvideox-lora"])
+
+# The LoRA adapter weights are determined by what was used for training.
+# In this case, we assume `--lora_alpha` is 32 and `--rank` is 64.
+# It can be made lower or higher from what was used in training to decrease or amplify the effect
+# of the LoRA upto a tolerance, beyond which one might notice no effect at all or overflows.
+pipe.set_adapters(["cogvideox-lora"], [32 / 64])
+
+video = pipe("{validation_prompt}", guidance_scale=6, use_dynamic_cfg=True).frames[0]
+```
+
+For more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters)
+
+## License
+
+Please adhere to the licensing terms as described [here](https://huggingface.co/THUDM/CogVideoX-5b/blob/main/LICENSE) and [here](https://huggingface.co/THUDM/CogVideoX-2b/blob/main/LICENSE).
+"""
+ model_card = load_or_create_model_card(
+ repo_id_or_path=repo_id,
+ from_training=True,
+ license="other",
+ base_model=base_model,
+ prompt=validation_prompt,
+ model_description=model_description,
+ widget=widget_dict,
+ )
+ tags = [
+ "text-to-video",
+ "diffusers-training",
+ "diffusers",
+ "lora",
+ "cogvideox",
+ "cogvideox-diffusers",
+ "template:sd-lora",
+ ]
+
+ model_card = populate_model_card(model_card, tags=tags)
+ model_card.save(os.path.join(repo_folder, "README.md"))
+
+
+def log_validation(
+ pipe,
+ args,
+ accelerator,
+ pipeline_args,
+ epoch,
+ is_final_validation: bool = False,
+):
+ logger.info(
+ f"Running validation... \n Generating {args.num_validation_videos} videos with prompt: {pipeline_args['prompt']}."
+ )
+ # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
+ scheduler_args = {}
+
+ if "variance_type" in pipe.scheduler.config:
+ variance_type = pipe.scheduler.config.variance_type
+
+ if variance_type in ["learned", "learned_range"]:
+ variance_type = "fixed_small"
+
+ scheduler_args["variance_type"] = variance_type
+
+ pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, **scheduler_args)
+ pipe = pipe.to(accelerator.device)
+ # pipe.set_progress_bar_config(disable=True)
+
+ # run inference
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
+
+ videos = []
+ for _ in range(args.num_validation_videos):
+ video = pipe(**pipeline_args, generator=generator, output_type="np").frames[0]
+ videos.append(video)
+
+ for tracker in accelerator.trackers:
+ phase_name = "test" if is_final_validation else "validation"
+ if tracker.name == "wandb":
+ video_filenames = []
+ for i, video in enumerate(videos):
+ prompt = (
+ pipeline_args["prompt"][:25]
+ .replace(" ", "_")
+ .replace(" ", "_")
+ .replace("'", "_")
+ .replace('"', "_")
+ .replace("/", "_")
+ )
+ filename = os.path.join(args.output_dir, f"{phase_name}_video_{i}_{prompt}.mp4")
+ export_to_video(video, filename, fps=8)
+ video_filenames.append(filename)
+
+ tracker.log(
+ {
+ phase_name: [
+ wandb.Video(filename, caption=f"{i}: {pipeline_args['prompt']}")
+ for i, filename in enumerate(video_filenames)
+ ]
+ }
+ )
+
+ clear_objs_and_retain_memory([pipe])
+
+ return videos
+
+
+def _get_t5_prompt_embeds(
+ tokenizer: T5Tokenizer,
+ text_encoder: T5EncoderModel,
+ prompt: Union[str, List[str]],
+ num_videos_per_prompt: int = 1,
+ max_sequence_length: int = 226,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ text_input_ids=None,
+):
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ if tokenizer is not None:
+ text_inputs = tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ add_special_tokens=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ else:
+ if text_input_ids is None:
+ raise ValueError("`text_input_ids` must be provided when the tokenizer is not specified.")
+
+ prompt_embeds = text_encoder(text_input_ids.to(device))[0]
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ _, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
+
+ return prompt_embeds
+
+
+def encode_prompt(
+ tokenizer: T5Tokenizer,
+ text_encoder: T5EncoderModel,
+ prompt: Union[str, List[str]],
+ num_videos_per_prompt: int = 1,
+ max_sequence_length: int = 226,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ text_input_ids=None,
+):
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ prompt_embeds = _get_t5_prompt_embeds(
+ tokenizer,
+ text_encoder,
+ prompt=prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ text_input_ids=text_input_ids,
+ )
+ return prompt_embeds
+
+
+def compute_prompt_embeddings(
+ tokenizer, text_encoder, prompt, max_sequence_length, device, dtype, requires_grad: bool = False
+):
+ if requires_grad:
+ prompt_embeds = encode_prompt(
+ tokenizer,
+ text_encoder,
+ prompt,
+ num_videos_per_prompt=1,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+ else:
+ with torch.no_grad():
+ prompt_embeds = encode_prompt(
+ tokenizer,
+ text_encoder,
+ prompt,
+ num_videos_per_prompt=1,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+ return prompt_embeds
+
+
+def prepare_rotary_positional_embeddings(
+ height: int,
+ width: int,
+ num_frames: int,
+ vae_scale_factor_spatial: int = 8,
+ patch_size: int = 2,
+ attention_head_dim: int = 64,
+ device: Optional[torch.device] = None,
+ base_height: int = 480,
+ base_width: int = 720,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ grid_height = height // (vae_scale_factor_spatial * patch_size)
+ grid_width = width // (vae_scale_factor_spatial * patch_size)
+ base_size_width = base_width // (vae_scale_factor_spatial * patch_size)
+ base_size_height = base_height // (vae_scale_factor_spatial * patch_size)
+
+ grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size_width, base_size_height)
+ freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
+ embed_dim=attention_head_dim,
+ crops_coords=grid_crops_coords,
+ grid_size=(grid_height, grid_width),
+ temporal_size=num_frames,
+ )
+
+ freqs_cos = freqs_cos.to(device=device)
+ freqs_sin = freqs_sin.to(device=device)
+ return freqs_cos, freqs_sin
+
+
+def get_optimizer(args, params_to_optimize, use_deepspeed: bool = False):
+ # Use DeepSpeed optimzer
+ if use_deepspeed:
+ from accelerate.utils import DummyOptim
+
+ return DummyOptim(
+ params_to_optimize,
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ eps=args.adam_epsilon,
+ weight_decay=args.adam_weight_decay,
+ )
+
+ # Optimizer creation
+ supported_optimizers = ["adam", "adamw", "prodigy"]
+ if args.optimizer not in supported_optimizers:
+ logger.warning(
+ f"Unsupported choice of optimizer: {args.optimizer}. Supported optimizers include {supported_optimizers}. Defaulting to AdamW"
+ )
+ args.optimizer = "adamw"
+
+ if args.use_8bit_adam and not (args.optimizer.lower() not in ["adam", "adamw"]):
+ logger.warning(
+ f"use_8bit_adam is ignored when optimizer is not set to 'Adam' or 'AdamW'. Optimizer was "
+ f"set to {args.optimizer.lower()}"
+ )
+
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
+ )
+
+ if args.optimizer.lower() == "adamw":
+ optimizer_class = bnb.optim.AdamW8bit if args.use_8bit_adam else torch.optim.AdamW
+
+ optimizer = optimizer_class(
+ params_to_optimize,
+ betas=(args.adam_beta1, args.adam_beta2),
+ eps=args.adam_epsilon,
+ weight_decay=args.adam_weight_decay,
+ )
+ elif args.optimizer.lower() == "adam":
+ optimizer_class = bnb.optim.Adam8bit if args.use_8bit_adam else torch.optim.Adam
+
+ optimizer = optimizer_class(
+ params_to_optimize,
+ betas=(args.adam_beta1, args.adam_beta2),
+ eps=args.adam_epsilon,
+ weight_decay=args.adam_weight_decay,
+ )
+ elif args.optimizer.lower() == "prodigy":
+ try:
+ import prodigyopt
+ except ImportError:
+ raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`")
+
+ optimizer_class = prodigyopt.Prodigy
+
+ if args.learning_rate <= 0.1:
+ logger.warning(
+ "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0"
+ )
+
+ optimizer = optimizer_class(
+ params_to_optimize,
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ beta3=args.prodigy_beta3,
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ decouple=args.prodigy_decouple,
+ use_bias_correction=args.prodigy_use_bias_correction,
+ safeguard_warmup=args.prodigy_safeguard_warmup,
+ )
+
+ return optimizer
+
+
+def main(args):
+ if args.report_to == "wandb" and args.hub_token is not None:
+ raise ValueError(
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
+ " Please use `huggingface-cli login` to authenticate with the Hub."
+ )
+
+ if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
+ # due to pytorch#99272, MPS does not yet support bfloat16.
+ raise ValueError(
+ "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
+ )
+
+ logging_dir = Path(args.output_dir, args.logging_dir)
+
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+ kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ kwargs_handlers=[kwargs],
+ )
+
+ # Disable AMP for MPS.
+ if torch.backends.mps.is_available():
+ accelerator.native_amp = False
+
+ if args.report_to == "wandb":
+ if not is_wandb_available():
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name,
+ exist_ok=True,
+ ).repo_id
+
+ # Prepare models and scheduler
+ tokenizer = AutoTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
+ )
+
+ text_encoder = T5EncoderModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
+ )
+
+ # CogVideoX-2b weights are stored in float16
+ # CogVideoX-5b and CogVideoX-5b-I2V weights are stored in bfloat16
+ load_dtype = torch.bfloat16 if "5b" in args.pretrained_model_name_or_path.lower() else torch.float16
+ transformer = CogVideoXTransformer3DModel.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="transformer",
+ torch_dtype=load_dtype,
+ revision=args.revision,
+ variant=args.variant,
+ )
+
+ vae = AutoencoderKLCogVideoX.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant
+ )
+
+ scheduler = CogVideoXDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
+
+ if args.enable_slicing:
+ vae.enable_slicing()
+ if args.enable_tiling:
+ vae.enable_tiling()
+
+ # We only train the additional adapter LoRA layers
+ text_encoder.requires_grad_(False)
+ transformer.requires_grad_(False)
+ vae.requires_grad_(False)
+
+ # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
+ # as these weights are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.state.deepspeed_plugin:
+ # DeepSpeed is handling precision, use what's in the DeepSpeed config
+ if (
+ "fp16" in accelerator.state.deepspeed_plugin.deepspeed_config
+ and accelerator.state.deepspeed_plugin.deepspeed_config["fp16"]["enabled"]
+ ):
+ weight_dtype = torch.float16
+ if (
+ "bf16" in accelerator.state.deepspeed_plugin.deepspeed_config
+ and accelerator.state.deepspeed_plugin.deepspeed_config["bf16"]["enabled"]
+ ):
+ weight_dtype = torch.float16
+ else:
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:
+ # due to pytorch#99272, MPS does not yet support bfloat16.
+ raise ValueError(
+ "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
+ )
+
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
+ transformer.to(accelerator.device, dtype=weight_dtype)
+ vae.to(accelerator.device, dtype=weight_dtype)
+
+ if args.gradient_checkpointing:
+ transformer.enable_gradient_checkpointing()
+
+ # now we will add new LoRA weights to the attention layers
+ transformer_lora_config = LoraConfig(
+ r=args.rank,
+ lora_alpha=args.lora_alpha,
+ init_lora_weights=True,
+ target_modules=["to_k", "to_q", "to_v", "to_out.0"],
+ )
+ transformer.add_adapter(transformer_lora_config)
+
+ def unwrap_model(model):
+ model = accelerator.unwrap_model(model)
+ model = model._orig_mod if is_compiled_module(model) else model
+ return model
+
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
+ def save_model_hook(models, weights, output_dir):
+ if accelerator.is_main_process:
+ transformer_lora_layers_to_save = None
+
+ for model in models:
+ if isinstance(model, type(unwrap_model(transformer))):
+ transformer_lora_layers_to_save = get_peft_model_state_dict(model)
+ else:
+ raise ValueError(f"unexpected save model: {model.__class__}")
+
+ # make sure to pop weight so that corresponding model is not saved again
+ weights.pop()
+
+ CogVideoXPipeline.save_lora_weights(
+ output_dir,
+ transformer_lora_layers=transformer_lora_layers_to_save,
+ )
+
+ def load_model_hook(models, input_dir):
+ transformer_ = None
+
+ while len(models) > 0:
+ model = models.pop()
+
+ if isinstance(model, type(unwrap_model(transformer))):
+ transformer_ = model
+ else:
+ raise ValueError(f"Unexpected save model: {model.__class__}")
+
+ lora_state_dict = CogVideoXPipeline.lora_state_dict(input_dir)
+
+ transformer_state_dict = {
+ f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.")
+ }
+ transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
+ incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
+ if incompatible_keys is not None:
+ # check only for unexpected keys
+ unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
+ if unexpected_keys:
+ logger.warning(
+ f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
+ f" {unexpected_keys}. "
+ )
+
+ # Make sure the trainable params are in float32. This is again needed since the base models
+ # are in `weight_dtype`. More details:
+ # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
+ if args.mixed_precision == "fp16":
+ # only upcast trainable parameters (LoRA) into fp32
+ cast_training_params([transformer_])
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32 and torch.cuda.is_available():
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ # Make sure the trainable params are in float32.
+ if args.mixed_precision == "fp16":
+ # only upcast trainable parameters (LoRA) into fp32
+ cast_training_params([transformer], dtype=torch.float32)
+
+ transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters()))
+
+ # Optimization parameters
+ transformer_parameters_with_lr = {"params": transformer_lora_parameters, "lr": args.learning_rate}
+ params_to_optimize = [transformer_parameters_with_lr]
+
+ use_deepspeed_optimizer = (
+ accelerator.state.deepspeed_plugin is not None
+ and "optimizer" in accelerator.state.deepspeed_plugin.deepspeed_config
+ )
+ use_deepspeed_scheduler = (
+ accelerator.state.deepspeed_plugin is not None
+ and "scheduler" not in accelerator.state.deepspeed_plugin.deepspeed_config
+ )
+
+ optimizer = get_optimizer(args, params_to_optimize, use_deepspeed=use_deepspeed_optimizer)
+
+ # Dataset and DataLoader
+ train_dataset = VideoDataset(
+ instance_data_root=args.instance_data_root,
+ dataset_name=args.dataset_name,
+ dataset_config_name=args.dataset_config_name,
+ caption_column=args.caption_column,
+ video_column=args.video_column,
+ height=args.height,
+ width=args.width,
+ fps=args.fps,
+ max_num_frames=args.max_num_frames,
+ skip_frames_start=args.skip_frames_start,
+ skip_frames_end=args.skip_frames_end,
+ cache_dir=args.cache_dir,
+ id_token=args.id_token,
+ )
+
+ def encode_video(video):
+ video = video.to(accelerator.device, dtype=vae.dtype).unsqueeze(0)
+ video = video.permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
+ latent_dist = vae.encode(video).latent_dist
+ return latent_dist
+
+ train_dataset.instance_videos = [encode_video(video) for video in train_dataset.instance_videos]
+
+ def collate_fn(examples):
+ videos = [example["instance_video"].sample() * vae.config.scaling_factor for example in examples]
+ prompts = [example["instance_prompt"] for example in examples]
+
+ videos = torch.cat(videos)
+ videos = videos.to(memory_format=torch.contiguous_format).float()
+
+ return {
+ "videos": videos,
+ "prompts": prompts,
+ }
+
+ train_dataloader = DataLoader(
+ train_dataset,
+ batch_size=args.train_batch_size,
+ shuffle=True,
+ collate_fn=collate_fn,
+ num_workers=args.dataloader_num_workers,
+ )
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ if use_deepspeed_scheduler:
+ from accelerate.utils import DummyScheduler
+
+ lr_scheduler = DummyScheduler(
+ name=args.lr_scheduler,
+ optimizer=optimizer,
+ total_num_steps=args.max_train_steps * accelerator.num_processes,
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
+ )
+ else:
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
+ num_cycles=args.lr_num_cycles,
+ power=args.lr_power,
+ )
+
+ # Prepare everything with our `accelerator`.
+ transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ transformer, optimizer, train_dataloader, lr_scheduler
+ )
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ tracker_name = args.tracker_name or "cogvideox-lora"
+ accelerator.init_trackers(tracker_name, config=vars(args))
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+ num_trainable_parameters = sum(param.numel() for model in params_to_optimize for param in model["params"])
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num trainable parameters = {num_trainable_parameters}")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
+ logger.info(f" Num epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if not args.resume_from_checkpoint:
+ initial_global_step = 0
+ else:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the mos recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ initial_global_step = 0
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ initial_global_step = global_step
+ first_epoch = global_step // num_update_steps_per_epoch
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
+ vae_scale_factor_spatial = 2 ** (len(vae.config.block_out_channels) - 1)
+
+ # For DeepSpeed training
+ model_config = transformer.module.config if hasattr(transformer, "module") else transformer.config
+
+ for epoch in range(first_epoch, args.num_train_epochs):
+ transformer.train()
+
+ for step, batch in enumerate(train_dataloader):
+ models_to_accumulate = [transformer]
+
+ with accelerator.accumulate(models_to_accumulate):
+ model_input = batch["videos"].permute(0, 2, 1, 3, 4).to(dtype=weight_dtype) # [B, F, C, H, W]
+ prompts = batch["prompts"]
+
+ # encode prompts
+ prompt_embeds = compute_prompt_embeddings(
+ tokenizer,
+ text_encoder,
+ prompts,
+ model_config.max_text_seq_length,
+ accelerator.device,
+ weight_dtype,
+ requires_grad=False,
+ )
+
+ # Sample noise that will be added to the latents
+ noise = torch.randn_like(model_input)
+ batch_size, num_frames, num_channels, height, width = model_input.shape
+
+ # Sample a random timestep for each image
+ timesteps = torch.randint(
+ 0, scheduler.config.num_train_timesteps, (batch_size,), device=model_input.device
+ )
+ timesteps = timesteps.long()
+
+ # Prepare rotary embeds
+ image_rotary_emb = (
+ prepare_rotary_positional_embeddings(
+ height=args.height,
+ width=args.width,
+ num_frames=num_frames,
+ vae_scale_factor_spatial=vae_scale_factor_spatial,
+ patch_size=model_config.patch_size,
+ attention_head_dim=model_config.attention_head_dim,
+ device=accelerator.device,
+ )
+ if model_config.use_rotary_positional_embeddings
+ else None
+ )
+
+ # Add noise to the model input according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ noisy_model_input = scheduler.add_noise(model_input, noise, timesteps)
+
+ # Predict the noise residual
+ model_output = transformer(
+ hidden_states=noisy_model_input,
+ encoder_hidden_states=prompt_embeds,
+ timestep=timesteps,
+ image_rotary_emb=image_rotary_emb,
+ return_dict=False,
+ )[0]
+ model_pred = scheduler.get_velocity(model_output, noisy_model_input, timesteps)
+
+ alphas_cumprod = scheduler.alphas_cumprod[timesteps]
+ weights = 1 / (1 - alphas_cumprod)
+ while len(weights.shape) < len(model_pred.shape):
+ weights = weights.unsqueeze(-1)
+
+ target = model_input
+
+ loss = torch.mean((weights * (model_pred - target) ** 2).reshape(batch_size, -1), dim=1)
+ loss = loss.mean()
+ accelerator.backward(loss)
+
+ if accelerator.sync_gradients:
+ params_to_clip = transformer.parameters()
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
+
+ if accelerator.state.deepspeed_plugin is None:
+ optimizer.step()
+ optimizer.zero_grad()
+
+ lr_scheduler.step()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ global_step += 1
+
+ if accelerator.is_main_process:
+ if global_step % args.checkpointing_steps == 0:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"Removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+ accelerator.log(logs, step=global_step)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ if accelerator.is_main_process:
+ if args.validation_prompt is not None and (epoch + 1) % args.validation_epochs == 0:
+ # Create pipeline
+ pipe = CogVideoXPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ transformer=unwrap_model(transformer),
+ text_encoder=unwrap_model(text_encoder),
+ vae=unwrap_model(vae),
+ scheduler=scheduler,
+ revision=args.revision,
+ variant=args.variant,
+ torch_dtype=weight_dtype,
+ )
+
+ validation_prompts = args.validation_prompt.split(args.validation_prompt_separator)
+ for validation_prompt in validation_prompts:
+ pipeline_args = {
+ "prompt": validation_prompt,
+ "guidance_scale": args.guidance_scale,
+ "use_dynamic_cfg": args.use_dynamic_cfg,
+ "height": args.height,
+ "width": args.width,
+ }
+
+ validation_outputs = log_validation(
+ pipe=pipe,
+ args=args,
+ accelerator=accelerator,
+ pipeline_args=pipeline_args,
+ epoch=epoch,
+ )
+
+ # Save the lora layers
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ transformer = unwrap_model(transformer)
+ dtype = (
+ torch.float16
+ if args.mixed_precision == "fp16"
+ else torch.bfloat16
+ if args.mixed_precision == "bf16"
+ else torch.float32
+ )
+ transformer = transformer.to(dtype)
+ transformer_lora_layers = get_peft_model_state_dict(transformer)
+
+ CogVideoXPipeline.save_lora_weights(
+ save_directory=args.output_dir,
+ transformer_lora_layers=transformer_lora_layers,
+ )
+
+ # Final test inference
+ pipe = CogVideoXPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ revision=args.revision,
+ variant=args.variant,
+ torch_dtype=weight_dtype,
+ )
+ pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config)
+
+ if args.enable_slicing:
+ pipe.vae.enable_slicing()
+ if args.enable_tiling:
+ pipe.vae.enable_tiling()
+
+ # Load LoRA weights
+ lora_scaling = args.lora_alpha / args.rank
+ pipe.load_lora_weights(args.output_dir, adapter_name="cogvideox-lora")
+ pipe.set_adapters(["cogvideox-lora"], [lora_scaling])
+
+ # Run inference
+ validation_outputs = []
+ if args.validation_prompt and args.num_validation_videos > 0:
+ validation_prompts = args.validation_prompt.split(args.validation_prompt_separator)
+ for validation_prompt in validation_prompts:
+ pipeline_args = {
+ "prompt": validation_prompt,
+ "guidance_scale": args.guidance_scale,
+ "use_dynamic_cfg": args.use_dynamic_cfg,
+ "height": args.height,
+ "width": args.width,
+ }
+
+ video = log_validation(
+ pipe=pipe,
+ args=args,
+ accelerator=accelerator,
+ pipeline_args=pipeline_args,
+ epoch=epoch,
+ is_final_validation=True,
+ )
+ validation_outputs.extend(video)
+
+ if args.push_to_hub:
+ save_model_card(
+ repo_id,
+ videos=validation_outputs,
+ base_model=args.pretrained_model_name_or_path,
+ validation_prompt=args.validation_prompt,
+ repo_folder=args.output_dir,
+ fps=args.fps,
+ )
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ args = get_args()
+ main(args)
diff --git a/diffusers/examples/community/README.md b/diffusers/examples/community/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..e51124e759560abd6c829c90f597e45de31dec70
--- /dev/null
+++ b/diffusers/examples/community/README.md
@@ -0,0 +1,4402 @@
+# Community Pipeline Examples
+
+> **For more information about community pipelines, please have a look at [this issue](https://github.com/huggingface/diffusers/issues/841).**
+
+**Community pipeline** examples consist pipelines that have been added by the community.
+Please have a look at the following tables to get an overview of all community examples. Click on the **Code Example** to get a copy-and-paste ready code example that you can try out.
+If a community pipeline doesn't work as expected, please open an issue and ping the author on it.
+
+Please also check out our [Community Scripts](https://github.com/huggingface/diffusers/blob/main/examples/community/README_community_scripts.md) examples for tips and tricks that you can use with diffusers without having to run a community pipeline.
+
+| Example | Description | Code Example | Colab | Author |
+|:--------------------------------------------------------------------------------------------------------------------------------------|:---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:------------------------------------------------------------------------------------------|:-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------:|
+|Flux with CFG|[Flux with CFG](https://github.com/ToTheBeginning/PuLID/blob/main/docs/pulid_for_flux.md) provides an implementation of using CFG in [Flux](https://blackforestlabs.ai/announcing-black-forest-labs/).|[Flux with CFG](#flux-with-cfg)|NA|[Linoy Tsaban](https://github.com/linoytsaban), [Apolinário](https://github.com/apolinario), and [Sayak Paul](https://github.com/sayakpaul)|
+|Differential Diffusion|[Differential Diffusion](https://github.com/exx8/differential-diffusion) modifies an image according to a text prompt, and according to a map that specifies the amount of change in each region.|[Differential Diffusion](#differential-diffusion)|[](https://huggingface.co/spaces/exx8/differential-diffusion) [](https://colab.research.google.com/github/exx8/differential-diffusion/blob/main/examples/SD2.ipynb)|[Eran Levin](https://github.com/exx8) and [Ohad Fried](https://www.ohadf.com/)|
+| HD-Painter | [HD-Painter](https://github.com/Picsart-AI-Research/HD-Painter) enables prompt-faithfull and high resolution (up to 2k) image inpainting upon any diffusion-based image inpainting method. | [HD-Painter](#hd-painter) | [](https://huggingface.co/spaces/PAIR/HD-Painter) | [Manukyan Hayk](https://github.com/haikmanukyan) and [Sargsyan Andranik](https://github.com/AndranikSargsyan) |
+| Marigold Monocular Depth Estimation | A universal monocular depth estimator, utilizing Stable Diffusion, delivering sharp predictions in the wild. (See the [project page](https://marigoldmonodepth.github.io) and [full codebase](https://github.com/prs-eth/marigold) for more details.) | [Marigold Depth Estimation](#marigold-depth-estimation) | [](https://huggingface.co/spaces/toshas/marigold) [](https://colab.research.google.com/drive/12G8reD13DdpMie5ZQlaFNo2WCGeNUH-u?usp=sharing) | [Bingxin Ke](https://github.com/markkua) and [Anton Obukhov](https://github.com/toshas) |
+| LLM-grounded Diffusion (LMD+) | LMD greatly improves the prompt following ability of text-to-image generation models by introducing an LLM as a front-end prompt parser and layout planner. [Project page.](https://llm-grounded-diffusion.github.io/) [See our full codebase (also with diffusers).](https://github.com/TonyLianLong/LLM-groundedDiffusion) | [LLM-grounded Diffusion (LMD+)](#llm-grounded-diffusion) | [Huggingface Demo](https://huggingface.co/spaces/longlian/llm-grounded-diffusion) [](https://colab.research.google.com/drive/1SXzMSeAB-LJYISb2yrUOdypLz4OYWUKj) | [Long (Tony) Lian](https://tonylian.com/) |
+| CLIP Guided Stable Diffusion | Doing CLIP guidance for text to image generation with Stable Diffusion | [CLIP Guided Stable Diffusion](#clip-guided-stable-diffusion) | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/CLIP_Guided_Stable_diffusion_with_diffusers.ipynb) | [Suraj Patil](https://github.com/patil-suraj/) |
+| One Step U-Net (Dummy) | Example showcasing of how to use Community Pipelines (see ) | [One Step U-Net](#one-step-unet) | - | [Patrick von Platen](https://github.com/patrickvonplaten/) |
+| Stable Diffusion Interpolation | Interpolate the latent space of Stable Diffusion between different prompts/seeds | [Stable Diffusion Interpolation](#stable-diffusion-interpolation) | - | [Nate Raw](https://github.com/nateraw/) |
+| Stable Diffusion Mega | **One** Stable Diffusion Pipeline with all functionalities of [Text2Image](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py), [Image2Image](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py) and [Inpainting](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py) | [Stable Diffusion Mega](#stable-diffusion-mega) | - | [Patrick von Platen](https://github.com/patrickvonplaten/) |
+| Long Prompt Weighting Stable Diffusion | **One** Stable Diffusion Pipeline without tokens length limit, and support parsing weighting in prompt. | [Long Prompt Weighting Stable Diffusion](#long-prompt-weighting-stable-diffusion) | - | [SkyTNT](https://github.com/SkyTNT) |
+| Speech to Image | Using automatic-speech-recognition to transcribe text and Stable Diffusion to generate images | [Speech to Image](#speech-to-image) | - | [Mikail Duzenli](https://github.com/MikailINTech)
+| Wild Card Stable Diffusion | Stable Diffusion Pipeline that supports prompts that contain wildcard terms (indicated by surrounding double underscores), with values instantiated randomly from a corresponding txt file or a dictionary of possible values | [Wildcard Stable Diffusion](#wildcard-stable-diffusion) | - | [Shyam Sudhakaran](https://github.com/shyamsn97) |
+| [Composable Stable Diffusion](https://energy-based-model.github.io/Compositional-Visual-Generation-with-Composable-Diffusion-Models/) | Stable Diffusion Pipeline that supports prompts that contain "|" in prompts (as an AND condition) and weights (separated by "|" as well) to positively / negatively weight prompts. | [Composable Stable Diffusion](#composable-stable-diffusion) | - | [Mark Rich](https://github.com/MarkRich) |
+| Seed Resizing Stable Diffusion | Stable Diffusion Pipeline that supports resizing an image and retaining the concepts of the 512 by 512 generation. | [Seed Resizing](#seed-resizing) | - | [Mark Rich](https://github.com/MarkRich) |
+| Imagic Stable Diffusion | Stable Diffusion Pipeline that enables writing a text prompt to edit an existing image | [Imagic Stable Diffusion](#imagic-stable-diffusion) | - | [Mark Rich](https://github.com/MarkRich) |
+| Multilingual Stable Diffusion | Stable Diffusion Pipeline that supports prompts in 50 different languages. | [Multilingual Stable Diffusion](#multilingual-stable-diffusion-pipeline) | - | [Juan Carlos Piñeros](https://github.com/juancopi81) |
+| GlueGen Stable Diffusion | Stable Diffusion Pipeline that supports prompts in different languages using GlueGen adapter. | [GlueGen Stable Diffusion](#gluegen-stable-diffusion-pipeline) | - | [Phạm Hồng Vinh](https://github.com/rootonchair) |
+| Image to Image Inpainting Stable Diffusion | Stable Diffusion Pipeline that enables the overlaying of two images and subsequent inpainting | [Image to Image Inpainting Stable Diffusion](#image-to-image-inpainting-stable-diffusion) | - | [Alex McKinney](https://github.com/vvvm23) |
+| Text Based Inpainting Stable Diffusion | Stable Diffusion Inpainting Pipeline that enables passing a text prompt to generate the mask for inpainting | [Text Based Inpainting Stable Diffusion](#text-based-inpainting-stable-diffusion) | - | [Dhruv Karan](https://github.com/unography) |
+| Bit Diffusion | Diffusion on discrete data | [Bit Diffusion](#bit-diffusion) | - | [Stuti R.](https://github.com/kingstut) |
+| K-Diffusion Stable Diffusion | Run Stable Diffusion with any of [K-Diffusion's samplers](https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py) | [Stable Diffusion with K Diffusion](#stable-diffusion-with-k-diffusion) | - | [Patrick von Platen](https://github.com/patrickvonplaten/) |
+| Checkpoint Merger Pipeline | Diffusion Pipeline that enables merging of saved model checkpoints | [Checkpoint Merger Pipeline](#checkpoint-merger-pipeline) | - | [Naga Sai Abhinay Devarinti](https://github.com/Abhinay1997/) |
+| Stable Diffusion v1.1-1.4 Comparison | Run all 4 model checkpoints for Stable Diffusion and compare their results together | [Stable Diffusion Comparison](#stable-diffusion-comparisons) | - | [Suvaditya Mukherjee](https://github.com/suvadityamuk) |
+| MagicMix | Diffusion Pipeline for semantic mixing of an image and a text prompt | [MagicMix](#magic-mix) | - | [Partho Das](https://github.com/daspartho) |
+| Stable UnCLIP | Diffusion Pipeline for combining prior model (generate clip image embedding from text, UnCLIPPipeline `"kakaobrain/karlo-v1-alpha"`) and decoder pipeline (decode clip image embedding to image, StableDiffusionImageVariationPipeline `"lambdalabs/sd-image-variations-diffusers"` ). | [Stable UnCLIP](#stable-unclip) | - | [Ray Wang](https://wrong.wang) |
+| UnCLIP Text Interpolation Pipeline | Diffusion Pipeline that allows passing two prompts and produces images while interpolating between the text-embeddings of the two prompts | [UnCLIP Text Interpolation Pipeline](#unclip-text-interpolation-pipeline) | - | [Naga Sai Abhinay Devarinti](https://github.com/Abhinay1997/) |
+| UnCLIP Image Interpolation Pipeline | Diffusion Pipeline that allows passing two images/image_embeddings and produces images while interpolating between their image-embeddings | [UnCLIP Image Interpolation Pipeline](#unclip-image-interpolation-pipeline) | - | [Naga Sai Abhinay Devarinti](https://github.com/Abhinay1997/) |
+| DDIM Noise Comparative Analysis Pipeline | Investigating how the diffusion models learn visual concepts from each noise level (which is a contribution of [P2 weighting (CVPR 2022)](https://arxiv.org/abs/2204.00227)) | [DDIM Noise Comparative Analysis Pipeline](#ddim-noise-comparative-analysis-pipeline) | - | [Aengus (Duc-Anh)](https://github.com/aengusng8) |
+| CLIP Guided Img2Img Stable Diffusion Pipeline | Doing CLIP guidance for image to image generation with Stable Diffusion | [CLIP Guided Img2Img Stable Diffusion](#clip-guided-img2img-stable-diffusion) | - | [Nipun Jindal](https://github.com/nipunjindal/) |
+| TensorRT Stable Diffusion Text to Image Pipeline | Accelerates the Stable Diffusion Text2Image Pipeline using TensorRT | [TensorRT Stable Diffusion Text to Image Pipeline](#tensorrt-text2image-stable-diffusion-pipeline) | - | [Asfiya Baig](https://github.com/asfiyab-nvidia) |
+| EDICT Image Editing Pipeline | Diffusion pipeline for text-guided image editing | [EDICT Image Editing Pipeline](#edict-image-editing-pipeline) | - | [Joqsan Azocar](https://github.com/Joqsan) |
+| Stable Diffusion RePaint | Stable Diffusion pipeline using [RePaint](https://arxiv.org/abs/2201.09865) for inpainting. | [Stable Diffusion RePaint](#stable-diffusion-repaint ) | - | [Markus Pobitzer](https://github.com/Markus-Pobitzer) |
+| TensorRT Stable Diffusion Image to Image Pipeline | Accelerates the Stable Diffusion Image2Image Pipeline using TensorRT | [TensorRT Stable Diffusion Image to Image Pipeline](#tensorrt-image2image-stable-diffusion-pipeline) | - | [Asfiya Baig](https://github.com/asfiyab-nvidia) |
+| Stable Diffusion IPEX Pipeline | Accelerate Stable Diffusion inference pipeline with BF16/FP32 precision on Intel Xeon CPUs with [IPEX](https://github.com/intel/intel-extension-for-pytorch) | [Stable Diffusion on IPEX](#stable-diffusion-on-ipex) | - | [Yingjie Han](https://github.com/yingjie-han/) |
+| CLIP Guided Images Mixing Stable Diffusion Pipeline | Сombine images using usual diffusion models. | [CLIP Guided Images Mixing Using Stable Diffusion](#clip-guided-images-mixing-with-stable-diffusion) | - | [Karachev Denis](https://github.com/TheDenk) |
+| TensorRT Stable Diffusion Inpainting Pipeline | Accelerates the Stable Diffusion Inpainting Pipeline using TensorRT | [TensorRT Stable Diffusion Inpainting Pipeline](#tensorrt-inpainting-stable-diffusion-pipeline) | - | [Asfiya Baig](https://github.com/asfiyab-nvidia) |
+| IADB Pipeline | Implementation of [Iterative α-(de)Blending: a Minimalist Deterministic Diffusion Model](https://arxiv.org/abs/2305.03486) | [IADB Pipeline](#iadb-pipeline) | - | [Thomas Chambon](https://github.com/tchambon)
+| Zero1to3 Pipeline | Implementation of [Zero-1-to-3: Zero-shot One Image to 3D Object](https://arxiv.org/abs/2303.11328) | [Zero1to3 Pipeline](#zero1to3-pipeline) | - | [Xin Kong](https://github.com/kxhit) |
+| Stable Diffusion XL Long Weighted Prompt Pipeline | A pipeline support unlimited length of prompt and negative prompt, use A1111 style of prompt weighting | [Stable Diffusion XL Long Weighted Prompt Pipeline](#stable-diffusion-xl-long-weighted-prompt-pipeline) | [](https://colab.research.google.com/drive/1LsqilswLR40XLLcp6XFOl5nKb_wOe26W?usp=sharing) | [Andrew Zhu](https://xhinker.medium.com/) |
+| FABRIC - Stable Diffusion with feedback Pipeline | pipeline supports feedback from liked and disliked images | [Stable Diffusion Fabric Pipeline](#stable-diffusion-fabric-pipeline) | - | [Shauray Singh](https://shauray8.github.io/about_shauray/) |
+| sketch inpaint - Inpainting with non-inpaint Stable Diffusion | sketch inpaint much like in automatic1111 | [Masked Im2Im Stable Diffusion Pipeline](#stable-diffusion-masked-im2im) | - | [Anatoly Belikov](https://github.com/noskill) |
+| sketch inpaint xl - Inpainting with non-inpaint Stable Diffusion | sketch inpaint much like in automatic1111 | [Masked Im2Im Stable Diffusion XL Pipeline](#stable-diffusion-xl-masked-im2im) | - | [Anatoly Belikov](https://github.com/noskill) |
+| prompt-to-prompt | change parts of a prompt and retain image structure (see [paper page](https://prompt-to-prompt.github.io/)) | [Prompt2Prompt Pipeline](#prompt2prompt-pipeline) | - | [Umer H. Adil](https://twitter.com/UmerHAdil) |
+| Latent Consistency Pipeline | Implementation of [Latent Consistency Models: Synthesizing High-Resolution Images with Few-Step Inference](https://arxiv.org/abs/2310.04378) | [Latent Consistency Pipeline](#latent-consistency-pipeline) | - | [Simian Luo](https://github.com/luosiallen) |
+| Latent Consistency Img2img Pipeline | Img2img pipeline for Latent Consistency Models | [Latent Consistency Img2Img Pipeline](#latent-consistency-img2img-pipeline) | - | [Logan Zoellner](https://github.com/nagolinc) |
+| Latent Consistency Interpolation Pipeline | Interpolate the latent space of Latent Consistency Models with multiple prompts | [Latent Consistency Interpolation Pipeline](#latent-consistency-interpolation-pipeline) | [](https://colab.research.google.com/drive/1pK3NrLWJSiJsBynLns1K1-IDTW9zbPvl?usp=sharing) | [Aryan V S](https://github.com/a-r-r-o-w) |
+| SDE Drag Pipeline | The pipeline supports drag editing of images using stochastic differential equations | [SDE Drag Pipeline](#sde-drag-pipeline) | - | [NieShen](https://github.com/NieShenRuc) [Fengqi Zhu](https://github.com/Monohydroxides) |
+| Regional Prompting Pipeline | Assign multiple prompts for different regions | [Regional Prompting Pipeline](#regional-prompting-pipeline) | - | [hako-mikan](https://github.com/hako-mikan) |
+| LDM3D-sr (LDM3D upscaler) | Upscale low resolution RGB and depth inputs to high resolution | [StableDiffusionUpscaleLDM3D Pipeline](https://github.com/estelleafl/diffusers/tree/ldm3d_upscaler_community/examples/community#stablediffusionupscaleldm3d-pipeline) | - | [Estelle Aflalo](https://github.com/estelleafl) |
+| AnimateDiff ControlNet Pipeline | Combines AnimateDiff with precise motion control using ControlNets | [AnimateDiff ControlNet Pipeline](#animatediff-controlnet-pipeline) | [](https://colab.research.google.com/drive/1SKboYeGjEQmQPWoFC0aLYpBlYdHXkvAu?usp=sharing) | [Aryan V S](https://github.com/a-r-r-o-w) and [Edoardo Botta](https://github.com/EdoardoBotta) |
+| DemoFusion Pipeline | Implementation of [DemoFusion: Democratising High-Resolution Image Generation With No $$$](https://arxiv.org/abs/2311.16973) | [DemoFusion Pipeline](#demofusion) | - | [Ruoyi Du](https://github.com/RuoyiDu) |
+| Instaflow Pipeline | Implementation of [InstaFlow! One-Step Stable Diffusion with Rectified Flow](https://arxiv.org/abs/2309.06380) | [Instaflow Pipeline](#instaflow-pipeline) | - | [Ayush Mangal](https://github.com/ayushtues) |
+| Null-Text Inversion Pipeline | Implement [Null-text Inversion for Editing Real Images using Guided Diffusion Models](https://arxiv.org/abs/2211.09794) as a pipeline. | [Null-Text Inversion](https://github.com/google/prompt-to-prompt/) | - | [Junsheng Luan](https://github.com/Junsheng121) |
+| Rerender A Video Pipeline | Implementation of [[SIGGRAPH Asia 2023] Rerender A Video: Zero-Shot Text-Guided Video-to-Video Translation](https://arxiv.org/abs/2306.07954) | [Rerender A Video Pipeline](#rerender-a-video) | - | [Yifan Zhou](https://github.com/SingleZombie) |
+| StyleAligned Pipeline | Implementation of [Style Aligned Image Generation via Shared Attention](https://arxiv.org/abs/2312.02133) | [StyleAligned Pipeline](#stylealigned-pipeline) | [](https://drive.google.com/file/d/15X2E0jFPTajUIjS0FzX50OaHsCbP2lQ0/view?usp=sharing) | [Aryan V S](https://github.com/a-r-r-o-w) |
+| AnimateDiff Image-To-Video Pipeline | Experimental Image-To-Video support for AnimateDiff (open to improvements) | [AnimateDiff Image To Video Pipeline](#animatediff-image-to-video-pipeline) | [](https://drive.google.com/file/d/1TvzCDPHhfFtdcJZe4RLloAwyoLKuttWK/view?usp=sharing) | [Aryan V S](https://github.com/a-r-r-o-w) |
+| IP Adapter FaceID Stable Diffusion | Stable Diffusion Pipeline that supports IP Adapter Face ID | [IP Adapter Face ID](#ip-adapter-face-id) | - | [Fabio Rigano](https://github.com/fabiorigano) |
+| InstantID Pipeline | Stable Diffusion XL Pipeline that supports InstantID | [InstantID Pipeline](#instantid-pipeline) | [](https://huggingface.co/spaces/InstantX/InstantID) | [Haofan Wang](https://github.com/haofanwang) |
+| UFOGen Scheduler | Scheduler for UFOGen Model (compatible with Stable Diffusion pipelines) | [UFOGen Scheduler](#ufogen-scheduler) | - | [dg845](https://github.com/dg845) |
+| Stable Diffusion XL IPEX Pipeline | Accelerate Stable Diffusion XL inference pipeline with BF16/FP32 precision on Intel Xeon CPUs with [IPEX](https://github.com/intel/intel-extension-for-pytorch) | [Stable Diffusion XL on IPEX](#stable-diffusion-xl-on-ipex) | - | [Dan Li](https://github.com/ustcuna/) |
+| Stable Diffusion BoxDiff Pipeline | Training-free controlled generation with bounding boxes using [BoxDiff](https://github.com/showlab/BoxDiff) | [Stable Diffusion BoxDiff Pipeline](#stable-diffusion-boxdiff) | - | [Jingyang Zhang](https://github.com/zjysteven/) |
+| FRESCO V2V Pipeline | Implementation of [[CVPR 2024] FRESCO: Spatial-Temporal Correspondence for Zero-Shot Video Translation](https://arxiv.org/abs/2403.12962) | [FRESCO V2V Pipeline](#fresco) | - | [Yifan Zhou](https://github.com/SingleZombie) |
+| AnimateDiff IPEX Pipeline | Accelerate AnimateDiff inference pipeline with BF16/FP32 precision on Intel Xeon CPUs with [IPEX](https://github.com/intel/intel-extension-for-pytorch) | [AnimateDiff on IPEX](#animatediff-on-ipex) | - | [Dan Li](https://github.com/ustcuna/) |
+| HunyuanDiT Differential Diffusion Pipeline | Applies [Differential Diffsuion](https://github.com/exx8/differential-diffusion) to [HunyuanDiT](https://github.com/huggingface/diffusers/pull/8240). | [HunyuanDiT with Differential Diffusion](#hunyuandit-with-differential-diffusion) | [](https://colab.research.google.com/drive/1v44a5fpzyr4Ffr4v2XBQ7BajzG874N4P?usp=sharing) | [Monjoy Choudhury](https://github.com/MnCSSJ4x) |
+
+To load a custom pipeline you just need to pass the `custom_pipeline` argument to `DiffusionPipeline`, as one of the files in `diffusers/examples/community`. Feel free to send a PR with your own pipelines, we will merge them quickly.
+
+```py
+pipe = DiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", custom_pipeline="filename_in_the_community_folder")
+```
+
+## Example usages
+
+### Flux with CFG
+
+Know more about Flux [here](https://blackforestlabs.ai/announcing-black-forest-labs/). Since Flux doesn't use CFG, this implementation provides one, inspired by the [PuLID Flux adaptation](https://github.com/ToTheBeginning/PuLID/blob/main/docs/pulid_for_flux.md).
+
+Example usage:
+
+```py
+from diffusers import DiffusionPipeline
+import torch
+
+pipeline = DiffusionPipeline.from_pretrained(
+ "black-forest-labs/FLUX.1-dev",
+ torch_dtype=torch.bfloat16,
+ custom_pipeline="pipeline_flux_with_cfg"
+)
+pipeline.enable_model_cpu_offload()
+prompt = "a watercolor painting of a unicorn"
+negative_prompt = "pink"
+
+img = pipeline(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ true_cfg=1.5,
+ guidance_scale=3.5,
+ num_images_per_prompt=1,
+ generator=torch.manual_seed(0)
+).images[0]
+img.save("cfg_flux.png")
+```
+
+### Differential Diffusion
+
+**Eran Levin, Ohad Fried**
+
+**Tel Aviv University, Reichman University**
+
+Diffusion models have revolutionized image generation and editing, producing state-of-the-art results in conditioned and unconditioned image synthesis. While current techniques enable user control over the degree of change in an image edit, the controllability is limited to global changes over an entire edited region. This paper introduces a novel framework that enables customization of the amount of change per pixel or per image region. Our framework can be integrated into any existing diffusion model, enhancing it with this capability. Such granular control on the quantity of change opens up a diverse array of new editing capabilities, such as control of the extent to which individual objects are modified, or the ability to introduce gradual spatial changes. Furthermore, we showcase the framework's effectiveness in soft-inpainting---the completion of portions of an image while subtly adjusting the surrounding areas to ensure seamless integration. Additionally, we introduce a new tool for exploring the effects of different change quantities. Our framework operates solely during inference, requiring no model training or fine-tuning. We demonstrate our method with the current open state-of-the-art models, and validate it via both quantitative and qualitative comparisons, and a user study.
+
+
+
+You can find additional information about Differential Diffusion in the [paper](https://differential-diffusion.github.io/paper.pdf) or in the [project website](https://differential-diffusion.github.io/).
+
+#### Usage example
+
+```python
+import torch
+from torchvision import transforms
+
+from diffusers import DPMSolverMultistepScheduler
+from diffusers.utils import load_image
+from examples.community.pipeline_stable_diffusion_xl_differential_img2img import (
+ StableDiffusionXLDifferentialImg2ImgPipeline,
+)
+
+
+pipeline = StableDiffusionXLDifferentialImg2ImgPipeline.from_pretrained(
+ "SG161222/RealVisXL_V4.0", torch_dtype=torch.float16, variant="fp16"
+).to("cuda")
+pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, use_karras_sigmas=True)
+
+
+def preprocess_image(image):
+ image = image.convert("RGB")
+ image = transforms.CenterCrop((image.size[1] // 64 * 64, image.size[0] // 64 * 64))(image)
+ image = transforms.ToTensor()(image)
+ image = image * 2 - 1
+ image = image.unsqueeze(0).to("cuda")
+ return image
+
+
+def preprocess_map(map):
+ map = map.convert("L")
+ map = transforms.CenterCrop((map.size[1] // 64 * 64, map.size[0] // 64 * 64))(map)
+ map = transforms.ToTensor()(map)
+ map = map.to("cuda")
+ return map
+
+
+image = preprocess_image(
+ load_image(
+ "https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/20240329211129_4024911930.png?download=true"
+ )
+)
+
+mask = preprocess_map(
+ load_image(
+ "https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/gradient_mask.png?download=true"
+ )
+)
+
+prompt = "a green pear"
+negative_prompt = "blurry"
+
+image = pipeline(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ guidance_scale=7.5,
+ num_inference_steps=25,
+ original_image=image,
+ image=image,
+ strength=1.0,
+ map=mask,
+).images[0]
+
+image.save("result.png")
+```
+
+### HD-Painter
+
+Implementation of [HD-Painter: High-Resolution and Prompt-Faithful Text-Guided Image Inpainting with Diffusion Models](https://arxiv.org/abs/2312.14091).
+
+
+
+The abstract from the paper is:
+
+Recent progress in text-guided image inpainting, based on the unprecedented success of text-to-image diffusion models, has led to exceptionally realistic and visually plausible results.
+However, there is still significant potential for improvement in current text-to-image inpainting models, particularly in better aligning the inpainted area with user prompts and performing high-resolution inpainting.
+Therefore, in this paper we introduce _HD-Painter_, a completely **training-free** approach that **accurately follows to prompts** and coherently **scales to high-resolution** image inpainting.
+To this end, we design the _Prompt-Aware Introverted Attention (PAIntA)_ layer enhancing self-attention scores by prompt information and resulting in better text alignment generations.
+To further improve the prompt coherence we introduce the _Reweighting Attention Score Guidance (RASG)_ mechanism seamlessly integrating a post-hoc sampling strategy into general form of DDIM to prevent out-of-distribution latent shifts.
+Moreover, HD-Painter allows extension to larger scales by introducing a specialized super-resolution technique customized for inpainting, enabling the completion of missing regions in images of up to 2K resolution.
+Our experiments demonstrate that HD-Painter surpasses existing state-of-the-art approaches qualitatively and quantitatively, achieving an impressive generation accuracy improvement of **61.4** vs **51.9**.
+We will make the codes publicly available.
+
+You can find additional information about Text2Video-Zero in the [paper](https://arxiv.org/abs/2312.14091) or the [original codebase](https://github.com/Picsart-AI-Research/HD-Painter).
+
+#### Usage example
+
+```python
+import torch
+from diffusers import DiffusionPipeline, DDIMScheduler
+from diffusers.utils import load_image, make_image_grid
+
+pipe = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-2-inpainting",
+ custom_pipeline="hd_painter"
+)
+pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
+
+prompt = "wooden boat"
+init_image = load_image("https://raw.githubusercontent.com/Picsart-AI-Research/HD-Painter/main/__assets__/samples/images/2.jpg")
+mask_image = load_image("https://raw.githubusercontent.com/Picsart-AI-Research/HD-Painter/main/__assets__/samples/masks/2.png")
+
+image = pipe(prompt, init_image, mask_image, use_rasg=True, use_painta=True, generator=torch.manual_seed(12345)).images[0]
+
+make_image_grid([init_image, mask_image, image], rows=1, cols=3)
+```
+
+### Marigold Depth Estimation
+
+Marigold is a universal monocular depth estimator that delivers accurate and sharp predictions in the wild. Based on Stable Diffusion, it is trained exclusively with synthetic depth data and excels in zero-shot adaptation to real-world imagery. This pipeline is an official implementation of the inference process. More details can be found on our [project page](https://marigoldmonodepth.github.io) and [full codebase](https://github.com/prs-eth/marigold) (also implemented with diffusers).
+
+
+
+This depth estimation pipeline processes a single input image through multiple diffusion denoising stages to estimate depth maps. These maps are subsequently merged to produce the final output. Below is an example code snippet, including optional arguments:
+
+```python
+import numpy as np
+import torch
+from PIL import Image
+from diffusers import DiffusionPipeline
+from diffusers.utils import load_image
+
+# Original DDIM version (higher quality)
+pipe = DiffusionPipeline.from_pretrained(
+ "prs-eth/marigold-v1-0",
+ custom_pipeline="marigold_depth_estimation"
+ # torch_dtype=torch.float16, # (optional) Run with half-precision (16-bit float).
+ # variant="fp16", # (optional) Use with `torch_dtype=torch.float16`, to directly load fp16 checkpoint
+)
+
+# (New) LCM version (faster speed)
+pipe = DiffusionPipeline.from_pretrained(
+ "prs-eth/marigold-depth-lcm-v1-0",
+ custom_pipeline="marigold_depth_estimation"
+ # torch_dtype=torch.float16, # (optional) Run with half-precision (16-bit float).
+ # variant="fp16", # (optional) Use with `torch_dtype=torch.float16`, to directly load fp16 checkpoint
+)
+
+pipe.to("cuda")
+
+img_path_or_url = "https://share.phys.ethz.ch/~pf/bingkedata/marigold/pipeline_example.jpg"
+image: Image.Image = load_image(img_path_or_url)
+
+pipeline_output = pipe(
+ image, # Input image.
+ # ----- recommended setting for DDIM version -----
+ # denoising_steps=10, # (optional) Number of denoising steps of each inference pass. Default: 10.
+ # ensemble_size=10, # (optional) Number of inference passes in the ensemble. Default: 10.
+ # ------------------------------------------------
+
+ # ----- recommended setting for LCM version ------
+ # denoising_steps=4,
+ # ensemble_size=5,
+ # -------------------------------------------------
+
+ # processing_res=768, # (optional) Maximum resolution of processing. If set to 0: will not resize at all. Defaults to 768.
+ # match_input_res=True, # (optional) Resize depth prediction to match input resolution.
+ # batch_size=0, # (optional) Inference batch size, no bigger than `num_ensemble`. If set to 0, the script will automatically decide the proper batch size. Defaults to 0.
+ # seed=2024, # (optional) Random seed can be set to ensure additional reproducibility. Default: None (unseeded). Note: forcing --batch_size 1 helps to increase reproducibility. To ensure full reproducibility, deterministic mode needs to be used.
+ # color_map="Spectral", # (optional) Colormap used to colorize the depth map. Defaults to "Spectral". Set to `None` to skip colormap generation.
+ # show_progress_bar=True, # (optional) If true, will show progress bars of the inference progress.
+)
+
+depth: np.ndarray = pipeline_output.depth_np # Predicted depth map
+depth_colored: Image.Image = pipeline_output.depth_colored # Colorized prediction
+
+# Save as uint16 PNG
+depth_uint16 = (depth * 65535.0).astype(np.uint16)
+Image.fromarray(depth_uint16).save("./depth_map.png", mode="I;16")
+
+# Save colorized depth map
+depth_colored.save("./depth_colored.png")
+```
+
+### LLM-grounded Diffusion
+
+LMD and LMD+ greatly improves the prompt understanding ability of text-to-image generation models by introducing an LLM as a front-end prompt parser and layout planner. It improves spatial reasoning, the understanding of negation, attribute binding, generative numeracy, etc. in a unified manner without explicitly aiming for each. LMD is completely training-free (i.e., uses SD model off-the-shelf). LMD+ takes in additional adapters for better control. This is a reproduction of LMD+ model used in our work. [Project page.](https://llm-grounded-diffusion.github.io/) [See our full codebase (also with diffusers).](https://github.com/TonyLianLong/LLM-groundedDiffusion)
+
+
+
+
+This pipeline can be used with an LLM or on its own. We provide a parser that parses LLM outputs to the layouts. You can obtain the prompt to input to the LLM for layout generation [here](https://github.com/TonyLianLong/LLM-groundedDiffusion/blob/main/prompt.py). After feeding the prompt to an LLM (e.g., GPT-4 on ChatGPT website), you can feed the LLM response into our pipeline.
+
+The following code has been tested on 1x RTX 4090, but it should also support GPUs with lower GPU memory.
+
+#### Use this pipeline with an LLM
+
+```python
+import torch
+from diffusers import DiffusionPipeline
+
+pipe = DiffusionPipeline.from_pretrained(
+ "longlian/lmd_plus",
+ custom_pipeline="llm_grounded_diffusion",
+ custom_revision="main",
+ variant="fp16", torch_dtype=torch.float16
+)
+pipe.enable_model_cpu_offload()
+
+# Generate directly from a text prompt and an LLM response
+prompt = "a waterfall and a modern high speed train in a beautiful forest with fall foliage"
+phrases, boxes, bg_prompt, neg_prompt = pipe.parse_llm_response("""
+[('a waterfall', [71, 105, 148, 258]), ('a modern high speed train', [255, 223, 181, 149])]
+Background prompt: A beautiful forest with fall foliage
+Negative prompt:
+""")
+
+images = pipe(
+ prompt=prompt,
+ negative_prompt=neg_prompt,
+ phrases=phrases,
+ boxes=boxes,
+ gligen_scheduled_sampling_beta=0.4,
+ output_type="pil",
+ num_inference_steps=50,
+ lmd_guidance_kwargs={}
+).images
+
+images[0].save("./lmd_plus_generation.jpg")
+```
+
+#### Use this pipeline on its own for layout generation
+
+```python
+import torch
+from diffusers import DiffusionPipeline
+
+pipe = DiffusionPipeline.from_pretrained(
+ "longlian/lmd_plus",
+ custom_pipeline="llm_grounded_diffusion",
+ variant="fp16", torch_dtype=torch.float16
+)
+pipe.enable_model_cpu_offload()
+
+# Generate an image described by the prompt and
+# insert objects described by text at the region defined by bounding boxes
+prompt = "a waterfall and a modern high speed train in a beautiful forest with fall foliage"
+boxes = [[0.1387, 0.2051, 0.4277, 0.7090], [0.4980, 0.4355, 0.8516, 0.7266]]
+phrases = ["a waterfall", "a modern high speed train"]
+
+images = pipe(
+ prompt=prompt,
+ phrases=phrases,
+ boxes=boxes,
+ gligen_scheduled_sampling_beta=0.4,
+ output_type="pil",
+ num_inference_steps=50,
+ lmd_guidance_kwargs={}
+).images
+
+images[0].save("./lmd_plus_generation.jpg")
+```
+
+### CLIP Guided Stable Diffusion
+
+CLIP guided stable diffusion can help to generate more realistic images
+by guiding stable diffusion at every denoising step with an additional CLIP model.
+
+The following code requires roughly 12GB of GPU RAM.
+
+```python
+from diffusers import DiffusionPipeline
+from transformers import CLIPImageProcessor, CLIPModel
+import torch
+
+
+feature_extractor = CLIPImageProcessor.from_pretrained("laion/CLIP-ViT-B-32-laion2B-s34B-b79K")
+clip_model = CLIPModel.from_pretrained("laion/CLIP-ViT-B-32-laion2B-s34B-b79K", torch_dtype=torch.float16)
+
+
+guided_pipeline = DiffusionPipeline.from_pretrained(
+ "stable-diffusion-v1-5/stable-diffusion-v1-5",
+ custom_pipeline="clip_guided_stable_diffusion",
+ clip_model=clip_model,
+ feature_extractor=feature_extractor,
+ torch_dtype=torch.float16,
+)
+guided_pipeline.enable_attention_slicing()
+guided_pipeline = guided_pipeline.to("cuda")
+
+prompt = "fantasy book cover, full moon, fantasy forest landscape, golden vector elements, fantasy magic, dark light night, intricate, elegant, sharp focus, illustration, highly detailed, digital painting, concept art, matte, art by WLOP and Artgerm and Albert Bierstadt, masterpiece"
+
+generator = torch.Generator(device="cuda").manual_seed(0)
+images = []
+for i in range(4):
+ image = guided_pipeline(
+ prompt,
+ num_inference_steps=50,
+ guidance_scale=7.5,
+ clip_guidance_scale=100,
+ num_cutouts=4,
+ use_cutouts=False,
+ generator=generator,
+ ).images[0]
+ images.append(image)
+
+# save images locally
+for i, img in enumerate(images):
+ img.save(f"./clip_guided_sd/image_{i}.png")
+```
+
+The `images` list contains a list of PIL images that can be saved locally or displayed directly in a google colab.
+Generated images tend to be of higher quality than natively using stable diffusion. E.g. the above script generates the following images:
+
+.
+
+### One Step Unet
+
+The dummy "one-step-unet" can be run as follows:
+
+```python
+from diffusers import DiffusionPipeline
+
+pipe = DiffusionPipeline.from_pretrained("google/ddpm-cifar10-32", custom_pipeline="one_step_unet")
+pipe()
+```
+
+**Note**: This community pipeline is not useful as a feature, but rather just serves as an example of how community pipelines can be added (see ).
+
+### Stable Diffusion Interpolation
+
+The following code can be run on a GPU of at least 8GB VRAM and should take approximately 5 minutes.
+
+```python
+from diffusers import DiffusionPipeline
+import torch
+
+pipe = DiffusionPipeline.from_pretrained(
+ "CompVis/stable-diffusion-v1-4",
+ variant='fp16',
+ torch_dtype=torch.float16,
+ safety_checker=None, # Very important for videos...lots of false positives while interpolating
+ custom_pipeline="interpolate_stable_diffusion",
+).to('cuda')
+pipe.enable_attention_slicing()
+
+frame_filepaths = pipe.walk(
+ prompts=['a dog', 'a cat', 'a horse'],
+ seeds=[42, 1337, 1234],
+ num_interpolation_steps=16,
+ output_dir='./dreams',
+ batch_size=4,
+ height=512,
+ width=512,
+ guidance_scale=8.5,
+ num_inference_steps=50,
+)
+```
+
+The output of the `walk(...)` function returns a list of images saved under the folder as defined in `output_dir`. You can use these images to create videos of stable diffusion.
+
+> **Please have a look at for more in-detail information on how to create videos using stable diffusion as well as more feature-complete functionality.**
+
+### Stable Diffusion Mega
+
+The Stable Diffusion Mega Pipeline lets you use the main use cases of the stable diffusion pipeline in a single class.
+
+```python
+#!/usr/bin/env python3
+from diffusers import DiffusionPipeline
+import PIL
+import requests
+from io import BytesIO
+import torch
+
+
+def download_image(url):
+ response = requests.get(url)
+ return PIL.Image.open(BytesIO(response.content)).convert("RGB")
+
+pipe = DiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", custom_pipeline="stable_diffusion_mega", torch_dtype=torch.float16, variant="fp16")
+pipe.to("cuda")
+pipe.enable_attention_slicing()
+
+
+### Text-to-Image
+images = pipe.text2img("An astronaut riding a horse").images
+
+### Image-to-Image
+init_image = download_image("https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg")
+
+prompt = "A fantasy landscape, trending on artstation"
+
+images = pipe.img2img(prompt=prompt, image=init_image, strength=0.75, guidance_scale=7.5).images
+
+### Inpainting
+img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
+mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
+init_image = download_image(img_url).resize((512, 512))
+mask_image = download_image(mask_url).resize((512, 512))
+
+prompt = "a cat sitting on a bench"
+images = pipe.inpaint(prompt=prompt, image=init_image, mask_image=mask_image, strength=0.75).images
+```
+
+As shown above this one pipeline can run all both "text-to-image", "image-to-image", and "inpainting" in one pipeline.
+
+### Long Prompt Weighting Stable Diffusion
+
+Features of this custom pipeline:
+
+- Input a prompt without the 77 token length limit.
+- Includes tx2img, img2img, and inpainting pipelines.
+- Emphasize/weigh part of your prompt with parentheses as so: `a baby deer with (big eyes)`
+- De-emphasize part of your prompt as so: `a [baby] deer with big eyes`
+- Precisely weigh part of your prompt as so: `a baby deer with (big eyes:1.3)`
+
+Prompt weighting equivalents:
+
+- `a baby deer with` == `(a baby deer with:1.0)`
+- `(big eyes)` == `(big eyes:1.1)`
+- `((big eyes))` == `(big eyes:1.21)`
+- `[big eyes]` == `(big eyes:0.91)`
+
+You can run this custom pipeline as so:
+
+#### PyTorch
+
+```python
+from diffusers import DiffusionPipeline
+import torch
+
+pipe = DiffusionPipeline.from_pretrained(
+ 'hakurei/waifu-diffusion',
+ custom_pipeline="lpw_stable_diffusion",
+ torch_dtype=torch.float16
+)
+pipe = pipe.to("cuda")
+
+prompt = "best_quality (1girl:1.3) bow bride brown_hair closed_mouth frilled_bow frilled_hair_tubes frills (full_body:1.3) fox_ear hair_bow hair_tubes happy hood japanese_clothes kimono long_sleeves red_bow smile solo tabi uchikake white_kimono wide_sleeves cherry_blossoms"
+neg_prompt = "lowres, bad_anatomy, error_body, error_hair, error_arm, error_hands, bad_hands, error_fingers, bad_fingers, missing_fingers, error_legs, bad_legs, multiple_legs, missing_legs, error_lighting, error_shadow, error_reflection, text, error, extra_digit, fewer_digits, cropped, worst_quality, low_quality, normal_quality, jpeg_artifacts, signature, watermark, username, blurry"
+
+pipe.text2img(prompt, negative_prompt=neg_prompt, width=512, height=512, max_embeddings_multiples=3).images[0]
+```
+
+#### onnxruntime
+
+```python
+from diffusers import DiffusionPipeline
+import torch
+
+pipe = DiffusionPipeline.from_pretrained(
+ 'CompVis/stable-diffusion-v1-4',
+ custom_pipeline="lpw_stable_diffusion_onnx",
+ revision="onnx",
+ provider="CUDAExecutionProvider"
+)
+
+prompt = "a photo of an astronaut riding a horse on mars, best quality"
+neg_prompt = "lowres, bad anatomy, error body, error hair, error arm, error hands, bad hands, error fingers, bad fingers, missing fingers, error legs, bad legs, multiple legs, missing legs, error lighting, error shadow, error reflection, text, error, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry"
+
+pipe.text2img(prompt, negative_prompt=neg_prompt, width=512, height=512, max_embeddings_multiples=3).images[0]
+```
+
+If you see `Token indices sequence length is longer than the specified maximum sequence length for this model ( *** > 77 ) . Running this sequence through the model will result in indexing errors`. Do not worry, it is normal.
+
+### Speech to Image
+
+The following code can generate an image from an audio sample using pre-trained OpenAI whisper-small and Stable Diffusion.
+
+```Python
+import torch
+
+import matplotlib.pyplot as plt
+from datasets import load_dataset
+from diffusers import DiffusionPipeline
+from transformers import (
+ WhisperForConditionalGeneration,
+ WhisperProcessor,
+)
+
+
+device = "cuda" if torch.cuda.is_available() else "cpu"
+
+ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
+
+audio_sample = ds[3]
+
+text = audio_sample["text"].lower()
+speech_data = audio_sample["audio"]["array"]
+
+model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small").to(device)
+processor = WhisperProcessor.from_pretrained("openai/whisper-small")
+
+diffuser_pipeline = DiffusionPipeline.from_pretrained(
+ "CompVis/stable-diffusion-v1-4",
+ custom_pipeline="speech_to_image_diffusion",
+ speech_model=model,
+ speech_processor=processor,
+ torch_dtype=torch.float16,
+)
+
+diffuser_pipeline.enable_attention_slicing()
+diffuser_pipeline = diffuser_pipeline.to(device)
+
+output = diffuser_pipeline(speech_data)
+plt.imshow(output.images[0])
+```
+
+This example produces the following image:
+
+
+
+### Wildcard Stable Diffusion
+
+Following the great examples from and , here's a minimal implementation that allows for users to add "wildcards", denoted by `__wildcard__` to prompts that are used as placeholders for randomly sampled values given by either a dictionary or a `.txt` file. For example:
+
+Say we have a prompt:
+
+```
+prompt = "__animal__ sitting on a __object__ wearing a __clothing__"
+```
+
+We can then define possible values to be sampled for `animal`, `object`, and `clothing`. These can either be from a `.txt` with the same name as the category.
+
+The possible values can also be defined / combined by using a dictionary like: `{"animal":["dog", "cat", mouse"]}`.
+
+The actual pipeline works just like `StableDiffusionPipeline`, except the `__call__` method takes in:
+
+`wildcard_files`: list of file paths for wild card replacement
+`wildcard_option_dict`: dict with key as `wildcard` and values as a list of possible replacements
+`num_prompt_samples`: number of prompts to sample, uniformly sampling wildcards
+
+A full example:
+
+create `animal.txt`, with contents like:
+
+```
+dog
+cat
+mouse
+```
+
+create `object.txt`, with contents like:
+
+```
+chair
+sofa
+bench
+```
+
+```python
+from diffusers import DiffusionPipeline
+import torch
+
+pipe = DiffusionPipeline.from_pretrained(
+ "CompVis/stable-diffusion-v1-4",
+ custom_pipeline="wildcard_stable_diffusion",
+ torch_dtype=torch.float16,
+)
+prompt = "__animal__ sitting on a __object__ wearing a __clothing__"
+out = pipe(
+ prompt,
+ wildcard_option_dict={
+ "clothing":["hat", "shirt", "scarf", "beret"]
+ },
+ wildcard_files=["object.txt", "animal.txt"],
+ num_prompt_samples=1
+)
+```
+
+### Composable Stable diffusion
+
+[Composable Stable Diffusion](https://energy-based-model.github.io/Compositional-Visual-Generation-with-Composable-Diffusion-Models/) proposes conjunction and negation (negative prompts) operators for compositional generation with conditional diffusion models.
+
+```python
+import torch as th
+import numpy as np
+import torchvision.utils as tvu
+
+from diffusers import DiffusionPipeline
+
+import argparse
+
+parser = argparse.ArgumentParser()
+parser.add_argument("--prompt", type=str, default="mystical trees | A magical pond | dark",
+ help="use '|' as the delimiter to compose separate sentences.")
+parser.add_argument("--steps", type=int, default=50)
+parser.add_argument("--scale", type=float, default=7.5)
+parser.add_argument("--weights", type=str, default="7.5 | 7.5 | -7.5")
+parser.add_argument("--seed", type=int, default=2)
+parser.add_argument("--model_path", type=str, default="CompVis/stable-diffusion-v1-4")
+parser.add_argument("--num_images", type=int, default=1)
+args = parser.parse_args()
+
+has_cuda = th.cuda.is_available()
+device = th.device('cpu' if not has_cuda else 'cuda')
+
+prompt = args.prompt
+scale = args.scale
+steps = args.steps
+
+pipe = DiffusionPipeline.from_pretrained(
+ args.model_path,
+ custom_pipeline="composable_stable_diffusion",
+).to(device)
+
+pipe.safety_checker = None
+
+images = []
+generator = th.Generator("cuda").manual_seed(args.seed)
+for i in range(args.num_images):
+ image = pipe(prompt, guidance_scale=scale, num_inference_steps=steps,
+ weights=args.weights, generator=generator).images[0]
+ images.append(th.from_numpy(np.array(image)).permute(2, 0, 1) / 255.)
+grid = tvu.make_grid(th.stack(images, dim=0), nrow=4, padding=0)
+tvu.save_image(grid, f'{prompt}_{args.weights}' + '.png')
+```
+
+### Imagic Stable Diffusion
+
+Allows you to edit an image using stable diffusion.
+
+```python
+import requests
+from PIL import Image
+from io import BytesIO
+import torch
+import os
+from diffusers import DiffusionPipeline, DDIMScheduler
+
+has_cuda = torch.cuda.is_available()
+device = torch.device('cpu' if not has_cuda else 'cuda')
+pipe = DiffusionPipeline.from_pretrained(
+ "CompVis/stable-diffusion-v1-4",
+ safety_checker=None,
+ custom_pipeline="imagic_stable_diffusion",
+ scheduler=DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
+).to(device)
+generator = torch.Generator("cuda").manual_seed(0)
+seed = 0
+prompt = "A photo of Barack Obama smiling with a big grin"
+url = 'https://www.dropbox.com/s/6tlwzr73jd1r9yk/obama.png?dl=1'
+response = requests.get(url)
+init_image = Image.open(BytesIO(response.content)).convert("RGB")
+init_image = init_image.resize((512, 512))
+res = pipe.train(
+ prompt,
+ image=init_image,
+ generator=generator)
+res = pipe(alpha=1, guidance_scale=7.5, num_inference_steps=50)
+os.makedirs("imagic", exist_ok=True)
+image = res.images[0]
+image.save('./imagic/imagic_image_alpha_1.png')
+res = pipe(alpha=1.5, guidance_scale=7.5, num_inference_steps=50)
+image = res.images[0]
+image.save('./imagic/imagic_image_alpha_1_5.png')
+res = pipe(alpha=2, guidance_scale=7.5, num_inference_steps=50)
+image = res.images[0]
+image.save('./imagic/imagic_image_alpha_2.png')
+```
+
+### Seed Resizing
+
+Test seed resizing. Originally generate an image in 512 by 512, then generate image with same seed at 512 by 592 using seed resizing. Finally, generate 512 by 592 using original stable diffusion pipeline.
+
+```python
+import torch as th
+import numpy as np
+from diffusers import DiffusionPipeline
+
+has_cuda = th.cuda.is_available()
+device = th.device('cpu' if not has_cuda else 'cuda')
+
+pipe = DiffusionPipeline.from_pretrained(
+ "CompVis/stable-diffusion-v1-4",
+ custom_pipeline="seed_resize_stable_diffusion"
+).to(device)
+
+def dummy(images, **kwargs):
+ return images, False
+
+pipe.safety_checker = dummy
+
+
+images = []
+th.manual_seed(0)
+generator = th.Generator("cuda").manual_seed(0)
+
+seed = 0
+prompt = "A painting of a futuristic cop"
+
+width = 512
+height = 512
+
+res = pipe(
+ prompt,
+ guidance_scale=7.5,
+ num_inference_steps=50,
+ height=height,
+ width=width,
+ generator=generator)
+image = res.images[0]
+image.save('./seed_resize/seed_resize_{w}_{h}_image.png'.format(w=width, h=height))
+
+
+th.manual_seed(0)
+generator = th.Generator("cuda").manual_seed(0)
+
+pipe = DiffusionPipeline.from_pretrained(
+ "CompVis/stable-diffusion-v1-4",
+ custom_pipeline="/home/mark/open_source/diffusers/examples/community/"
+).to(device)
+
+width = 512
+height = 592
+
+res = pipe(
+ prompt,
+ guidance_scale=7.5,
+ num_inference_steps=50,
+ height=height,
+ width=width,
+ generator=generator)
+image = res.images[0]
+image.save('./seed_resize/seed_resize_{w}_{h}_image.png'.format(w=width, h=height))
+
+pipe_compare = DiffusionPipeline.from_pretrained(
+ "CompVis/stable-diffusion-v1-4",
+ custom_pipeline="/home/mark/open_source/diffusers/examples/community/"
+).to(device)
+
+res = pipe_compare(
+ prompt,
+ guidance_scale=7.5,
+ num_inference_steps=50,
+ height=height,
+ width=width,
+ generator=generator
+)
+
+image = res.images[0]
+image.save('./seed_resize/seed_resize_{w}_{h}_image_compare.png'.format(w=width, h=height))
+```
+
+### Multilingual Stable Diffusion Pipeline
+
+The following code can generate images from texts in different languages using the pre-trained [mBART-50 many-to-one multilingual machine translation model](https://huggingface.co/facebook/mbart-large-50-many-to-one-mmt) and Stable Diffusion.
+
+```python
+from PIL import Image
+
+import torch
+
+from diffusers import DiffusionPipeline
+from transformers import (
+ pipeline,
+ MBart50TokenizerFast,
+ MBartForConditionalGeneration,
+)
+device = "cuda" if torch.cuda.is_available() else "cpu"
+device_dict = {"cuda": 0, "cpu": -1}
+
+# helper function taken from: https://huggingface.co/blog/stable_diffusion
+def image_grid(imgs, rows, cols):
+ assert len(imgs) == rows*cols
+
+ w, h = imgs[0].size
+ grid = Image.new('RGB', size=(cols*w, rows*h))
+ grid_w, grid_h = grid.size
+
+ for i, img in enumerate(imgs):
+ grid.paste(img, box=(i%cols*w, i//cols*h))
+ return grid
+
+# Add language detection pipeline
+language_detection_model_ckpt = "papluca/xlm-roberta-base-language-detection"
+language_detection_pipeline = pipeline("text-classification",
+ model=language_detection_model_ckpt,
+ device=device_dict[device])
+
+# Add model for language translation
+trans_tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-one-mmt")
+trans_model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-one-mmt").to(device)
+
+diffuser_pipeline = DiffusionPipeline.from_pretrained(
+ "CompVis/stable-diffusion-v1-4",
+ custom_pipeline="multilingual_stable_diffusion",
+ detection_pipeline=language_detection_pipeline,
+ translation_model=trans_model,
+ translation_tokenizer=trans_tokenizer,
+ torch_dtype=torch.float16,
+)
+
+diffuser_pipeline.enable_attention_slicing()
+diffuser_pipeline = diffuser_pipeline.to(device)
+
+prompt = ["a photograph of an astronaut riding a horse",
+ "Una casa en la playa",
+ "Ein Hund, der Orange isst",
+ "Un restaurant parisien"]
+
+output = diffuser_pipeline(prompt)
+
+images = output.images
+
+grid = image_grid(images, rows=2, cols=2)
+```
+
+This example produces the following images:
+
+
+### GlueGen Stable Diffusion Pipeline
+
+GlueGen is a minimal adapter that allows alignment between any encoder (Text Encoder of different language, Multilingual Roberta, AudioClip) and CLIP text encoder used in standard Stable Diffusion model. This method allows easy language adaptation to available english Stable Diffusion checkpoints without the need of an image captioning dataset as well as long training hours.
+
+Make sure you downloaded `gluenet_French_clip_overnorm_over3_noln.ckpt` for French (there are also pre-trained weights for Chinese, Italian, Japanese, Spanish or train your own) at [GlueGen's official repo](https://github.com/salesforce/GlueGen/tree/main).
+
+```python
+from PIL import Image
+
+import torch
+
+from transformers import AutoModel, AutoTokenizer
+
+from diffusers import DiffusionPipeline
+
+if __name__ == "__main__":
+ device = "cuda"
+
+ lm_model_id = "xlm-roberta-large"
+ token_max_length = 77
+
+ text_encoder = AutoModel.from_pretrained(lm_model_id)
+ tokenizer = AutoTokenizer.from_pretrained(lm_model_id, model_max_length=token_max_length, use_fast=False)
+
+ tensor_norm = torch.Tensor([[43.8203],[28.3668],[27.9345],[28.0084],[28.2958],[28.2576],[28.3373],[28.2695],[28.4097],[28.2790],[28.2825],[28.2807],[28.2775],[28.2708],[28.2682],[28.2624],[28.2589],[28.2611],[28.2616],[28.2639],[28.2613],[28.2566],[28.2615],[28.2665],[28.2799],[28.2885],[28.2852],[28.2863],[28.2780],[28.2818],[28.2764],[28.2532],[28.2412],[28.2336],[28.2514],[28.2734],[28.2763],[28.2977],[28.2971],[28.2948],[28.2818],[28.2676],[28.2831],[28.2890],[28.2979],[28.2999],[28.3117],[28.3363],[28.3554],[28.3626],[28.3589],[28.3597],[28.3543],[28.3660],[28.3731],[28.3717],[28.3812],[28.3753],[28.3810],[28.3777],[28.3693],[28.3713],[28.3670],[28.3691],[28.3679],[28.3624],[28.3703],[28.3703],[28.3720],[28.3594],[28.3576],[28.3562],[28.3438],[28.3376],[28.3389],[28.3433],[28.3191]])
+
+ pipeline = DiffusionPipeline.from_pretrained(
+ "stable-diffusion-v1-5/stable-diffusion-v1-5",
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ custom_pipeline="gluegen"
+ ).to(device)
+ pipeline.load_language_adapter("gluenet_French_clip_overnorm_over3_noln.ckpt", num_token=token_max_length, dim=1024, dim_out=768, tensor_norm=tensor_norm)
+
+ prompt = "une voiture sur la plage"
+
+ generator = torch.Generator(device=device).manual_seed(42)
+ image = pipeline(prompt, generator=generator).images[0]
+ image.save("gluegen_output_fr.png")
+```
+
+Which will produce:
+
+
+
+### Image to Image Inpainting Stable Diffusion
+
+Similar to the standard stable diffusion inpainting example, except with the addition of an `inner_image` argument.
+
+`image`, `inner_image`, and `mask` should have the same dimensions. `inner_image` should have an alpha (transparency) channel.
+
+The aim is to overlay two images, then mask out the boundary between `image` and `inner_image` to allow stable diffusion to make the connection more seamless.
+For example, this could be used to place a logo on a shirt and make it blend seamlessly.
+
+```python
+import PIL
+import torch
+
+from diffusers import DiffusionPipeline
+
+image_path = "./path-to-image.png"
+inner_image_path = "./path-to-inner-image.png"
+mask_path = "./path-to-mask.png"
+
+init_image = PIL.Image.open(image_path).convert("RGB").resize((512, 512))
+inner_image = PIL.Image.open(inner_image_path).convert("RGBA").resize((512, 512))
+mask_image = PIL.Image.open(mask_path).convert("RGB").resize((512, 512))
+
+pipe = DiffusionPipeline.from_pretrained(
+ "runwayml/stable-diffusion-inpainting",
+ custom_pipeline="img2img_inpainting",
+ torch_dtype=torch.float16
+)
+pipe = pipe.to("cuda")
+
+prompt = "Your prompt here!"
+image = pipe(prompt=prompt, image=init_image, inner_image=inner_image, mask_image=mask_image).images[0]
+```
+
+
+
+### Text Based Inpainting Stable Diffusion
+
+Use a text prompt to generate the mask for the area to be inpainted.
+Currently uses the CLIPSeg model for mask generation, then calls the standard Stable Diffusion Inpainting pipeline to perform the inpainting.
+
+```python
+from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
+from diffusers import DiffusionPipeline
+
+from PIL import Image
+import requests
+
+processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
+model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
+
+pipe = DiffusionPipeline.from_pretrained(
+ "runwayml/stable-diffusion-inpainting",
+ custom_pipeline="text_inpainting",
+ segmentation_model=model,
+ segmentation_processor=processor
+)
+pipe = pipe.to("cuda")
+
+
+url = "https://github.com/timojl/clipseg/blob/master/example_image.jpg?raw=true"
+image = Image.open(requests.get(url, stream=True).raw).resize((512, 512))
+text = "a glass" # will mask out this text
+prompt = "a cup" # the masked out region will be replaced with this
+
+image = pipe(image=image, text=text, prompt=prompt).images[0]
+```
+
+### Bit Diffusion
+
+Based , this is used for diffusion on discrete data - eg, discrete image data, DNA sequence data. An unconditional discrete image can be generated like this:
+
+```python
+from diffusers import DiffusionPipeline
+
+pipe = DiffusionPipeline.from_pretrained("google/ddpm-cifar10-32", custom_pipeline="bit_diffusion")
+image = pipe().images[0]
+```
+
+### Stable Diffusion with K Diffusion
+
+Make sure you have @crowsonkb's installed:
+
+```sh
+pip install k-diffusion
+```
+
+You can use the community pipeline as follows:
+
+```python
+from diffusers import DiffusionPipeline
+
+pipe = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", custom_pipeline="sd_text2img_k_diffusion")
+pipe = pipe.to("cuda")
+
+prompt = "an astronaut riding a horse on mars"
+pipe.set_scheduler("sample_heun")
+generator = torch.Generator(device="cuda").manual_seed(seed)
+image = pipe(prompt, generator=generator, num_inference_steps=20).images[0]
+
+image.save("./astronaut_heun_k_diffusion.png")
+```
+
+To make sure that K Diffusion and `diffusers` yield the same results:
+
+**Diffusers**:
+
+```python
+from diffusers import DiffusionPipeline, EulerDiscreteScheduler
+
+seed = 33
+
+pipe = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
+pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
+pipe = pipe.to("cuda")
+
+generator = torch.Generator(device="cuda").manual_seed(seed)
+image = pipe(prompt, generator=generator, num_inference_steps=50).images[0]
+```
+
+
+
+**K Diffusion**:
+
+```python
+from diffusers import DiffusionPipeline, EulerDiscreteScheduler
+
+seed = 33
+
+pipe = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", custom_pipeline="sd_text2img_k_diffusion")
+pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
+pipe = pipe.to("cuda")
+
+pipe.set_scheduler("sample_euler")
+generator = torch.Generator(device="cuda").manual_seed(seed)
+image = pipe(prompt, generator=generator, num_inference_steps=50).images[0]
+```
+
+
+
+### Checkpoint Merger Pipeline
+
+Based on the AUTOMATIC1111/webui for checkpoint merging. This is a custom pipeline that merges up to 3 pretrained model checkpoints as long as they are in the HuggingFace model_index.json format.
+
+The checkpoint merging is currently memory intensive as it modifies the weights of a DiffusionPipeline object in place. Expect at least 13GB RAM usage on Kaggle GPU kernels and
+on Colab you might run out of the 12GB memory even while merging two checkpoints.
+
+Usage:-
+
+```python
+from diffusers import DiffusionPipeline
+
+# Return a CheckpointMergerPipeline class that allows you to merge checkpoints.
+# The checkpoint passed here is ignored. But still pass one of the checkpoints you plan to
+# merge for convenience
+pipe = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", custom_pipeline="checkpoint_merger")
+
+# There are multiple possible scenarios:
+# The pipeline with the merged checkpoints is returned in all the scenarios
+
+# Compatible checkpoints a.k.a matched model_index.json files. Ignores the meta attributes in model_index.json during comparison.( attrs with _ as prefix )
+merged_pipe = pipe.merge(["CompVis/stable-diffusion-v1-4"," CompVis/stable-diffusion-v1-2"], interp="sigmoid", alpha=0.4)
+
+# Incompatible checkpoints in model_index.json but merge might be possible. Use force=True to ignore model_index.json compatibility
+merged_pipe_1 = pipe.merge(["CompVis/stable-diffusion-v1-4", "hakurei/waifu-diffusion"], force=True, interp="sigmoid", alpha=0.4)
+
+# Three checkpoint merging. Only "add_difference" method actually works on all three checkpoints. Using any other options will ignore the 3rd checkpoint.
+merged_pipe_2 = pipe.merge(["CompVis/stable-diffusion-v1-4", "hakurei/waifu-diffusion", "prompthero/openjourney"], force=True, interp="add_difference", alpha=0.4)
+
+prompt = "An astronaut riding a horse on Mars"
+
+image = merged_pipe(prompt).images[0]
+```
+
+Some examples along with the merge details:
+
+1. "CompVis/stable-diffusion-v1-4" + "hakurei/waifu-diffusion" ; Sigmoid interpolation; alpha = 0.8
+
+
+
+2. "hakurei/waifu-diffusion" + "prompthero/openjourney" ; Inverse Sigmoid interpolation; alpha = 0.8
+
+
+
+3. "CompVis/stable-diffusion-v1-4" + "hakurei/waifu-diffusion" + "prompthero/openjourney"; Add Difference interpolation; alpha = 0.5
+
+
+
+### Stable Diffusion Comparisons
+
+This Community Pipeline enables the comparison between the 4 checkpoints that exist for Stable Diffusion. They can be found through the following links:
+
+1. [Stable Diffusion v1.1](https://huggingface.co/CompVis/stable-diffusion-v1-1)
+2. [Stable Diffusion v1.2](https://huggingface.co/CompVis/stable-diffusion-v1-2)
+3. [Stable Diffusion v1.3](https://huggingface.co/CompVis/stable-diffusion-v1-3)
+4. [Stable Diffusion v1.4](https://huggingface.co/CompVis/stable-diffusion-v1-4)
+
+```python
+from diffusers import DiffusionPipeline
+import matplotlib.pyplot as plt
+
+pipe = DiffusionPipeline.from_pretrained('CompVis/stable-diffusion-v1-4', custom_pipeline='suvadityamuk/StableDiffusionComparison')
+pipe.enable_attention_slicing()
+pipe = pipe.to('cuda')
+prompt = "an astronaut riding a horse on mars"
+output = pipe(prompt)
+
+plt.subplots(2,2,1)
+plt.imshow(output.images[0])
+plt.title('Stable Diffusion v1.1')
+plt.axis('off')
+plt.subplots(2,2,2)
+plt.imshow(output.images[1])
+plt.title('Stable Diffusion v1.2')
+plt.axis('off')
+plt.subplots(2,2,3)
+plt.imshow(output.images[2])
+plt.title('Stable Diffusion v1.3')
+plt.axis('off')
+plt.subplots(2,2,4)
+plt.imshow(output.images[3])
+plt.title('Stable Diffusion v1.4')
+plt.axis('off')
+
+plt.show()
+```
+
+As a result, you can look at a grid of all 4 generated images being shown together, that captures a difference the advancement of the training between the 4 checkpoints.
+
+### Magic Mix
+
+Implementation of the [MagicMix: Semantic Mixing with Diffusion Models](https://arxiv.org/abs/2210.16056) paper. This is a Diffusion Pipeline for semantic mixing of an image and a text prompt to create a new concept while preserving the spatial layout and geometry of the subject in the image. The pipeline takes an image that provides the layout semantics and a prompt that provides the content semantics for the mixing process.
+
+There are 3 parameters for the method-
+
+- `mix_factor`: It is the interpolation constant used in the layout generation phase. The greater the value of `mix_factor`, the greater the influence of the prompt on the layout generation process.
+- `kmax` and `kmin`: These determine the range for the layout and content generation process. A higher value of kmax results in loss of more information about the layout of the original image and a higher value of kmin results in more steps for content generation process.
+
+Here is an example usage-
+
+```python
+from diffusers import DiffusionPipeline, DDIMScheduler
+from PIL import Image
+
+pipe = DiffusionPipeline.from_pretrained(
+ "CompVis/stable-diffusion-v1-4",
+ custom_pipeline="magic_mix",
+ scheduler=DDIMScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler"),
+).to('cuda')
+
+img = Image.open('phone.jpg')
+mix_img = pipe(
+ img,
+ prompt='bed',
+ kmin=0.3,
+ kmax=0.5,
+ mix_factor=0.5,
+ )
+mix_img.save('phone_bed_mix.jpg')
+```
+
+The `mix_img` is a PIL image that can be saved locally or displayed directly in a google colab. Generated image is a mix of the layout semantics of the given image and the content semantics of the prompt.
+
+E.g. the above script generates the following image:
+
+`phone.jpg`
+
+
+
+`phone_bed_mix.jpg`
+
+
+
+For more example generations check out this [demo notebook](https://github.com/daspartho/MagicMix/blob/main/demo.ipynb).
+
+### Stable UnCLIP
+
+UnCLIPPipeline("kakaobrain/karlo-v1-alpha") provides a prior model that can generate clip image embedding from text.
+StableDiffusionImageVariationPipeline("lambdalabs/sd-image-variations-diffusers") provides a decoder model than can generate images from clip image embedding.
+
+```python
+import torch
+from diffusers import DiffusionPipeline
+
+device = torch.device("cpu" if not torch.cuda.is_available() else "cuda")
+
+pipeline = DiffusionPipeline.from_pretrained(
+ "kakaobrain/karlo-v1-alpha",
+ torch_dtype=torch.float16,
+ custom_pipeline="stable_unclip",
+ decoder_pipe_kwargs=dict(
+ image_encoder=None,
+ ),
+)
+pipeline.to(device)
+
+prompt = "a shiba inu wearing a beret and black turtleneck"
+random_generator = torch.Generator(device=device).manual_seed(1000)
+output = pipeline(
+ prompt=prompt,
+ width=512,
+ height=512,
+ generator=random_generator,
+ prior_guidance_scale=4,
+ prior_num_inference_steps=25,
+ decoder_guidance_scale=8,
+ decoder_num_inference_steps=50,
+)
+
+image = output.images[0]
+image.save("./shiba-inu.jpg")
+
+# debug
+
+# `pipeline.decoder_pipe` is a regular StableDiffusionImageVariationPipeline instance.
+# It is used to convert clip image embedding to latents, then fed into VAE decoder.
+print(pipeline.decoder_pipe.__class__)
+#
+
+# this pipeline only uses prior module in "kakaobrain/karlo-v1-alpha"
+# It is used to convert clip text embedding to clip image embedding.
+print(pipeline)
+# StableUnCLIPPipeline {
+# "_class_name": "StableUnCLIPPipeline",
+# "_diffusers_version": "0.12.0.dev0",
+# "prior": [
+# "diffusers",
+# "PriorTransformer"
+# ],
+# "prior_scheduler": [
+# "diffusers",
+# "UnCLIPScheduler"
+# ],
+# "text_encoder": [
+# "transformers",
+# "CLIPTextModelWithProjection"
+# ],
+# "tokenizer": [
+# "transformers",
+# "CLIPTokenizer"
+# ]
+# }
+
+# pipeline.prior_scheduler is the scheduler used for prior in UnCLIP.
+print(pipeline.prior_scheduler)
+# UnCLIPScheduler {
+# "_class_name": "UnCLIPScheduler",
+# "_diffusers_version": "0.12.0.dev0",
+# "clip_sample": true,
+# "clip_sample_range": 5.0,
+# "num_train_timesteps": 1000,
+# "prediction_type": "sample",
+# "variance_type": "fixed_small_log"
+# }
+```
+
+`shiba-inu.jpg`
+
+
+
+### UnCLIP Text Interpolation Pipeline
+
+This Diffusion Pipeline takes two prompts and interpolates between the two input prompts using spherical interpolation ( slerp ). The input prompts are converted to text embeddings by the pipeline's text_encoder and the interpolation is done on the resulting text_embeddings over the number of steps specified. Defaults to 5 steps.
+
+```python
+import torch
+from diffusers import DiffusionPipeline
+
+device = torch.device("cpu" if not torch.cuda.is_available() else "cuda")
+
+pipe = DiffusionPipeline.from_pretrained(
+ "kakaobrain/karlo-v1-alpha",
+ torch_dtype=torch.float16,
+ custom_pipeline="unclip_text_interpolation"
+)
+pipe.to(device)
+
+start_prompt = "A photograph of an adult lion"
+end_prompt = "A photograph of a lion cub"
+# For best results keep the prompts close in length to each other. Of course, feel free to try out with differing lengths.
+generator = torch.Generator(device=device).manual_seed(42)
+
+output = pipe(start_prompt, end_prompt, steps=6, generator=generator, enable_sequential_cpu_offload=False)
+
+for i,image in enumerate(output.images):
+ img.save('result%s.jpg' % i)
+```
+
+The resulting images in order:-
+
+
+
+
+
+
+
+
+### UnCLIP Image Interpolation Pipeline
+
+This Diffusion Pipeline takes two images or an image_embeddings tensor of size 2 and interpolates between their embeddings using spherical interpolation ( slerp ). The input images/image_embeddings are converted to image embeddings by the pipeline's image_encoder and the interpolation is done on the resulting image_embeddings over the number of steps specified. Defaults to 5 steps.
+
+```python
+import torch
+from diffusers import DiffusionPipeline
+from PIL import Image
+
+device = torch.device("cpu" if not torch.cuda.is_available() else "cuda")
+dtype = torch.float16 if torch.cuda.is_available() else torch.bfloat16
+
+pipe = DiffusionPipeline.from_pretrained(
+ "kakaobrain/karlo-v1-alpha-image-variations",
+ torch_dtype=dtype,
+ custom_pipeline="unclip_image_interpolation"
+)
+pipe.to(device)
+
+images = [Image.open('./starry_night.jpg'), Image.open('./flowers.jpg')]
+# For best results keep the prompts close in length to each other. Of course, feel free to try out with differing lengths.
+generator = torch.Generator(device=device).manual_seed(42)
+
+output = pipe(image=images, steps=6, generator=generator)
+
+for i,image in enumerate(output.images):
+ image.save('starry_to_flowers_%s.jpg' % i)
+```
+
+The original images:-
+
+
+
+
+The resulting images in order:-
+
+
+
+
+
+
+
+
+### DDIM Noise Comparative Analysis Pipeline
+
+#### **Research question: What visual concepts do the diffusion models learn from each noise level during training?**
+
+The [P2 weighting (CVPR 2022)](https://arxiv.org/abs/2204.00227) paper proposed an approach to answer the above question, which is their second contribution.
+The approach consists of the following steps:
+
+1. The input is an image x0.
+2. Perturb it to xt using a diffusion process q(xt|x0).
+ - `strength` is a value between 0.0 and 1.0, that controls the amount of noise that is added to the input image. Values that approach 1.0 allow for lots of variations but will also produce images that are not semantically consistent with the input.
+3. Reconstruct the image with the learned denoising process pθ(ˆx0|xt).
+4. Compare x0 and ˆx0 among various t to show how each step contributes to the sample.
+The authors used [openai/guided-diffusion](https://github.com/openai/guided-diffusion) model to denoise images in FFHQ dataset. This pipeline extends their second contribution by investigating DDIM on any input image.
+
+```python
+import torch
+from PIL import Image
+import numpy as np
+
+image_path = "path/to/your/image" # images from CelebA-HQ might be better
+image_pil = Image.open(image_path)
+image_name = image_path.split("/")[-1].split(".")[0]
+
+device = torch.device("cpu" if not torch.cuda.is_available() else "cuda")
+pipe = DiffusionPipeline.from_pretrained(
+ "google/ddpm-ema-celebahq-256",
+ custom_pipeline="ddim_noise_comparative_analysis",
+)
+pipe = pipe.to(device)
+
+for strength in np.linspace(0.1, 1, 25):
+ denoised_image, latent_timestep = pipe(
+ image_pil, strength=strength, return_dict=False
+ )
+ denoised_image = denoised_image[0]
+ denoised_image.save(
+ f"noise_comparative_analysis_{image_name}_{latent_timestep}.png"
+ )
+```
+
+Here is the result of this pipeline (which is DDIM) on CelebA-HQ dataset.
+
+
+
+### CLIP Guided Img2Img Stable Diffusion
+
+CLIP guided Img2Img stable diffusion can help to generate more realistic images with an initial image
+by guiding stable diffusion at every denoising step with an additional CLIP model.
+
+The following code requires roughly 12GB of GPU RAM.
+
+```python
+from io import BytesIO
+import requests
+import torch
+from diffusers import DiffusionPipeline
+from PIL import Image
+from transformers import CLIPImageProcessor, CLIPModel
+
+feature_extractor = CLIPImageProcessor.from_pretrained(
+ "laion/CLIP-ViT-B-32-laion2B-s34B-b79K"
+)
+clip_model = CLIPModel.from_pretrained(
+ "laion/CLIP-ViT-B-32-laion2B-s34B-b79K", torch_dtype=torch.float16
+)
+guided_pipeline = DiffusionPipeline.from_pretrained(
+ "CompVis/stable-diffusion-v1-4",
+ # custom_pipeline="clip_guided_stable_diffusion",
+ custom_pipeline="/home/njindal/diffusers/examples/community/clip_guided_stable_diffusion.py",
+ clip_model=clip_model,
+ feature_extractor=feature_extractor,
+ torch_dtype=torch.float16,
+)
+guided_pipeline.enable_attention_slicing()
+guided_pipeline = guided_pipeline.to("cuda")
+prompt = "fantasy book cover, full moon, fantasy forest landscape, golden vector elements, fantasy magic, dark light night, intricate, elegant, sharp focus, illustration, highly detailed, digital painting, concept art, matte, art by WLOP and Artgerm and Albert Bierstadt, masterpiece"
+url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
+response = requests.get(url)
+init_image = Image.open(BytesIO(response.content)).convert("RGB")
+image = guided_pipeline(
+ prompt=prompt,
+ num_inference_steps=30,
+ image=init_image,
+ strength=0.75,
+ guidance_scale=7.5,
+ clip_guidance_scale=100,
+ num_cutouts=4,
+ use_cutouts=False,
+).images[0]
+display(image)
+```
+
+Init Image
+
+
+
+Output Image
+
+
+
+### TensorRT Text2Image Stable Diffusion Pipeline
+
+The TensorRT Pipeline can be used to accelerate the Text2Image Stable Diffusion Inference run.
+
+NOTE: The ONNX conversions and TensorRT engine build may take up to 30 minutes.
+
+```python
+import torch
+from diffusers import DDIMScheduler
+from diffusers.pipelines import DiffusionPipeline
+
+# Use the DDIMScheduler scheduler here instead
+scheduler = DDIMScheduler.from_pretrained("stabilityai/stable-diffusion-2-1", subfolder="scheduler")
+
+pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1",
+ custom_pipeline="stable_diffusion_tensorrt_txt2img",
+ variant='fp16',
+ torch_dtype=torch.float16,
+ scheduler=scheduler,)
+
+# re-use cached folder to save ONNX models and TensorRT Engines
+pipe.set_cached_folder("stabilityai/stable-diffusion-2-1", variant='fp16',)
+
+pipe = pipe.to("cuda")
+
+prompt = "a beautiful photograph of Mt. Fuji during cherry blossom"
+image = pipe(prompt).images[0]
+image.save('tensorrt_mt_fuji.png')
+```
+
+### EDICT Image Editing Pipeline
+
+This pipeline implements the text-guided image editing approach from the paper [EDICT: Exact Diffusion Inversion via Coupled Transformations](https://arxiv.org/abs/2211.12446). You have to pass:
+
+- (`PIL`) `image` you want to edit.
+- `base_prompt`: the text prompt describing the current image (before editing).
+- `target_prompt`: the text prompt describing with the edits.
+
+```python
+from diffusers import DiffusionPipeline, DDIMScheduler
+from transformers import CLIPTextModel
+import torch, PIL, requests
+from io import BytesIO
+from IPython.display import display
+
+def center_crop_and_resize(im):
+
+ width, height = im.size
+ d = min(width, height)
+ left = (width - d) / 2
+ upper = (height - d) / 2
+ right = (width + d) / 2
+ lower = (height + d) / 2
+
+ return im.crop((left, upper, right, lower)).resize((512, 512))
+
+torch_dtype = torch.float16
+device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+
+# scheduler and text_encoder param values as in the paper
+scheduler = DDIMScheduler(
+ num_train_timesteps=1000,
+ beta_start=0.00085,
+ beta_end=0.012,
+ beta_schedule="scaled_linear",
+ set_alpha_to_one=False,
+ clip_sample=False,
+)
+
+text_encoder = CLIPTextModel.from_pretrained(
+ pretrained_model_name_or_path="openai/clip-vit-large-patch14",
+ torch_dtype=torch_dtype,
+)
+
+# initialize pipeline
+pipeline = DiffusionPipeline.from_pretrained(
+ pretrained_model_name_or_path="CompVis/stable-diffusion-v1-4",
+ custom_pipeline="edict_pipeline",
+ variant="fp16",
+ scheduler=scheduler,
+ text_encoder=text_encoder,
+ leapfrog_steps=True,
+ torch_dtype=torch_dtype,
+).to(device)
+
+# download image
+image_url = "https://huggingface.co/datasets/Joqsan/images/resolve/main/imagenet_dog_1.jpeg"
+response = requests.get(image_url)
+image = PIL.Image.open(BytesIO(response.content))
+
+# preprocess it
+cropped_image = center_crop_and_resize(image)
+
+# define the prompts
+base_prompt = "A dog"
+target_prompt = "A golden retriever"
+
+# run the pipeline
+result_image = pipeline(
+ base_prompt=base_prompt,
+ target_prompt=target_prompt,
+ image=cropped_image,
+)
+
+display(result_image)
+```
+
+Init Image
+
+
+
+Output Image
+
+
+
+### Stable Diffusion RePaint
+
+This pipeline uses the [RePaint](https://arxiv.org/abs/2201.09865) logic on the latent space of stable diffusion. It can
+be used similarly to other image inpainting pipelines but does not rely on a specific inpainting model. This means you can use
+models that are not specifically created for inpainting.
+
+Make sure to use the ```RePaintScheduler``` as shown in the example below.
+
+Disclaimer: The mask gets transferred into latent space, this may lead to unexpected changes on the edge of the masked part.
+The inference time is a lot slower.
+
+```py
+import PIL
+import requests
+import torch
+from io import BytesIO
+from diffusers import StableDiffusionPipeline, RePaintScheduler
+
+def download_image(url):
+ response = requests.get(url)
+ return PIL.Image.open(BytesIO(response.content)).convert("RGB")
+img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
+mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
+init_image = download_image(img_url).resize((512, 512))
+mask_image = download_image(mask_url).resize((512, 512))
+mask_image = PIL.ImageOps.invert(mask_image)
+pipe = StableDiffusionPipeline.from_pretrained(
+ "CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16, custom_pipeline="stable_diffusion_repaint",
+)
+pipe.scheduler = RePaintScheduler.from_config(pipe.scheduler.config)
+pipe = pipe.to("cuda")
+prompt = "Face of a yellow cat, high resolution, sitting on a park bench"
+image = pipe(prompt=prompt, image=init_image, mask_image=mask_image).images[0]
+```
+
+### TensorRT Image2Image Stable Diffusion Pipeline
+
+The TensorRT Pipeline can be used to accelerate the Image2Image Stable Diffusion Inference run.
+
+NOTE: The ONNX conversions and TensorRT engine build may take up to 30 minutes.
+
+```python
+import requests
+from io import BytesIO
+from PIL import Image
+import torch
+from diffusers import DDIMScheduler
+from diffusers import DiffusionPipeline
+
+# Use the DDIMScheduler scheduler here instead
+scheduler = DDIMScheduler.from_pretrained("stabilityai/stable-diffusion-2-1",
+ subfolder="scheduler")
+
+pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1",
+ custom_pipeline="stable_diffusion_tensorrt_img2img",
+ variant='fp16',
+ torch_dtype=torch.float16,
+ scheduler=scheduler,)
+
+# re-use cached folder to save ONNX models and TensorRT Engines
+pipe.set_cached_folder("stabilityai/stable-diffusion-2-1", variant='fp16',)
+
+pipe = pipe.to("cuda")
+
+url = "https://pajoca.com/wp-content/uploads/2022/09/tekito-yamakawa-1.png"
+response = requests.get(url)
+input_image = Image.open(BytesIO(response.content)).convert("RGB")
+prompt = "photorealistic new zealand hills"
+image = pipe(prompt, image=input_image, strength=0.75,).images[0]
+image.save('tensorrt_img2img_new_zealand_hills.png')
+```
+
+### Stable Diffusion BoxDiff
+BoxDiff is a training-free method for controlled generation with bounding box coordinates. It should work with any Stable Diffusion model. Below shows an example with `stable-diffusion-2-1-base`.
+```py
+import torch
+from PIL import Image, ImageDraw
+from copy import deepcopy
+
+from examples.community.pipeline_stable_diffusion_boxdiff import StableDiffusionBoxDiffPipeline
+
+def draw_box_with_text(img, boxes, names):
+ colors = ["red", "olive", "blue", "green", "orange", "brown", "cyan", "purple"]
+ img_new = deepcopy(img)
+ draw = ImageDraw.Draw(img_new)
+
+ W, H = img.size
+ for bid, box in enumerate(boxes):
+ draw.rectangle([box[0] * W, box[1] * H, box[2] * W, box[3] * H], outline=colors[bid % len(colors)], width=4)
+ draw.text((box[0] * W, box[1] * H), names[bid], fill=colors[bid % len(colors)])
+ return img_new
+
+pipe = StableDiffusionBoxDiffPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-2-1-base",
+ torch_dtype=torch.float16,
+)
+pipe.to("cuda")
+
+# example 1
+prompt = "as the aurora lights up the sky, a herd of reindeer leisurely wanders on the grassy meadow, admiring the breathtaking view, a serene lake quietly reflects the magnificent display, and in the distance, a snow-capped mountain stands majestically, fantasy, 8k, highly detailed"
+phrases = [
+ "aurora",
+ "reindeer",
+ "meadow",
+ "lake",
+ "mountain"
+]
+boxes = [[1,3,512,202], [75,344,421,495], [1,327,508,507], [2,217,507,341], [1,135,509,242]]
+
+# example 2
+# prompt = "A rabbit wearing sunglasses looks very proud"
+# phrases = ["rabbit", "sunglasses"]
+# boxes = [[67,87,366,512], [66,130,364,262]]
+
+boxes = [[x / 512 for x in box] for box in boxes]
+
+images = pipe(
+ prompt,
+ boxdiff_phrases=phrases,
+ boxdiff_boxes=boxes,
+ boxdiff_kwargs={
+ "attention_res": 16,
+ "normalize_eot": True
+ },
+ num_inference_steps=50,
+ guidance_scale=7.5,
+ generator=torch.manual_seed(42),
+ safety_checker=None
+).images
+
+draw_box_with_text(images[0], boxes, phrases).save("output.png")
+```
+
+
+### Stable Diffusion Reference
+
+This pipeline uses the Reference Control. Refer to the [sd-webui-controlnet discussion: Reference-only Control](https://github.com/Mikubill/sd-webui-controlnet/discussions/1236)[sd-webui-controlnet discussion: Reference-adain Control](https://github.com/Mikubill/sd-webui-controlnet/discussions/1280).
+
+Based on [this issue](https://github.com/huggingface/diffusers/issues/3566),
+
+- `EulerAncestralDiscreteScheduler` got poor results.
+
+```py
+import torch
+from diffusers import UniPCMultistepScheduler
+from diffusers.utils import load_image
+
+input_image = load_image("https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png")
+
+pipe = StableDiffusionReferencePipeline.from_pretrained(
+ "stable-diffusion-v1-5/stable-diffusion-v1-5",
+ safety_checker=None,
+ torch_dtype=torch.float16
+ ).to('cuda:0')
+
+pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
+
+result_img = pipe(ref_image=input_image,
+ prompt="1girl",
+ num_inference_steps=20,
+ reference_attn=True,
+ reference_adain=True).images[0]
+```
+
+Reference Image
+
+
+
+Output Image of `reference_attn=True` and `reference_adain=False`
+
+
+
+Output Image of `reference_attn=False` and `reference_adain=True`
+
+
+
+Output Image of `reference_attn=True` and `reference_adain=True`
+
+
+
+### Stable Diffusion ControlNet Reference
+
+This pipeline uses the Reference Control with ControlNet. Refer to the [sd-webui-controlnet discussion: Reference-only Control](https://github.com/Mikubill/sd-webui-controlnet/discussions/1236)[sd-webui-controlnet discussion: Reference-adain Control](https://github.com/Mikubill/sd-webui-controlnet/discussions/1280).
+
+Based on [this issue](https://github.com/huggingface/diffusers/issues/3566),
+
+- `EulerAncestralDiscreteScheduler` got poor results.
+- `guess_mode=True` works well for ControlNet v1.1
+
+```py
+import cv2
+import torch
+import numpy as np
+from PIL import Image
+from diffusers import UniPCMultistepScheduler
+from diffusers.utils import load_image
+
+input_image = load_image("https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png")
+
+# get canny image
+image = cv2.Canny(np.array(input_image), 100, 200)
+image = image[:, :, None]
+image = np.concatenate([image, image, image], axis=2)
+canny_image = Image.fromarray(image)
+
+controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16)
+pipe = StableDiffusionControlNetReferencePipeline.from_pretrained(
+ "stable-diffusion-v1-5/stable-diffusion-v1-5",
+ controlnet=controlnet,
+ safety_checker=None,
+ torch_dtype=torch.float16
+ ).to('cuda:0')
+
+pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
+
+result_img = pipe(ref_image=input_image,
+ prompt="1girl",
+ image=canny_image,
+ num_inference_steps=20,
+ reference_attn=True,
+ reference_adain=True).images[0]
+```
+
+Reference Image
+
+
+
+Output Image
+
+
+
+### Stable Diffusion on IPEX
+
+This diffusion pipeline aims to accelerate the inference of Stable-Diffusion on Intel Xeon CPUs with BF16/FP32 precision using [IPEX](https://github.com/intel/intel-extension-for-pytorch).
+
+To use this pipeline, you need to:
+
+1. Install [IPEX](https://github.com/intel/intel-extension-for-pytorch)
+
+**Note:** For each PyTorch release, there is a corresponding release of the IPEX. Here is the mapping relationship. It is recommended to install PyTorch/IPEX2.0 to get the best performance.
+
+|PyTorch Version|IPEX Version|
+|--|--|
+|[v2.0.\*](https://github.com/pytorch/pytorch/tree/v2.0.1 "v2.0.1")|[v2.0.\*](https://github.com/intel/intel-extension-for-pytorch/tree/v2.0.100+cpu)|
+|[v1.13.\*](https://github.com/pytorch/pytorch/tree/v1.13.0 "v1.13.0")|[v1.13.\*](https://github.com/intel/intel-extension-for-pytorch/tree/v1.13.100+cpu)|
+
+You can simply use pip to install IPEX with the latest version.
+
+```sh
+python -m pip install intel_extension_for_pytorch
+```
+
+**Note:** To install a specific version, run with the following command:
+
+```sh
+python -m pip install intel_extension_for_pytorch== -f https://developer.intel.com/ipex-whl-stable-cpu
+```
+
+2. After pipeline initialization, `prepare_for_ipex()` should be called to enable IPEX acceleration. Supported inference datatypes are Float32 and BFloat16.
+
+**Note:** The setting of generated image height/width for `prepare_for_ipex()` should be same as the setting of pipeline inference.
+
+```python
+pipe = DiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", custom_pipeline="stable_diffusion_ipex")
+# For Float32
+pipe.prepare_for_ipex(prompt, dtype=torch.float32, height=512, width=512) # value of image height/width should be consistent with the pipeline inference
+# For BFloat16
+pipe.prepare_for_ipex(prompt, dtype=torch.bfloat16, height=512, width=512) # value of image height/width should be consistent with the pipeline inference
+```
+
+Then you can use the ipex pipeline in a similar way to the default stable diffusion pipeline.
+
+```python
+# For Float32
+image = pipe(prompt, num_inference_steps=20, height=512, width=512).images[0] # value of image height/width should be consistent with 'prepare_for_ipex()'
+# For BFloat16
+with torch.cpu.amp.autocast(enabled=True, dtype=torch.bfloat16):
+ image = pipe(prompt, num_inference_steps=20, height=512, width=512).images[0] # value of image height/width should be consistent with 'prepare_for_ipex()'
+```
+
+The following code compares the performance of the original stable diffusion pipeline with the ipex-optimized pipeline.
+
+```python
+import torch
+import intel_extension_for_pytorch as ipex
+from diffusers import StableDiffusionPipeline
+import time
+
+prompt = "sailing ship in storm by Rembrandt"
+model_id = "stable-diffusion-v1-5/stable-diffusion-v1-5"
+# Helper function for time evaluation
+def elapsed_time(pipeline, nb_pass=3, num_inference_steps=20):
+ # warmup
+ for _ in range(2):
+ images = pipeline(prompt, num_inference_steps=num_inference_steps, height=512, width=512).images
+ # time evaluation
+ start = time.time()
+ for _ in range(nb_pass):
+ pipeline(prompt, num_inference_steps=num_inference_steps, height=512, width=512)
+ end = time.time()
+ return (end - start) / nb_pass
+
+############## bf16 inference performance ###############
+
+# 1. IPEX Pipeline initialization
+pipe = DiffusionPipeline.from_pretrained(model_id, custom_pipeline="stable_diffusion_ipex")
+pipe.prepare_for_ipex(prompt, dtype=torch.bfloat16, height=512, width=512)
+
+# 2. Original Pipeline initialization
+pipe2 = StableDiffusionPipeline.from_pretrained(model_id)
+
+# 3. Compare performance between Original Pipeline and IPEX Pipeline
+with torch.cpu.amp.autocast(enabled=True, dtype=torch.bfloat16):
+ latency = elapsed_time(pipe)
+ print("Latency of StableDiffusionIPEXPipeline--bf16", latency)
+ latency = elapsed_time(pipe2)
+ print("Latency of StableDiffusionPipeline--bf16", latency)
+
+############## fp32 inference performance ###############
+
+# 1. IPEX Pipeline initialization
+pipe3 = DiffusionPipeline.from_pretrained(model_id, custom_pipeline="stable_diffusion_ipex")
+pipe3.prepare_for_ipex(prompt, dtype=torch.float32, height=512, width=512)
+
+# 2. Original Pipeline initialization
+pipe4 = StableDiffusionPipeline.from_pretrained(model_id)
+
+# 3. Compare performance between Original Pipeline and IPEX Pipeline
+latency = elapsed_time(pipe3)
+print("Latency of StableDiffusionIPEXPipeline--fp32", latency)
+latency = elapsed_time(pipe4)
+print("Latency of StableDiffusionPipeline--fp32", latency)
+```
+
+### Stable Diffusion XL on IPEX
+
+This diffusion pipeline aims to accelerate the inference of Stable-Diffusion XL on Intel Xeon CPUs with BF16/FP32 precision using [IPEX](https://github.com/intel/intel-extension-for-pytorch).
+
+To use this pipeline, you need to:
+
+1. Install [IPEX](https://github.com/intel/intel-extension-for-pytorch)
+
+**Note:** For each PyTorch release, there is a corresponding release of IPEX. Here is the mapping relationship. It is recommended to install Pytorch/IPEX2.0 to get the best performance.
+
+|PyTorch Version|IPEX Version|
+|--|--|
+|[v2.0.\*](https://github.com/pytorch/pytorch/tree/v2.0.1 "v2.0.1")|[v2.0.\*](https://github.com/intel/intel-extension-for-pytorch/tree/v2.0.100+cpu)|
+|[v1.13.\*](https://github.com/pytorch/pytorch/tree/v1.13.0 "v1.13.0")|[v1.13.\*](https://github.com/intel/intel-extension-for-pytorch/tree/v1.13.100+cpu)|
+
+You can simply use pip to install IPEX with the latest version.
+
+```sh
+python -m pip install intel_extension_for_pytorch
+```
+
+**Note:** To install a specific version, run with the following command:
+
+```sh
+python -m pip install intel_extension_for_pytorch== -f https://developer.intel.com/ipex-whl-stable-cpu
+```
+
+2. After pipeline initialization, `prepare_for_ipex()` should be called to enable IPEX acceleration. Supported inference datatypes are Float32 and BFloat16.
+
+**Note:** The values of `height` and `width` used during preparation with `prepare_for_ipex()` should be the same when running inference with the prepared pipeline.
+
+```python
+pipe = StableDiffusionXLPipelineIpex.from_pretrained("stabilityai/sdxl-turbo", low_cpu_mem_usage=True, use_safetensors=True)
+# value of image height/width should be consistent with the pipeline inference
+# For Float32
+pipe.prepare_for_ipex(torch.float32, prompt, height=512, width=512)
+# For BFloat16
+pipe.prepare_for_ipex(torch.bfloat16, prompt, height=512, width=512)
+```
+
+Then you can use the ipex pipeline in a similar way to the default stable diffusion xl pipeline.
+
+```python
+# value of image height/width should be consistent with 'prepare_for_ipex()'
+# For Float32
+image = pipe(prompt, num_inference_steps=num_inference_steps, height=512, width=512, guidance_scale=guidance_scale).images[0]
+# For BFloat16
+with torch.cpu.amp.autocast(enabled=True, dtype=torch.bfloat16):
+ image = pipe(prompt, num_inference_steps=num_inference_steps, height=512, width=512, guidance_scale=guidance_scale).images[0]
+```
+
+The following code compares the performance of the original stable diffusion xl pipeline with the ipex-optimized pipeline.
+By using this optimized pipeline, we can get about 1.4-2 times performance boost with BFloat16 on fourth generation of Intel Xeon CPUs,
+code-named Sapphire Rapids.
+
+```python
+import torch
+from diffusers import StableDiffusionXLPipeline
+from pipeline_stable_diffusion_xl_ipex import StableDiffusionXLPipelineIpex
+import time
+
+prompt = "sailing ship in storm by Rembrandt"
+model_id = "stabilityai/sdxl-turbo"
+steps = 4
+
+# Helper function for time evaluation
+def elapsed_time(pipeline, nb_pass=3, num_inference_steps=1):
+ # warmup
+ for _ in range(2):
+ images = pipeline(prompt, num_inference_steps=num_inference_steps, height=512, width=512, guidance_scale=0.0).images
+ # time evaluation
+ start = time.time()
+ for _ in range(nb_pass):
+ pipeline(prompt, num_inference_steps=num_inference_steps, height=512, width=512, guidance_scale=0.0)
+ end = time.time()
+ return (end - start) / nb_pass
+
+############## bf16 inference performance ###############
+
+# 1. IPEX Pipeline initialization
+pipe = StableDiffusionXLPipelineIpex.from_pretrained(model_id, low_cpu_mem_usage=True, use_safetensors=True)
+pipe.prepare_for_ipex(torch.bfloat16, prompt, height=512, width=512)
+
+# 2. Original Pipeline initialization
+pipe2 = StableDiffusionXLPipeline.from_pretrained(model_id, low_cpu_mem_usage=True, use_safetensors=True)
+
+# 3. Compare performance between Original Pipeline and IPEX Pipeline
+with torch.cpu.amp.autocast(enabled=True, dtype=torch.bfloat16):
+ latency = elapsed_time(pipe, num_inference_steps=steps)
+ print("Latency of StableDiffusionXLPipelineIpex--bf16", latency, "s for total", steps, "steps")
+ latency = elapsed_time(pipe2, num_inference_steps=steps)
+ print("Latency of StableDiffusionXLPipeline--bf16", latency, "s for total", steps, "steps")
+
+############## fp32 inference performance ###############
+
+# 1. IPEX Pipeline initialization
+pipe3 = StableDiffusionXLPipelineIpex.from_pretrained(model_id, low_cpu_mem_usage=True, use_safetensors=True)
+pipe3.prepare_for_ipex(torch.float32, prompt, height=512, width=512)
+
+# 2. Original Pipeline initialization
+pipe4 = StableDiffusionXLPipeline.from_pretrained(model_id, low_cpu_mem_usage=True, use_safetensors=True)
+
+# 3. Compare performance between Original Pipeline and IPEX Pipeline
+latency = elapsed_time(pipe3, num_inference_steps=steps)
+print("Latency of StableDiffusionXLPipelineIpex--fp32", latency, "s for total", steps, "steps")
+latency = elapsed_time(pipe4, num_inference_steps=steps)
+print("Latency of StableDiffusionXLPipeline--fp32", latency, "s for total", steps, "steps")
+```
+
+### CLIP Guided Images Mixing With Stable Diffusion
+
+
+
+CLIP guided stable diffusion images mixing pipeline allows to combine two images using standard diffusion models.
+This approach is using (optional) CoCa model to avoid writing image description.
+[More code examples](https://github.com/TheDenk/images_mixing)
+
+### Stable Diffusion XL Long Weighted Prompt Pipeline
+
+This SDXL pipeline supports unlimited length prompt and negative prompt, compatible with A1111 prompt weighted style.
+
+You can provide both `prompt` and `prompt_2`. If only one prompt is provided, `prompt_2` will be a copy of the provided `prompt`. Here is a sample code to use this pipeline.
+
+```python
+from diffusers import DiffusionPipeline
+from diffusers.utils import load_image
+import torch
+
+pipe = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0"
+ , torch_dtype = torch.float16
+ , use_safetensors = True
+ , variant = "fp16"
+ , custom_pipeline = "lpw_stable_diffusion_xl",
+)
+
+prompt = "photo of a cute (white) cat running on the grass" * 20
+prompt2 = "chasing (birds:1.5)" * 20
+prompt = f"{prompt},{prompt2}"
+neg_prompt = "blur, low quality, carton, animate"
+
+pipe.to("cuda")
+
+# text2img
+t2i_images = pipe(
+ prompt=prompt,
+ negative_prompt=neg_prompt,
+).images # alternatively, you can call the .text2img() function
+
+# img2img
+input_image = load_image("/path/to/local/image.png") # or URL to your input image
+i2i_images = pipe.img2img(
+ prompt=prompt,
+ negative_prompt=neg_prompt,
+ image=input_image,
+ strength=0.8, # higher strength will result in more variation compared to original image
+).images
+
+# inpaint
+input_mask = load_image("/path/to/local/mask.png") # or URL to your input inpainting mask
+inpaint_images = pipe.inpaint(
+ prompt="photo of a cute (black) cat running on the grass" * 20,
+ negative_prompt=neg_prompt,
+ image=input_image,
+ mask=input_mask,
+ strength=0.6, # higher strength will result in more variation compared to original image
+).images
+
+pipe.to("cpu")
+torch.cuda.empty_cache()
+
+from IPython.display import display # assuming you are using this code in a notebook
+display(t2i_images[0])
+display(i2i_images[0])
+display(inpaint_images[0])
+```
+
+In the above code, the `prompt2` is appended to the `prompt`, which is more than 77 tokens. "birds" are showing up in the result.
+
+
+For more results, checkout [PR #6114](https://github.com/huggingface/diffusers/pull/6114).
+
+### Example Images Mixing (with CoCa)
+
+```python
+import requests
+from io import BytesIO
+
+import PIL
+import torch
+import open_clip
+from open_clip import SimpleTokenizer
+from diffusers import DiffusionPipeline
+from transformers import CLIPImageProcessor, CLIPModel
+
+
+def download_image(url):
+ response = requests.get(url)
+ return PIL.Image.open(BytesIO(response.content)).convert("RGB")
+
+# Loading additional models
+feature_extractor = CLIPImageProcessor.from_pretrained(
+ "laion/CLIP-ViT-B-32-laion2B-s34B-b79K"
+)
+clip_model = CLIPModel.from_pretrained(
+ "laion/CLIP-ViT-B-32-laion2B-s34B-b79K", torch_dtype=torch.float16
+)
+coca_model = open_clip.create_model('coca_ViT-L-14', pretrained='laion2B-s13B-b90k').to('cuda')
+coca_model.dtype = torch.float16
+coca_transform = open_clip.image_transform(
+ coca_model.visual.image_size,
+ is_train=False,
+ mean=getattr(coca_model.visual, 'image_mean', None),
+ std=getattr(coca_model.visual, 'image_std', None),
+)
+coca_tokenizer = SimpleTokenizer()
+
+# Pipeline creating
+mixing_pipeline = DiffusionPipeline.from_pretrained(
+ "CompVis/stable-diffusion-v1-4",
+ custom_pipeline="clip_guided_images_mixing_stable_diffusion",
+ clip_model=clip_model,
+ feature_extractor=feature_extractor,
+ coca_model=coca_model,
+ coca_tokenizer=coca_tokenizer,
+ coca_transform=coca_transform,
+ torch_dtype=torch.float16,
+)
+mixing_pipeline.enable_attention_slicing()
+mixing_pipeline = mixing_pipeline.to("cuda")
+
+# Pipeline running
+generator = torch.Generator(device="cuda").manual_seed(17)
+
+def download_image(url):
+ response = requests.get(url)
+ return PIL.Image.open(BytesIO(response.content)).convert("RGB")
+
+content_image = download_image("https://huggingface.co/datasets/TheDenk/images_mixing/resolve/main/boromir.jpg")
+style_image = download_image("https://huggingface.co/datasets/TheDenk/images_mixing/resolve/main/gigachad.jpg")
+
+pipe_images = mixing_pipeline(
+ num_inference_steps=50,
+ content_image=content_image,
+ style_image=style_image,
+ noise_strength=0.65,
+ slerp_latent_style_strength=0.9,
+ slerp_prompt_style_strength=0.1,
+ slerp_clip_image_style_strength=0.1,
+ guidance_scale=9.0,
+ batch_size=1,
+ clip_guidance_scale=100,
+ generator=generator,
+).images
+```
+
+
+
+### Stable Diffusion Mixture Tiling
+
+This pipeline uses the Mixture. Refer to the [Mixture](https://arxiv.org/abs/2302.02412) paper for more details.
+
+```python
+from diffusers import LMSDiscreteScheduler, DiffusionPipeline
+
+# Create scheduler and model (similar to StableDiffusionPipeline)
+scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
+pipeline = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", scheduler=scheduler, custom_pipeline="mixture_tiling")
+pipeline.to("cuda")
+
+# Mixture of Diffusers generation
+image = pipeline(
+ prompt=[[
+ "A charming house in the countryside, by jakub rozalski, sunset lighting, elegant, highly detailed, smooth, sharp focus, artstation, stunning masterpiece",
+ "A dirt road in the countryside crossing pastures, by jakub rozalski, sunset lighting, elegant, highly detailed, smooth, sharp focus, artstation, stunning masterpiece",
+ "An old and rusty giant robot lying on a dirt road, by jakub rozalski, dark sunset lighting, elegant, highly detailed, smooth, sharp focus, artstation, stunning masterpiece"
+ ]],
+ tile_height=640,
+ tile_width=640,
+ tile_row_overlap=0,
+ tile_col_overlap=256,
+ guidance_scale=8,
+ seed=7178915308,
+ num_inference_steps=50,
+)["images"][0]
+```
+
+
+
+### TensorRT Inpainting Stable Diffusion Pipeline
+
+The TensorRT Pipeline can be used to accelerate the Inpainting Stable Diffusion Inference run.
+
+NOTE: The ONNX conversions and TensorRT engine build may take up to 30 minutes.
+
+```python
+import requests
+from io import BytesIO
+from PIL import Image
+import torch
+from diffusers import PNDMScheduler
+from diffusers.pipelines import DiffusionPipeline
+
+# Use the PNDMScheduler scheduler here instead
+scheduler = PNDMScheduler.from_pretrained("stabilityai/stable-diffusion-2-inpainting", subfolder="scheduler")
+
+pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-inpainting",
+ custom_pipeline="stable_diffusion_tensorrt_inpaint",
+ variant='fp16',
+ torch_dtype=torch.float16,
+ scheduler=scheduler,
+ )
+
+# re-use cached folder to save ONNX models and TensorRT Engines
+pipe.set_cached_folder("stabilityai/stable-diffusion-2-inpainting", variant='fp16',)
+
+pipe = pipe.to("cuda")
+
+url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
+response = requests.get(url)
+input_image = Image.open(BytesIO(response.content)).convert("RGB")
+
+mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
+response = requests.get(mask_url)
+mask_image = Image.open(BytesIO(response.content)).convert("RGB")
+
+prompt = "a mecha robot sitting on a bench"
+image = pipe(prompt, image=input_image, mask_image=mask_image, strength=0.75,).images[0]
+image.save('tensorrt_inpaint_mecha_robot.png')
+```
+
+### Stable Diffusion Mixture Canvas
+
+This pipeline uses the Mixture. Refer to the [Mixture](https://arxiv.org/abs/2302.02412) paper for more details.
+
+```python
+from PIL import Image
+from diffusers import LMSDiscreteScheduler, DiffusionPipeline
+from diffusers.pipelines.pipeline_utils import Image2ImageRegion, Text2ImageRegion, preprocess_image
+
+
+# Load and preprocess guide image
+iic_image = preprocess_image(Image.open("input_image.png").convert("RGB"))
+
+# Create scheduler and model (similar to StableDiffusionPipeline)
+scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
+pipeline = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", scheduler=scheduler).to("cuda:0", custom_pipeline="mixture_canvas")
+pipeline.to("cuda")
+
+# Mixture of Diffusers generation
+output = pipeline(
+ canvas_height=800,
+ canvas_width=352,
+ regions=[
+ Text2ImageRegion(0, 800, 0, 352, guidance_scale=8,
+ prompt=f"best quality, masterpiece, WLOP, sakimichan, art contest winner on pixiv, 8K, intricate details, wet effects, rain drops, ethereal, mysterious, futuristic, UHD, HDR, cinematic lighting, in a beautiful forest, rainy day, award winning, trending on artstation, beautiful confident cheerful young woman, wearing a futuristic sleeveless dress, ultra beautiful detailed eyes, hyper-detailed face, complex, perfect, model, textured, chiaroscuro, professional make-up, realistic, figure in frame, "),
+ Image2ImageRegion(352-800, 352, 0, 352, reference_image=iic_image, strength=1.0),
+ ],
+ num_inference_steps=100,
+ seed=5525475061,
+)["images"][0]
+```
+
+
+
+
+### IADB pipeline
+
+This pipeline is the implementation of the [α-(de)Blending: a Minimalist Deterministic Diffusion Model](https://arxiv.org/abs/2305.03486) paper.
+It is a simple and minimalist diffusion model.
+
+The following code shows how to use the IADB pipeline to generate images using a pretrained celebahq-256 model.
+
+```python
+pipeline_iadb = DiffusionPipeline.from_pretrained("thomasc4/iadb-celebahq-256", custom_pipeline='iadb')
+
+pipeline_iadb = pipeline_iadb.to('cuda')
+
+output = pipeline_iadb(batch_size=4, num_inference_steps=128)
+for i in range(len(output[0])):
+ plt.imshow(output[0][i])
+ plt.show()
+```
+
+Sampling with the IADB formulation is easy, and can be done in a few lines (the pipeline already implements it):
+
+```python
+def sample_iadb(model, x0, nb_step):
+ x_alpha = x0
+ for t in range(nb_step):
+ alpha = (t/nb_step)
+ alpha_next =((t+1)/nb_step)
+
+ d = model(x_alpha, torch.tensor(alpha, device=x_alpha.device))['sample']
+ x_alpha = x_alpha + (alpha_next-alpha)*d
+
+ return x_alpha
+```
+
+The training loop is also straightforward:
+
+```python
+# Training loop
+while True:
+ x0 = sample_noise()
+ x1 = sample_dataset()
+
+ alpha = torch.rand(batch_size)
+
+ # Blend
+ x_alpha = (1-alpha) * x0 + alpha * x1
+
+ # Loss
+ loss = torch.sum((D(x_alpha, alpha)- (x1-x0))**2)
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+```
+
+### Zero1to3 pipeline
+
+This pipeline is the implementation of the [Zero-1-to-3: Zero-shot One Image to 3D Object](https://arxiv.org/abs/2303.11328) paper.
+The original pytorch-lightning [repo](https://github.com/cvlab-columbia/zero123) and a diffusers [repo](https://github.com/kxhit/zero123-hf).
+
+The following code shows how to use the Zero1to3 pipeline to generate novel view synthesis images using a pretrained stable diffusion model.
+
+```python
+import os
+import torch
+from pipeline_zero1to3 import Zero1to3StableDiffusionPipeline
+from diffusers.utils import load_image
+
+model_id = "kxic/zero123-165000" # zero123-105000, zero123-165000, zero123-xl
+
+pipe = Zero1to3StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
+
+pipe.enable_xformers_memory_efficient_attention()
+pipe.enable_vae_tiling()
+pipe.enable_attention_slicing()
+pipe = pipe.to("cuda")
+
+num_images_per_prompt = 4
+
+# test inference pipeline
+# x y z, Polar angle (vertical rotation in degrees) Azimuth angle (horizontal rotation in degrees) Zoom (relative distance from center)
+query_pose1 = [-75.0, 100.0, 0.0]
+query_pose2 = [-20.0, 125.0, 0.0]
+query_pose3 = [-55.0, 90.0, 0.0]
+
+# load image
+# H, W = (256, 256) # H, W = (512, 512) # zero123 training is 256,256
+
+# for batch input
+input_image1 = load_image("./demo/4_blackarm.png") # load_image("https://cvlab-zero123-live.hf.space/file=/home/user/app/configs/4_blackarm.png")
+input_image2 = load_image("./demo/8_motor.png") # load_image("https://cvlab-zero123-live.hf.space/file=/home/user/app/configs/8_motor.png")
+input_image3 = load_image("./demo/7_london.png") # load_image("https://cvlab-zero123-live.hf.space/file=/home/user/app/configs/7_london.png")
+input_images = [input_image1, input_image2, input_image3]
+query_poses = [query_pose1, query_pose2, query_pose3]
+
+# # for single input
+# H, W = (256, 256)
+# input_images = [input_image2.resize((H, W), PIL.Image.NEAREST)]
+# query_poses = [query_pose2]
+
+
+# better do preprocessing
+from gradio_new import preprocess_image, create_carvekit_interface
+import numpy as np
+import PIL.Image as Image
+
+pre_images = []
+models = dict()
+print('Instantiating Carvekit HiInterface...')
+models['carvekit'] = create_carvekit_interface()
+if not isinstance(input_images, list):
+ input_images = [input_images]
+for raw_im in input_images:
+ input_im = preprocess_image(models, raw_im, True)
+ H, W = input_im.shape[:2]
+ pre_images.append(Image.fromarray((input_im * 255.0).astype(np.uint8)))
+input_images = pre_images
+
+# infer pipeline, in original zero123 num_inference_steps=76
+images = pipe(input_imgs=input_images, prompt_imgs=input_images, poses=query_poses, height=H, width=W,
+ guidance_scale=3.0, num_images_per_prompt=num_images_per_prompt, num_inference_steps=50).images
+
+# save imgs
+log_dir = "logs"
+os.makedirs(log_dir, exist_ok=True)
+bs = len(input_images)
+i = 0
+for obj in range(bs):
+ for idx in range(num_images_per_prompt):
+ images[i].save(os.path.join(log_dir,f"obj{obj}_{idx}.jpg"))
+ i += 1
+```
+
+### Stable Diffusion XL Reference
+
+This pipeline uses the Reference. Refer to the [stable_diffusion_reference](https://github.com/huggingface/diffusers/blob/main/examples/community/README.md#stable-diffusion-reference).
+
+```py
+import torch
+from PIL import Image
+from diffusers.utils import load_image
+from diffusers import DiffusionPipeline
+from diffusers.schedulers import UniPCMultistepScheduler
+
+input_image = load_image("https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png")
+
+# pipe = DiffusionPipeline.from_pretrained(
+# "stabilityai/stable-diffusion-xl-base-1.0",
+# custom_pipeline="stable_diffusion_xl_reference",
+# torch_dtype=torch.float16,
+# use_safetensors=True,
+# variant="fp16").to('cuda:0')
+
+pipe = StableDiffusionXLReferencePipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0",
+ torch_dtype=torch.float16,
+ use_safetensors=True,
+ variant="fp16").to('cuda:0')
+
+pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
+
+result_img = pipe(ref_image=input_image,
+ prompt="1girl",
+ num_inference_steps=20,
+ reference_attn=True,
+ reference_adain=True).images[0]
+```
+
+Reference Image
+
+
+
+Output Image
+
+`prompt: 1 girl`
+
+`reference_attn=True, reference_adain=True, num_inference_steps=20`
+
+
+Reference Image
+
+
+Output Image
+
+`prompt: A dog`
+
+`reference_attn=True, reference_adain=False, num_inference_steps=20`
+
+
+Reference Image
+
+
+Output Image
+
+`prompt: An astronaut riding a lion`
+
+`reference_attn=True, reference_adain=True, num_inference_steps=20`
+
+
+### Stable diffusion fabric pipeline
+
+FABRIC approach applicable to a wide range of popular diffusion models, which exploits
+the self-attention layer present in the most widely used architectures to condition
+the diffusion process on a set of feedback images.
+
+```python
+import requests
+import torch
+from PIL import Image
+from io import BytesIO
+
+from diffusers import DiffusionPipeline
+
+# load the pipeline
+# make sure you're logged in with `huggingface-cli login`
+model_id_or_path = "stable-diffusion-v1-5/stable-diffusion-v1-5"
+# can also be used with dreamlike-art/dreamlike-photoreal-2.0
+pipe = DiffusionPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16, custom_pipeline="pipeline_fabric").to("cuda")
+
+# let's specify a prompt
+prompt = "An astronaut riding an elephant"
+negative_prompt = "lowres, cropped"
+
+# call the pipeline
+image = pipe(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ num_inference_steps=20,
+ generator=torch.manual_seed(12)
+).images[0]
+
+image.save("horse_to_elephant.jpg")
+
+# let's try another example with feedback
+url = "https://raw.githubusercontent.com/ChenWu98/cycle-diffusion/main/data/dalle2/A%20black%20colored%20car.png"
+response = requests.get(url)
+init_image = Image.open(BytesIO(response.content)).convert("RGB")
+
+prompt = "photo, A blue colored car, fish eye"
+liked = [init_image]
+## same goes with disliked
+
+# call the pipeline
+torch.manual_seed(0)
+image = pipe(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ liked=liked,
+ num_inference_steps=20,
+).images[0]
+
+image.save("black_to_blue.png")
+```
+
+*With enough feedbacks you can create very similar high quality images.*
+
+The original codebase can be found at [sd-fabric/fabric](https://github.com/sd-fabric/fabric), and available checkpoints are [dreamlike-art/dreamlike-photoreal-2.0](https://huggingface.co/dreamlike-art/dreamlike-photoreal-2.0), [stable-diffusion-v1-5/stable-diffusion-v1-5](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5), and [stabilityai/stable-diffusion-2-1](https://huggingface.co/stabilityai/stable-diffusion-2-1) (may give unexpected results).
+
+Let's have a look at the images (_512X512_)
+
+| Without Feedback | With Feedback (1st image) |
+|---------------------|---------------------|
+|  |  |
+
+### Masked Im2Im Stable Diffusion Pipeline
+
+This pipeline reimplements sketch inpaint feature from A1111 for non-inpaint models. The following code reads two images, original and one with mask painted over it. It computes mask as a difference of two images and does the inpainting in the area defined by the mask.
+
+```python
+img = PIL.Image.open("./mech.png")
+# read image with mask painted over
+img_paint = PIL.Image.open("./mech_painted.png")
+neq = numpy.any(numpy.array(img) != numpy.array(img_paint), axis=-1)
+mask = neq / neq.max()
+
+pipeline = MaskedStableDiffusionImg2ImgPipeline.from_pretrained("frankjoshua/icbinpICantBelieveIts_v8")
+
+# works best with EulerAncestralDiscreteScheduler
+pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(pipeline.scheduler.config)
+generator = torch.Generator(device="cpu").manual_seed(4)
+
+prompt = "a man wearing a mask"
+result = pipeline(prompt=prompt, image=img_paint, mask=mask, strength=0.75,
+ generator=generator)
+result.images[0].save("result.png")
+```
+
+original image mech.png
+
+
+
+image with mask mech_painted.png
+
+
+
+result:
+
+
+
+### Masked Im2Im Stable Diffusion Pipeline XL
+
+This pipeline implements sketch inpaint feature from A1111 for non-inpaint models. The following code reads two images, original and one with mask painted over it. It computes mask as a difference of two images and does the inpainting in the area defined by the mask. Latent code is initialized from the image with the mask by default so the color of the mask affects the result.
+
+```
+img = PIL.Image.open("./mech.png")
+# read image with mask painted over
+img_paint = PIL.Image.open("./mech_painted.png")
+
+pipeline = MaskedStableDiffusionXLImg2ImgPipeline.from_pretrained("frankjoshua/juggernautXL_v8Rundiffusion", dtype=torch.float16)
+
+pipeline.to('cuda')
+pipeline.enable_xformers_memory_efficient_attention()
+
+prompt = "a mech warrior wearing a mask"
+seed = 8348273636437
+for i in range(10):
+ generator = torch.Generator(device="cuda").manual_seed(seed + i)
+ print(seed + i)
+ result = pipeline(prompt=prompt, blur=48, image=img_paint, original_image=img, strength=0.9,
+ generator=generator, num_inference_steps=60, num_images_per_prompt=1)
+ im = result.images[0]
+ im.save(f"result{i}.png")
+```
+
+original image mech.png
+
+
+
+image with mask mech_painted.png
+
+
+
+result:
+
+
+
+### Prompt2Prompt Pipeline
+
+Prompt2Prompt allows the following edits:
+
+- ReplaceEdit (change words in prompt)
+- ReplaceEdit with local blend (change words in prompt, keep image part unrelated to changes constant)
+- RefineEdit (add words to prompt)
+- RefineEdit with local blend (add words to prompt, keep image part unrelated to changes constant)
+- ReweightEdit (modulate importance of words)
+
+Here's a full example for `ReplaceEdit``:
+
+```python
+import torch
+import numpy as np
+import matplotlib.pyplot as plt
+from diffusers import DiffusionPipeline
+
+pipe = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", custom_pipeline="pipeline_prompt2prompt").to("cuda")
+
+prompts = ["A turtle playing with a ball",
+ "A monkey playing with a ball"]
+
+cross_attention_kwargs = {
+ "edit_type": "replace",
+ "cross_replace_steps": 0.4,
+ "self_replace_steps": 0.4
+}
+
+outputs = pipe(prompt=prompts, height=512, width=512, num_inference_steps=50, cross_attention_kwargs=cross_attention_kwargs)
+```
+
+And abbreviated examples for the other edits:
+
+`ReplaceEdit with local blend`
+
+```python
+prompts = ["A turtle playing with a ball",
+ "A monkey playing with a ball"]
+
+cross_attention_kwargs = {
+ "edit_type": "replace",
+ "cross_replace_steps": 0.4,
+ "self_replace_steps": 0.4,
+ "local_blend_words": ["turtle", "monkey"]
+}
+```
+
+`RefineEdit`
+
+```python
+prompts = ["A turtle",
+ "A turtle in a forest"]
+
+cross_attention_kwargs = {
+ "edit_type": "refine",
+ "cross_replace_steps": 0.4,
+ "self_replace_steps": 0.4,
+}
+```
+
+`RefineEdit with local blend`
+
+```python
+prompts = ["A turtle",
+ "A turtle in a forest"]
+
+cross_attention_kwargs = {
+ "edit_type": "refine",
+ "cross_replace_steps": 0.4,
+ "self_replace_steps": 0.4,
+ "local_blend_words": ["in", "a" , "forest"]
+}
+```
+
+`ReweightEdit`
+
+```python
+prompts = ["A smiling turtle"] * 2
+
+edit_kcross_attention_kwargswargs = {
+ "edit_type": "reweight",
+ "cross_replace_steps": 0.4,
+ "self_replace_steps": 0.4,
+ "equalizer_words": ["smiling"],
+ "equalizer_strengths": [5]
+}
+```
+
+Side note: See [this GitHub gist](https://gist.github.com/UmerHA/b65bb5fb9626c9c73f3ade2869e36164) if you want to visualize the attention maps.
+
+### Latent Consistency Pipeline
+
+Latent Consistency Models was proposed in [Latent Consistency Models: Synthesizing High-Resolution Images with Few-Step Inference](https://arxiv.org/abs/2310.04378) by _Simian Luo, Yiqin Tan, Longbo Huang, Jian Li, Hang Zhao_ from Tsinghua University.
+
+The abstract of the paper reads as follows:
+
+*Latent Diffusion models (LDMs) have achieved remarkable results in synthesizing high-resolution images. However, the iterative sampling process is computationally intensive and leads to slow generation. Inspired by Consistency Models (song et al.), we propose Latent Consistency Models (LCMs), enabling swift inference with minimal steps on any pre-trained LDMs, including Stable Diffusion (rombach et al). Viewing the guided reverse diffusion process as solving an augmented probability flow ODE (PF-ODE), LCMs are designed to directly predict the solution of such ODE in latent space, mitigating the need for numerous iterations and allowing rapid, high-fidelity sampling. Efficiently distilled from pre-trained classifier-free guided diffusion models, a high-quality 768 x 768 2~4-step LCM takes only 32 A100 GPU hours for training. Furthermore, we introduce Latent Consistency Fine-tuning (LCF), a novel method that is tailored for fine-tuning LCMs on customized image datasets. Evaluation on the LAION-5B-Aesthetics dataset demonstrates that LCMs achieve state-of-the-art text-to-image generation performance with few-step inference. Project Page: [this https URL](https://latent-consistency-models.github.io/)*
+
+The model can be used with `diffusers` as follows:
+
+- *1. Load the model from the community pipeline.*
+
+```py
+from diffusers import DiffusionPipeline
+import torch
+
+pipe = DiffusionPipeline.from_pretrained("SimianLuo/LCM_Dreamshaper_v7", custom_pipeline="latent_consistency_txt2img", custom_revision="main")
+
+# To save GPU memory, torch.float16 can be used, but it may compromise image quality.
+pipe.to(torch_device="cuda", torch_dtype=torch.float32)
+```
+
+- 2. Run inference with as little as 4 steps:
+
+```py
+prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k"
+
+# Can be set to 1~50 steps. LCM supports fast inference even <= 4 steps. Recommend: 1~8 steps.
+num_inference_steps = 4
+
+images = pipe(prompt=prompt, num_inference_steps=num_inference_steps, guidance_scale=8.0, lcm_origin_steps=50, output_type="pil").images
+```
+
+For any questions or feedback, feel free to reach out to [Simian Luo](https://github.com/luosiallen).
+
+You can also try this pipeline directly in the [🚀 official spaces](https://huggingface.co/spaces/SimianLuo/Latent_Consistency_Model).
+
+### Latent Consistency Img2img Pipeline
+
+This pipeline extends the Latent Consistency Pipeline to allow it to take an input image.
+
+```py
+from diffusers import DiffusionPipeline
+import torch
+
+pipe = DiffusionPipeline.from_pretrained("SimianLuo/LCM_Dreamshaper_v7", custom_pipeline="latent_consistency_img2img")
+
+# To save GPU memory, torch.float16 can be used, but it may compromise image quality.
+pipe.to(torch_device="cuda", torch_dtype=torch.float32)
+```
+
+- 2. Run inference with as little as 4 steps:
+
+```py
+prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k"
+
+
+input_image=Image.open("myimg.png")
+
+strength = 0.5 # strength =0 (no change) strength=1 (completely overwrite image)
+
+# Can be set to 1~50 steps. LCM supports fast inference even <= 4 steps. Recommend: 1~8 steps.
+num_inference_steps = 4
+
+images = pipe(prompt=prompt, image=input_image, strength=strength, num_inference_steps=num_inference_steps, guidance_scale=8.0, lcm_origin_steps=50, output_type="pil").images
+```
+
+### Latent Consistency Interpolation Pipeline
+
+This pipeline extends the Latent Consistency Pipeline to allow for interpolation of the latent space between multiple prompts. It is similar to the [Stable Diffusion Interpolate](https://github.com/huggingface/diffusers/blob/main/examples/community/interpolate_stable_diffusion.py) and [unCLIP Interpolate](https://github.com/huggingface/diffusers/blob/main/examples/community/unclip_text_interpolation.py) community pipelines.
+
+```py
+import torch
+import numpy as np
+
+from diffusers import DiffusionPipeline
+
+pipe = DiffusionPipeline.from_pretrained("SimianLuo/LCM_Dreamshaper_v7", custom_pipeline="latent_consistency_interpolate")
+
+# To save GPU memory, torch.float16 can be used, but it may compromise image quality.
+pipe.to(torch_device="cuda", torch_dtype=torch.float32)
+
+prompts = [
+ "Self-portrait oil painting, a beautiful cyborg with golden hair, Margot Robbie, 8k",
+ "Self-portrait oil painting, an extremely strong man, body builder, Huge Jackman, 8k",
+ "An astronaut floating in space, renaissance art, realistic, high quality, 8k",
+ "Oil painting of a cat, cute, dream-like",
+ "Hugging face emoji, cute, realistic"
+]
+num_inference_steps = 4
+num_interpolation_steps = 60
+seed = 1337
+
+torch.manual_seed(seed)
+np.random.seed(seed)
+
+images = pipe(
+ prompt=prompts,
+ height=512,
+ width=512,
+ num_inference_steps=num_inference_steps,
+ num_interpolation_steps=num_interpolation_steps,
+ guidance_scale=8.0,
+ embedding_interpolation_type="lerp",
+ latent_interpolation_type="slerp",
+ process_batch_size=4, # Make it higher or lower based on your GPU memory
+ generator=torch.Generator(seed),
+)
+
+assert len(images) == (len(prompts) - 1) * num_interpolation_steps
+```
+
+### StableDiffusionUpscaleLDM3D Pipeline
+
+[LDM3D-VR](https://arxiv.org/pdf/2311.03226.pdf) is an extended version of LDM3D.
+
+The abstract from the paper is:
+*Latent diffusion models have proven to be state-of-the-art in the creation and manipulation of visual outputs. However, as far as we know, the generation of depth maps jointly with RGB is still limited. We introduce LDM3D-VR, a suite of diffusion models targeting virtual reality development that includes LDM3D-pano and LDM3D-SR. These models enable the generation of panoramic RGBD based on textual prompts and the upscaling of low-resolution inputs to high-resolution RGBD, respectively. Our models are fine-tuned from existing pretrained models on datasets containing panoramic/high-resolution RGB images, depth maps and captions. Both models are evaluated in comparison to existing related methods*
+
+Two checkpoints are available for use:
+
+- [ldm3d-pano](https://huggingface.co/Intel/ldm3d-pano). This checkpoint enables the generation of panoramic images and requires the StableDiffusionLDM3DPipeline pipeline to be used.
+- [ldm3d-sr](https://huggingface.co/Intel/ldm3d-sr). This checkpoint enables the upscaling of RGB and depth images. Can be used in cascade after the original LDM3D pipeline using the StableDiffusionUpscaleLDM3DPipeline pipeline.
+
+```py
+from PIL import Image
+import os
+import torch
+from diffusers import StableDiffusionLDM3DPipeline, DiffusionPipeline
+
+# Generate a rgb/depth output from LDM3D
+
+pipe_ldm3d = StableDiffusionLDM3DPipeline.from_pretrained("Intel/ldm3d-4c")
+pipe_ldm3d.to("cuda")
+
+prompt = "A picture of some lemons on a table"
+output = pipe_ldm3d(prompt)
+rgb_image, depth_image = output.rgb, output.depth
+rgb_image[0].save("lemons_ldm3d_rgb.jpg")
+depth_image[0].save("lemons_ldm3d_depth.png")
+
+# Upscale the previous output to a resolution of (1024, 1024)
+
+pipe_ldm3d_upscale = DiffusionPipeline.from_pretrained("Intel/ldm3d-sr", custom_pipeline="pipeline_stable_diffusion_upscale_ldm3d")
+
+pipe_ldm3d_upscale.to("cuda")
+
+low_res_img = Image.open("lemons_ldm3d_rgb.jpg").convert("RGB")
+low_res_depth = Image.open("lemons_ldm3d_depth.png").convert("L")
+outputs = pipe_ldm3d_upscale(prompt="high quality high resolution uhd 4k image", rgb=low_res_img, depth=low_res_depth, num_inference_steps=50, target_res=[1024, 1024])
+
+upscaled_rgb, upscaled_depth = outputs.rgb[0], outputs.depth[0]
+upscaled_rgb.save("upscaled_lemons_rgb.png")
+upscaled_depth.save("upscaled_lemons_depth.png")
+```
+
+### ControlNet + T2I Adapter Pipeline
+
+This pipeline combines both ControlNet and T2IAdapter into a single pipeline, where the forward pass is executed once.
+It receives `control_image` and `adapter_image`, as well as `controlnet_conditioning_scale` and `adapter_conditioning_scale`, for the ControlNet and Adapter modules, respectively. Whenever `adapter_conditioning_scale=0` or `controlnet_conditioning_scale=0`, it will act as a full ControlNet module or as a full T2IAdapter module, respectively.
+
+```py
+import cv2
+import numpy as np
+import torch
+from controlnet_aux.midas import MidasDetector
+from PIL import Image
+
+from diffusers import AutoencoderKL, ControlNetModel, MultiAdapter, T2IAdapter
+from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
+from diffusers.utils import load_image
+from examples.community.pipeline_stable_diffusion_xl_controlnet_adapter import (
+ StableDiffusionXLControlNetAdapterPipeline,
+)
+
+controlnet_depth = ControlNetModel.from_pretrained(
+ "diffusers/controlnet-depth-sdxl-1.0",
+ torch_dtype=torch.float16,
+ variant="fp16",
+ use_safetensors=True
+)
+adapter_depth = T2IAdapter.from_pretrained(
+ "TencentARC/t2i-adapter-depth-midas-sdxl-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
+)
+vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16, use_safetensors=True)
+
+pipe = StableDiffusionXLControlNetAdapterPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0",
+ controlnet=controlnet_depth,
+ adapter=adapter_depth,
+ vae=vae,
+ variant="fp16",
+ use_safetensors=True,
+ torch_dtype=torch.float16,
+)
+pipe = pipe.to("cuda")
+pipe.enable_xformers_memory_efficient_attention()
+# pipe.enable_freeu(s1=0.6, s2=0.4, b1=1.1, b2=1.2)
+midas_depth = MidasDetector.from_pretrained(
+ "valhalla/t2iadapter-aux-models", filename="dpt_large_384.pt", model_type="dpt_large"
+).to("cuda")
+
+prompt = "a tiger sitting on a park bench"
+img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
+
+image = load_image(img_url).resize((1024, 1024))
+
+depth_image = midas_depth(
+ image, detect_resolution=512, image_resolution=1024
+)
+
+strength = 0.5
+
+images = pipe(
+ prompt,
+ control_image=depth_image,
+ adapter_image=depth_image,
+ num_inference_steps=30,
+ controlnet_conditioning_scale=strength,
+ adapter_conditioning_scale=strength,
+).images
+images[0].save("controlnet_and_adapter.png")
+```
+
+### ControlNet + T2I Adapter + Inpainting Pipeline
+
+```py
+import cv2
+import numpy as np
+import torch
+from controlnet_aux.midas import MidasDetector
+from PIL import Image
+
+from diffusers import AutoencoderKL, ControlNetModel, MultiAdapter, T2IAdapter
+from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
+from diffusers.utils import load_image
+from examples.community.pipeline_stable_diffusion_xl_controlnet_adapter_inpaint import (
+ StableDiffusionXLControlNetAdapterInpaintPipeline,
+)
+
+controlnet_depth = ControlNetModel.from_pretrained(
+ "diffusers/controlnet-depth-sdxl-1.0",
+ torch_dtype=torch.float16,
+ variant="fp16",
+ use_safetensors=True
+)
+adapter_depth = T2IAdapter.from_pretrained(
+ "TencentARC/t2i-adapter-depth-midas-sdxl-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
+)
+vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16, use_safetensors=True)
+
+pipe = StableDiffusionXLControlNetAdapterInpaintPipeline.from_pretrained(
+ "diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
+ controlnet=controlnet_depth,
+ adapter=adapter_depth,
+ vae=vae,
+ variant="fp16",
+ use_safetensors=True,
+ torch_dtype=torch.float16,
+)
+pipe = pipe.to("cuda")
+pipe.enable_xformers_memory_efficient_attention()
+# pipe.enable_freeu(s1=0.6, s2=0.4, b1=1.1, b2=1.2)
+midas_depth = MidasDetector.from_pretrained(
+ "valhalla/t2iadapter-aux-models", filename="dpt_large_384.pt", model_type="dpt_large"
+).to("cuda")
+
+prompt = "a tiger sitting on a park bench"
+img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
+mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
+
+image = load_image(img_url).resize((1024, 1024))
+mask_image = load_image(mask_url).resize((1024, 1024))
+
+depth_image = midas_depth(
+ image, detect_resolution=512, image_resolution=1024
+)
+
+strength = 0.4
+
+images = pipe(
+ prompt,
+ image=image,
+ mask_image=mask_image,
+ control_image=depth_image,
+ adapter_image=depth_image,
+ num_inference_steps=30,
+ controlnet_conditioning_scale=strength,
+ adapter_conditioning_scale=strength,
+ strength=0.7,
+).images
+images[0].save("controlnet_and_adapter_inpaint.png")
+```
+
+### Regional Prompting Pipeline
+
+This pipeline is a port of the [Regional Prompter extension](https://github.com/hako-mikan/sd-webui-regional-prompter) for [Stable Diffusion web UI](https://github.com/AUTOMATIC1111/stable-diffusion-webui) to `diffusers`.
+This code implements a pipeline for the Stable Diffusion model, enabling the division of the canvas into multiple regions, with different prompts applicable to each region. Users can specify regions in two ways: using `Cols` and `Rows` modes for grid-like divisions, or the `Prompt` mode for regions calculated based on prompts.
+
+
+
+### Usage
+
+### Sample Code
+
+```py
+from examples.community.regional_prompting_stable_diffusion import RegionalPromptingStableDiffusionPipeline
+
+pipe = RegionalPromptingStableDiffusionPipeline.from_single_file(model_path, vae=vae)
+
+rp_args = {
+ "mode":"rows",
+ "div": "1;1;1"
+}
+
+prompt = """
+green hair twintail BREAK
+red blouse BREAK
+blue skirt
+"""
+
+images = pipe(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ guidance_scale=7.5,
+ height=768,
+ width=512,
+ num_inference_steps=20,
+ num_images_per_prompt=1,
+ rp_args=rp_args
+ ).images
+
+time = time.strftime(r"%Y%m%d%H%M%S")
+i = 1
+for image in images:
+ i += 1
+ fileName = f'img-{time}-{i+1}.png'
+ image.save(fileName)
+```
+
+### Cols, Rows mode
+
+In the Cols, Rows mode, you can split the screen vertically and horizontally and assign prompts to each region. The split ratio can be specified by 'div', and you can set the division ratio like '3;3;2' or '0.1;0.5'. Furthermore, as will be described later, you can also subdivide the split Cols, Rows to specify more complex regions.
+
+In this image, the image is divided into three parts, and a separate prompt is applied to each. The prompts are divided by 'BREAK', and each is applied to the respective region.
+
+
+```
+green hair twintail BREAK
+red blouse BREAK
+blue skirt
+```
+
+### 2-Dimentional division
+
+The prompt consists of instructions separated by the term `BREAK` and is assigned to different regions of a two-dimensional space. The image is initially split in the main splitting direction, which in this case is rows, due to the presence of a single semicolon `;`, dividing the space into an upper and a lower section. Additional sub-splitting is then applied, indicated by commas. The upper row is split into ratios of `2:1:1`, while the lower row is split into a ratio of `4:6`. Rows themselves are split in a `1:2` ratio. According to the reference image, the blue sky is designated as the first region, green hair as the second, the bookshelf as the third, and so on, in a sequence based on their position from the top left. The terrarium is placed on the desk in the fourth region, and the orange dress and sofa are in the fifth region, conforming to their respective splits.
+
+```py
+rp_args = {
+ "mode":"rows",
+ "div": "1,2,1,1;2,4,6"
+}
+
+prompt = """
+blue sky BREAK
+green hair BREAK
+book shelf BREAK
+terrarium on the desk BREAK
+orange dress and sofa
+"""
+```
+
+
+
+### Prompt Mode
+
+There are limitations to methods of specifying regions in advance. This is because specifying regions can be a hindrance when designating complex shapes or dynamic compositions. In the region specified by the prompt, the region is determined after the image generation has begun. This allows us to accommodate compositions and complex regions.
+For further infomagen, see [here](https://github.com/hako-mikan/sd-webui-regional-prompter/blob/main/prompt_en.md).
+
+### Syntax
+
+```
+baseprompt target1 target2 BREAK
+effect1, target1 BREAK
+effect2 ,target2
+```
+
+First, write the base prompt. In the base prompt, write the words (target1, target2) for which you want to create a mask. Next, separate them with BREAK. Next, write the prompt corresponding to target1. Then enter a comma and write target1. The order of the targets in the base prompt and the order of the BREAK-separated targets can be back to back.
+
+```
+target2 baseprompt target1 BREAK
+effect1, target1 BREAK
+effect2 ,target2
+```
+
+is also effective.
+
+### Sample
+
+In this example, masks are calculated for shirt, tie, skirt, and color prompts are specified only for those regions.
+
+```py
+rp_args = {
+ "mode": "prompt-ex",
+ "save_mask": True,
+ "th": "0.4,0.6,0.6",
+}
+
+prompt = """
+a girl in street with shirt, tie, skirt BREAK
+red, shirt BREAK
+green, tie BREAK
+blue , skirt
+"""
+```
+
+
+
+### Threshold
+
+The threshold used to determine the mask created by the prompt. This can be set as many times as there are masks, as the range varies widely depending on the target prompt. If multiple regions are used, enter them separated by commas. For example, hair tends to be ambiguous and requires a small value, while face tends to be large and requires a small value. These should be ordered by BREAK.
+
+```
+a lady ,hair, face BREAK
+red, hair BREAK
+tanned ,face
+```
+
+`threshold : 0.4,0.6`
+If only one input is given for multiple regions, they are all assumed to be the same value.
+
+### Prompt and Prompt-EX
+
+The difference is that in Prompt, duplicate regions are added, whereas in Prompt-EX, duplicate regions are overwritten sequentially. Since they are processed in order, setting a TARGET with a large regions first makes it easier for the effect of small regions to remain unmuffled.
+
+### Accuracy
+
+In the case of a 512x512 image, Attention mode reduces the size of the region to about 8x8 pixels deep in the U-Net, so that small regions get mixed up; Latent mode calculates 64*64, so that the region is exact.
+
+```
+girl hair twintail frills,ribbons, dress, face BREAK
+girl, ,face
+```
+
+### Mask
+
+When an image is generated, the generated mask is displayed. It is generated at the same size as the image, but is actually used at a much smaller size.
+
+### Use common prompt
+
+You can attach the prompt up to ADDCOMM to all prompts by separating it first with ADDCOMM. This is useful when you want to include elements common to all regions. For example, when generating pictures of three people with different appearances, it's necessary to include the instruction of 'three people' in all regions. It's also useful when inserting quality tags and other things. "For example, if you write as follows:
+
+```
+best quality, 3persons in garden, ADDCOMM
+a girl white dress BREAK
+a boy blue shirt BREAK
+an old man red suit
+```
+
+If common is enabled, this prompt is converted to the following:
+
+```
+best quality, 3persons in garden, a girl white dress BREAK
+best quality, 3persons in garden, a boy blue shirt BREAK
+best quality, 3persons in garden, an old man red suit
+```
+
+### Negative prompt
+
+Negative prompts are equally effective across all regions, but it is possible to set region-specific prompts for negative prompts as well. The number of BREAKs must be the same as the number of prompts. If the number of prompts does not match, the negative prompts will be used without being divided into regions.
+
+### Parameters
+
+To activate Regional Prompter, it is necessary to enter settings in `rp_args`. The items that can be set are as follows. `rp_args` is a dictionary type.
+
+### Input Parameters
+
+Parameters are specified through the `rp_arg`(dictionary type).
+
+```py
+rp_args = {
+ "mode":"rows",
+ "div": "1;1;1"
+}
+
+pipe(prompt=prompt, rp_args=rp_args)
+```
+
+### Required Parameters
+
+- `mode`: Specifies the method for defining regions. Choose from `Cols`, `Rows`, `Prompt`, or `Prompt-Ex`. This parameter is case-insensitive.
+- `divide`: Used in `Cols` and `Rows` modes. Details on how to specify this are provided under the respective `Cols` and `Rows` sections.
+- `th`: Used in `Prompt` mode. The method of specification is detailed under the `Prompt` section.
+
+### Optional Parameters
+
+- `save_mask`: In `Prompt` mode, choose whether to output the generated mask along with the image. The default is `False`.
+
+The Pipeline supports `compel` syntax. Input prompts using the `compel` structure will be automatically applied and processed.
+
+### Diffusion Posterior Sampling Pipeline
+
+- Reference paper
+
+ ```bibtex
+ @article{chung2022diffusion,
+ title={Diffusion posterior sampling for general noisy inverse problems},
+ author={Chung, Hyungjin and Kim, Jeongsol and Mccann, Michael T and Klasky, Marc L and Ye, Jong Chul},
+ journal={arXiv preprint arXiv:2209.14687},
+ year={2022}
+ }
+ ```
+
+- This pipeline allows zero-shot conditional sampling from the posterior distribution $p(x|y)$, given observation on $y$, unconditional generative model $p(x)$ and differentiable operator $y=f(x)$.
+
+- For example, $f(.)$ can be downsample operator, then $y$ is a downsampled image, and the pipeline becomes a super-resolution pipeline.
+- To use this pipeline, you need to know your operator $f(.)$ and corrupted image $y$, and pass them during the call. For example, as in the main function of `dps_pipeline.py`, you need to first define the Gaussian blurring operator $f(.)$. The operator should be a callable `nn.Module`, with all the parameter gradient disabled:
+
+ ```python
+ import torch.nn.functional as F
+ import scipy
+ from torch import nn
+
+ # define the Gaussian blurring operator first
+ class GaussialBlurOperator(nn.Module):
+ def __init__(self, kernel_size, intensity):
+ super().__init__()
+
+ class Blurkernel(nn.Module):
+ def __init__(self, blur_type='gaussian', kernel_size=31, std=3.0):
+ super().__init__()
+ self.blur_type = blur_type
+ self.kernel_size = kernel_size
+ self.std = std
+ self.seq = nn.Sequential(
+ nn.ReflectionPad2d(self.kernel_size//2),
+ nn.Conv2d(3, 3, self.kernel_size, stride=1, padding=0, bias=False, groups=3)
+ )
+ self.weights_init()
+
+ def forward(self, x):
+ return self.seq(x)
+
+ def weights_init(self):
+ if self.blur_type == "gaussian":
+ n = np.zeros((self.kernel_size, self.kernel_size))
+ n[self.kernel_size // 2, self.kernel_size // 2] = 1
+ k = scipy.ndimage.gaussian_filter(n, sigma=self.std)
+ k = torch.from_numpy(k)
+ self.k = k
+ for name, f in self.named_parameters():
+ f.data.copy_(k)
+ elif self.blur_type == "motion":
+ k = Kernel(size=(self.kernel_size, self.kernel_size), intensity=self.std).kernelMatrix
+ k = torch.from_numpy(k)
+ self.k = k
+ for name, f in self.named_parameters():
+ f.data.copy_(k)
+
+ def update_weights(self, k):
+ if not torch.is_tensor(k):
+ k = torch.from_numpy(k)
+ for name, f in self.named_parameters():
+ f.data.copy_(k)
+
+ def get_kernel(self):
+ return self.k
+
+ self.kernel_size = kernel_size
+ self.conv = Blurkernel(blur_type='gaussian',
+ kernel_size=kernel_size,
+ std=intensity)
+ self.kernel = self.conv.get_kernel()
+ self.conv.update_weights(self.kernel.type(torch.float32))
+
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def forward(self, data, **kwargs):
+ return self.conv(data)
+
+ def transpose(self, data, **kwargs):
+ return data
+
+ def get_kernel(self):
+ return self.kernel.view(1, 1, self.kernel_size, self.kernel_size)
+ ```
+
+- Next, you should obtain the corrupted image $y$ by the operator. In this example, we generate $y$ from the source image $x$. However in practice, having the operator $f(.)$ and corrupted image $y$ is enough:
+
+ ```python
+ # set up source image
+ src = Image.open('sample.png')
+ # read image into [1,3,H,W]
+ src = torch.from_numpy(np.array(src, dtype=np.float32)).permute(2,0,1)[None]
+ # normalize image to [-1,1]
+ src = (src / 127.5) - 1.0
+ src = src.to("cuda")
+
+ # set up operator and measurement
+ operator = GaussialBlurOperator(kernel_size=61, intensity=3.0).to("cuda")
+ measurement = operator(src)
+
+ # save the source and corrupted images
+ save_image((src+1.0)/2.0, "dps_src.png")
+ save_image((measurement+1.0)/2.0, "dps_mea.png")
+ ```
+
+- We provide an example pair of saved source and corrupted images, using the Gaussian blur operator above
+ - Source image:
+ - 
+ - Gaussian blurred image:
+ - 
+ - You can download those images to run the example on your own.
+
+- Next, we need to define a loss function used for diffusion posterior sample. For most of the cases, the RMSE is fine:
+
+ ```python
+ def RMSELoss(yhat, y):
+ return torch.sqrt(torch.sum((yhat-y)**2))
+ ```
+
+- And next, as any other diffusion models, we need the score estimator and scheduler. As we are working with $256x256$ face images, we use ddpm-celebahq-256:
+
+ ```python
+ # set up scheduler
+ scheduler = DDPMScheduler.from_pretrained("google/ddpm-celebahq-256")
+ scheduler.set_timesteps(1000)
+
+ # set up model
+ model = UNet2DModel.from_pretrained("google/ddpm-celebahq-256").to("cuda")
+ ```
+
+- And finally, run the pipeline:
+
+ ```python
+ # finally, the pipeline
+ dpspipe = DPSPipeline(model, scheduler)
+ image = dpspipe(
+ measurement=measurement,
+ operator=operator,
+ loss_fn=RMSELoss,
+ zeta=1.0,
+ ).images[0]
+ image.save("dps_generated_image.png")
+ ```
+
+- The `zeta` is a hyperparameter that is in range of $[0,1]$. It needs to be tuned for best effect. By setting `zeta=1`, you should be able to have the reconstructed result:
+ - Reconstructed image:
+ - 
+
+- The reconstruction is perceptually similar to the source image, but different in details.
+- In `dps_pipeline.py`, we also provide a super-resolution example, which should produce:
+ - Downsampled image:
+ - 
+ - Reconstructed image:
+ - 
+
+### AnimateDiff ControlNet Pipeline
+
+This pipeline combines AnimateDiff and ControlNet. Enjoy precise motion control for your videos! Refer to [this](https://github.com/huggingface/diffusers/issues/5866) issue for more details.
+
+```py
+import torch
+from diffusers import AutoencoderKL, ControlNetModel, MotionAdapter, DiffusionPipeline, DPMSolverMultistepScheduler
+from diffusers.utils import export_to_gif
+from PIL import Image
+
+motion_id = "guoyww/animatediff-motion-adapter-v1-5-2"
+adapter = MotionAdapter.from_pretrained(motion_id)
+controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_openpose", torch_dtype=torch.float16)
+vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16)
+
+model_id = "SG161222/Realistic_Vision_V5.1_noVAE"
+pipe = DiffusionPipeline.from_pretrained(
+ model_id,
+ motion_adapter=adapter,
+ controlnet=controlnet,
+ vae=vae,
+ custom_pipeline="pipeline_animatediff_controlnet",
+ torch_dtype=torch.float16,
+).to(device="cuda")
+pipe.scheduler = DPMSolverMultistepScheduler.from_pretrained(
+ model_id, subfolder="scheduler", beta_schedule="linear", clip_sample=False, timestep_spacing="linspace", steps_offset=1
+)
+pipe.enable_vae_slicing()
+
+conditioning_frames = []
+for i in range(1, 16 + 1):
+ conditioning_frames.append(Image.open(f"frame_{i}.png"))
+
+prompt = "astronaut in space, dancing"
+negative_prompt = "bad quality, worst quality, jpeg artifacts, ugly"
+result = pipe(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ width=512,
+ height=768,
+ conditioning_frames=conditioning_frames,
+ num_inference_steps=20,
+).frames[0]
+
+export_to_gif(result.frames[0], "result.gif")
+```
+
+
+ Conditioning Frames
+
+
+
+ AnimateDiff model: SG161222/Realistic_Vision_V5.1_noVAE
+
+
+
+
+ AnimateDiff model: CardosAnime
+
+
+
+
+
+
+You can also use multiple controlnets at once!
+
+```python
+import torch
+from diffusers import AutoencoderKL, ControlNetModel, MotionAdapter, DiffusionPipeline, DPMSolverMultistepScheduler
+from diffusers.utils import export_to_gif
+from PIL import Image
+
+motion_id = "guoyww/animatediff-motion-adapter-v1-5-2"
+adapter = MotionAdapter.from_pretrained(motion_id)
+controlnet1 = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_openpose", torch_dtype=torch.float16)
+controlnet2 = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16)
+vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16)
+
+model_id = "SG161222/Realistic_Vision_V5.1_noVAE"
+pipe = DiffusionPipeline.from_pretrained(
+ model_id,
+ motion_adapter=adapter,
+ controlnet=[controlnet1, controlnet2],
+ vae=vae,
+ custom_pipeline="pipeline_animatediff_controlnet",
+ torch_dtype=torch.float16,
+).to(device="cuda")
+pipe.scheduler = DPMSolverMultistepScheduler.from_pretrained(
+ model_id, subfolder="scheduler", clip_sample=False, timestep_spacing="linspace", steps_offset=1, beta_schedule="linear",
+)
+pipe.enable_vae_slicing()
+
+def load_video(file_path: str):
+ images = []
+
+ if file_path.startswith(('http://', 'https://')):
+ # If the file_path is a URL
+ response = requests.get(file_path)
+ response.raise_for_status()
+ content = BytesIO(response.content)
+ vid = imageio.get_reader(content)
+ else:
+ # Assuming it's a local file path
+ vid = imageio.get_reader(file_path)
+
+ for frame in vid:
+ pil_image = Image.fromarray(frame)
+ images.append(pil_image)
+
+ return images
+
+video = load_video("dance.gif")
+
+# You need to install it using `pip install controlnet_aux`
+from controlnet_aux.processor import Processor
+
+p1 = Processor("openpose_full")
+cn1 = [p1(frame) for frame in video]
+
+p2 = Processor("canny")
+cn2 = [p2(frame) for frame in video]
+
+prompt = "astronaut in space, dancing"
+negative_prompt = "bad quality, worst quality, jpeg artifacts, ugly"
+result = pipe(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ width=512,
+ height=768,
+ conditioning_frames=[cn1, cn2],
+ num_inference_steps=20,
+)
+
+export_to_gif(result.frames[0], "result.gif")
+```
+
+### DemoFusion
+
+This pipeline is the official implementation of [DemoFusion: Democratising High-Resolution Image Generation With No $$$](https://arxiv.org/abs/2311.16973).
+The original repo can be found at [repo](https://github.com/PRIS-CV/DemoFusion).
+
+- `view_batch_size` (`int`, defaults to 16):
+ The batch size for multiple denoising paths. Typically, a larger batch size can result in higher efficiency but comes with increased GPU memory requirements.
+
+- `stride` (`int`, defaults to 64):
+ The stride of moving local patches. A smaller stride is better for alleviating seam issues, but it also introduces additional computational overhead and inference time.
+
+- `cosine_scale_1` (`float`, defaults to 3):
+ Control the strength of skip-residual. For specific impacts, please refer to Appendix C in the DemoFusion paper.
+
+- `cosine_scale_2` (`float`, defaults to 1):
+ Control the strength of dilated sampling. For specific impacts, please refer to Appendix C in the DemoFusion paper.
+
+- `cosine_scale_3` (`float`, defaults to 1):
+ Control the strength of the Gaussian filter. For specific impacts, please refer to Appendix C in the DemoFusion paper.
+
+- `sigma` (`float`, defaults to 1):
+ The standard value of the Gaussian filter. Larger sigma promotes the global guidance of dilated sampling, but has the potential of over-smoothing.
+
+- `multi_decoder` (`bool`, defaults to True):
+ Determine whether to use a tiled decoder. Generally, when the resolution exceeds 3072x3072, a tiled decoder becomes necessary.
+
+- `show_image` (`bool`, defaults to False):
+ Determine whether to show intermediate results during generation.
+
+```py
+from diffusers import DiffusionPipeline
+
+pipe = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0",
+ custom_pipeline="pipeline_demofusion_sdxl",
+ custom_revision="main",
+ torch_dtype=torch.float16,
+)
+pipe = pipe.to("cuda")
+
+prompt = "Envision a portrait of an elderly woman, her face a canvas of time, framed by a headscarf with muted tones of rust and cream. Her eyes, blue like faded denim. Her attire, simple yet dignified."
+negative_prompt = "blurry, ugly, duplicate, poorly drawn, deformed, mosaic"
+
+images = pipe(
+ prompt,
+ negative_prompt=negative_prompt,
+ height=3072,
+ width=3072,
+ view_batch_size=16,
+ stride=64,
+ num_inference_steps=50,
+ guidance_scale=7.5,
+ cosine_scale_1=3,
+ cosine_scale_2=1,
+ cosine_scale_3=1,
+ sigma=0.8,
+ multi_decoder=True,
+ show_image=True
+)
+```
+
+You can display and save the generated images as:
+
+```py
+def image_grid(imgs, save_path=None):
+
+ w = 0
+ for i, img in enumerate(imgs):
+ h_, w_ = imgs[i].size
+ w += w_
+ h = h_
+ grid = Image.new('RGB', size=(w, h))
+ grid_w, grid_h = grid.size
+
+ w = 0
+ for i, img in enumerate(imgs):
+ h_, w_ = imgs[i].size
+ grid.paste(img, box=(w, h - h_))
+ if save_path != None:
+ img.save(save_path + "/img_{}.jpg".format((i + 1) * 1024))
+ w += w_
+
+ return grid
+
+image_grid(images, save_path="./outputs/")
+```
+
+ 
+
+### SDE Drag pipeline
+
+This pipeline provides drag-and-drop image editing using stochastic differential equations. It enables image editing by inputting prompt, image, mask_image, source_points, and target_points.
+
+
+
+See [paper](https://arxiv.org/abs/2311.01410), [paper page](https://ml-gsai.github.io/SDE-Drag-demo/), [original repo](https://github.com/ML-GSAI/SDE-Drag) for more information.
+
+```py
+import PIL
+import torch
+from diffusers import DDIMScheduler, DiffusionPipeline
+
+# Load the pipeline
+model_path = "stable-diffusion-v1-5/stable-diffusion-v1-5"
+scheduler = DDIMScheduler.from_pretrained(model_path, subfolder="scheduler")
+pipe = DiffusionPipeline.from_pretrained(model_path, scheduler=scheduler, custom_pipeline="sde_drag")
+pipe.to('cuda')
+
+# To save GPU memory, torch.float16 can be used, but it may compromise image quality.
+# If not training LoRA, please avoid using torch.float16
+# pipe.to(torch.float16)
+
+# Provide prompt, image, mask image, and the starting and target points for drag editing.
+prompt = "prompt of the image"
+image = PIL.Image.open('/path/to/image')
+mask_image = PIL.Image.open('/path/to/mask_image')
+source_points = [[123, 456]]
+target_points = [[234, 567]]
+
+# train_lora is optional, and in most cases, using train_lora can better preserve consistency with the original image.
+pipe.train_lora(prompt, image)
+
+output = pipe(prompt, image, mask_image, source_points, target_points)
+output_image = PIL.Image.fromarray(output)
+output_image.save("./output.png")
+```
+
+### Instaflow Pipeline
+
+InstaFlow is an ultra-fast, one-step image generator that achieves image quality close to Stable Diffusion, significantly reducing the demand of computational resources. This efficiency is made possible through a recent [Rectified Flow](https://github.com/gnobitab/RectifiedFlow) technique, which trains probability flows with straight trajectories, hence inherently requiring only a single step for fast inference.
+
+```python
+from diffusers import DiffusionPipeline
+import torch
+
+
+pipe = DiffusionPipeline.from_pretrained("XCLIU/instaflow_0_9B_from_sd_1_5", torch_dtype=torch.float16, custom_pipeline="instaflow_one_step")
+pipe.to("cuda") ### if GPU is not available, comment this line
+prompt = "A hyper-realistic photo of a cute cat."
+
+images = pipe(prompt=prompt,
+ num_inference_steps=1,
+ guidance_scale=0.0).images
+images[0].save("./image.png")
+```
+
+
+
+You can also combine it with LORA out of the box, like , to unlock cool use cases in single step!
+
+```python
+from diffusers import DiffusionPipeline
+import torch
+
+
+pipe = DiffusionPipeline.from_pretrained("XCLIU/instaflow_0_9B_from_sd_1_5", torch_dtype=torch.float16, custom_pipeline="instaflow_one_step")
+pipe.to("cuda") ### if GPU is not available, comment this line
+pipe.load_lora_weights("artificialguybr/logo-redmond-1-5v-logo-lora-for-liberteredmond-sd-1-5")
+prompt = "logo, A logo for a fitness app, dynamic running figure, energetic colors (red, orange) ),LogoRedAF ,"
+images = pipe(prompt=prompt,
+ num_inference_steps=1,
+ guidance_scale=0.0).images
+images[0].save("./image.png")
+```
+
+
+
+### Null-Text Inversion pipeline
+
+This pipeline provides null-text inversion for editing real images. It enables null-text optimization, and DDIM reconstruction via w, w/o null-text optimization. No prompt-to-prompt code is implemented as there is a Prompt2PromptPipeline.
+
+- Reference paper
+
+ ```bibtex
+ @article{hertz2022prompt,
+ title={Prompt-to-prompt image editing with cross attention control},
+ author={Hertz, Amir and Mokady, Ron and Tenenbaum, Jay and Aberman, Kfir and Pritch, Yael and Cohen-Or, Daniel},
+ booktitle={arXiv preprint arXiv:2208.01626},
+ year={2022}
+ ```}
+
+```py
+from diffusers import DDIMScheduler
+from examples.community.pipeline_null_text_inversion import NullTextPipeline
+import torch
+
+device = "cuda"
+# Provide invert_prompt and the image for null-text optimization.
+invert_prompt = "A lying cat"
+input_image = "siamese.jpg"
+steps = 50
+
+# Provide prompt used for generation. Same if reconstruction
+prompt = "A lying cat"
+# or different if editing.
+prompt = "A lying dog"
+
+# Float32 is essential to a well optimization
+model_path = "stable-diffusion-v1-5/stable-diffusion-v1-5"
+scheduler = DDIMScheduler(num_train_timesteps=1000, beta_start=0.00085, beta_end=0.0120, beta_schedule="scaled_linear")
+pipeline = NullTextPipeline.from_pretrained(model_path, scheduler=scheduler, torch_dtype=torch.float32).to(device)
+
+# Saves the inverted_latent to save time
+inverted_latent, uncond = pipeline.invert(input_image, invert_prompt, num_inner_steps=10, early_stop_epsilon=1e-5, num_inference_steps=steps)
+pipeline(prompt, uncond, inverted_latent, guidance_scale=7.5, num_inference_steps=steps).images[0].save(input_image+".output.jpg")
+```
+
+### Rerender A Video
+
+This is the Diffusers implementation of zero-shot video-to-video translation pipeline [Rerender A Video](https://github.com/williamyang1991/Rerender_A_Video) (without Ebsynth postprocessing). To run the code, please install gmflow. Then modify the path in `gmflow_dir`. After that, you can run the pipeline with:
+
+```py
+import sys
+gmflow_dir = "/path/to/gmflow"
+sys.path.insert(0, gmflow_dir)
+
+from diffusers import ControlNetModel, AutoencoderKL, DDIMScheduler
+from diffusers.utils import export_to_video
+import numpy as np
+import torch
+
+import cv2
+from PIL import Image
+
+def video_to_frame(video_path: str, interval: int):
+ vidcap = cv2.VideoCapture(video_path)
+ success = True
+
+ count = 0
+ res = []
+ while success:
+ count += 1
+ success, image = vidcap.read()
+ if count % interval != 1:
+ continue
+ if image is not None:
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
+ res.append(image)
+
+ vidcap.release()
+ return res
+
+input_video_path = 'path/to/video'
+input_interval = 10
+frames = video_to_frame(
+ input_video_path, input_interval)
+
+control_frames = []
+# get canny image
+for frame in frames:
+ np_image = cv2.Canny(frame, 50, 100)
+ np_image = np_image[:, :, None]
+ np_image = np.concatenate([np_image, np_image, np_image], axis=2)
+ canny_image = Image.fromarray(np_image)
+ control_frames.append(canny_image)
+
+# You can use any ControlNet here
+controlnet = ControlNetModel.from_pretrained(
+ "lllyasviel/sd-controlnet-canny").to('cuda')
+
+# You can use any finetuned SD here
+pipe = DiffusionPipeline.from_pretrained(
+ "stable-diffusion-v1-5/stable-diffusion-v1-5", controlnet=controlnet, custom_pipeline='rerender_a_video').to('cuda')
+
+# Optional: you can download vae-ft-mse-840000-ema-pruned.ckpt to enhance the results
+# pipe.vae = AutoencoderKL.from_single_file(
+# "path/to/vae-ft-mse-840000-ema-pruned.ckpt").to('cuda')
+
+pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
+
+generator = torch.manual_seed(0)
+frames = [Image.fromarray(frame) for frame in frames]
+output_frames = pipe(
+ "a beautiful woman in CG style, best quality, extremely detailed",
+ frames,
+ control_frames,
+ num_inference_steps=20,
+ strength=0.75,
+ controlnet_conditioning_scale=0.7,
+ generator=generator,
+ warp_start=0.0,
+ warp_end=0.1,
+ mask_start=0.5,
+ mask_end=0.8,
+ mask_strength=0.5,
+ negative_prompt='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality'
+).frames[0]
+
+export_to_video(
+ output_frames, "/path/to/video.mp4", 5)
+```
+
+### StyleAligned Pipeline
+
+This pipeline is the implementation of [Style Aligned Image Generation via Shared Attention](https://arxiv.org/abs/2312.02133). You can find more results [here](https://github.com/huggingface/diffusers/pull/6489#issuecomment-1881209354).
+
+> Large-scale Text-to-Image (T2I) models have rapidly gained prominence across creative fields, generating visually compelling outputs from textual prompts. However, controlling these models to ensure consistent style remains challenging, with existing methods necessitating fine-tuning and manual intervention to disentangle content and style. In this paper, we introduce StyleAligned, a novel technique designed to establish style alignment among a series of generated images. By employing minimal `attention sharing' during the diffusion process, our method maintains style consistency across images within T2I models. This approach allows for the creation of style-consistent images using a reference style through a straightforward inversion operation. Our method's evaluation across diverse styles and text prompts demonstrates high-quality synthesis and fidelity, underscoring its efficacy in achieving consistent style across various inputs.
+
+```python
+from typing import List
+
+import torch
+from diffusers import DiffusionPipeline
+from PIL import Image
+
+model_id = "a-r-r-o-w/dreamshaper-xl-turbo"
+pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16, variant="fp16", custom_pipeline="pipeline_sdxl_style_aligned")
+pipe = pipe.to("cuda")
+
+# Enable memory saving techniques
+pipe.enable_vae_slicing()
+pipe.enable_vae_tiling()
+
+prompt = [
+ "a toy train. macro photo. 3d game asset",
+ "a toy airplane. macro photo. 3d game asset",
+ "a toy bicycle. macro photo. 3d game asset",
+ "a toy car. macro photo. 3d game asset",
+]
+negative_prompt = "low quality, worst quality, "
+
+# Enable StyleAligned
+pipe.enable_style_aligned(
+ share_group_norm=False,
+ share_layer_norm=False,
+ share_attention=True,
+ adain_queries=True,
+ adain_keys=True,
+ adain_values=False,
+ full_attention_share=False,
+ shared_score_scale=1.0,
+ shared_score_shift=0.0,
+ only_self_level=0.0,
+)
+
+# Run inference
+images = pipe(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ guidance_scale=2,
+ height=1024,
+ width=1024,
+ num_inference_steps=10,
+ generator=torch.Generator().manual_seed(42),
+).images
+
+# Disable StyleAligned if you do not wish to use it anymore
+pipe.disable_style_aligned()
+```
+
+### AnimateDiff Image-To-Video Pipeline
+
+This pipeline adds experimental support for the image-to-video task using AnimateDiff. Refer to [this](https://github.com/huggingface/diffusers/pull/6328) PR for more examples and results.
+
+This pipeline relies on a "hack" discovered by the community that allows the generation of videos given an input image with AnimateDiff. It works by creating a copy of the image `num_frames` times and progressively adding more noise to the image based on the strength and latent interpolation method.
+
+```py
+import torch
+from diffusers import MotionAdapter, DiffusionPipeline, DDIMScheduler
+from diffusers.utils import export_to_gif, load_image
+
+model_id = "SG161222/Realistic_Vision_V5.1_noVAE"
+adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2")
+pipe = DiffusionPipeline.from_pretrained(model_id, motion_adapter=adapter, custom_pipeline="pipeline_animatediff_img2video").to("cuda")
+pipe.scheduler = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler", clip_sample=False, timestep_spacing="linspace", beta_schedule="linear", steps_offset=1)
+
+image = load_image("snail.png")
+output = pipe(
+ image=image,
+ prompt="A snail moving on the ground",
+ strength=0.8,
+ latent_interpolation_method="slerp", # can be lerp, slerp, or your own callback
+)
+frames = output.frames[0]
+export_to_gif(frames, "animation.gif")
+```
+
+### IP Adapter Face ID
+
+IP Adapter FaceID is an experimental IP Adapter model that uses image embeddings generated by `insightface`, so no image encoder needs to be loaded.
+You need to install `insightface` and all its requirements to use this model.
+You must pass the image embedding tensor as `image_embeds` to the `DiffusionPipeline` instead of `ip_adapter_image`.
+You can find more results [here](https://github.com/huggingface/diffusers/pull/6276).
+
+```py
+import torch
+from diffusers.utils import load_image
+import cv2
+import numpy as np
+from diffusers import DiffusionPipeline, AutoencoderKL, DDIMScheduler
+from insightface.app import FaceAnalysis
+
+
+noise_scheduler = DDIMScheduler(
+ num_train_timesteps=1000,
+ beta_start=0.00085,
+ beta_end=0.012,
+ beta_schedule="scaled_linear",
+ clip_sample=False,
+ set_alpha_to_one=False,
+ steps_offset=1,
+)
+vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse").to(dtype=torch.float16)
+pipeline = DiffusionPipeline.from_pretrained(
+ "SG161222/Realistic_Vision_V4.0_noVAE",
+ torch_dtype=torch.float16,
+ scheduler=noise_scheduler,
+ vae=vae,
+ custom_pipeline="ip_adapter_face_id"
+)
+pipeline.load_ip_adapter_face_id("h94/IP-Adapter-FaceID", "ip-adapter-faceid_sd15.bin")
+pipeline.to("cuda")
+
+generator = torch.Generator(device="cpu").manual_seed(42)
+num_images = 2
+
+image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ai_face2.png")
+
+app = FaceAnalysis(name="buffalo_l", providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
+app.prepare(ctx_id=0, det_size=(640, 640))
+image = cv2.cvtColor(np.asarray(image), cv2.COLOR_BGR2RGB)
+faces = app.get(image)
+image = torch.from_numpy(faces[0].normed_embedding).unsqueeze(0)
+images = pipeline(
+ prompt="A photo of a girl wearing a black dress, holding red roses in hand, upper body, behind is the Eiffel Tower",
+ image_embeds=image,
+ negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality",
+ num_inference_steps=20, num_images_per_prompt=num_images, width=512, height=704,
+ generator=generator
+).images
+
+for i in range(num_images):
+ images[i].save(f"c{i}.png")
+```
+
+### InstantID Pipeline
+
+InstantID is a new state-of-the-art tuning-free method to achieve ID-Preserving generation with only single image, supporting various downstream tasks. For any usage question, please refer to the [official implementation](https://github.com/InstantID/InstantID).
+
+```py
+# !pip install diffusers opencv-python transformers accelerate insightface
+import diffusers
+from diffusers.utils import load_image
+from diffusers import ControlNetModel
+
+import cv2
+import torch
+import numpy as np
+from PIL import Image
+
+from insightface.app import FaceAnalysis
+from pipeline_stable_diffusion_xl_instantid import StableDiffusionXLInstantIDPipeline, draw_kps
+
+# prepare 'antelopev2' under ./models
+# https://github.com/deepinsight/insightface/issues/1896#issuecomment-1023867304
+app = FaceAnalysis(name='antelopev2', root='./', providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
+app.prepare(ctx_id=0, det_size=(640, 640))
+
+# prepare models under ./checkpoints
+# https://huggingface.co/InstantX/InstantID
+from huggingface_hub import hf_hub_download
+
+hf_hub_download(repo_id="InstantX/InstantID", filename="ControlNetModel/config.json", local_dir="./checkpoints")
+hf_hub_download(repo_id="InstantX/InstantID", filename="ControlNetModel/diffusion_pytorch_model.safetensors", local_dir="./checkpoints")
+hf_hub_download(repo_id="InstantX/InstantID", filename="ip-adapter.bin", local_dir="./checkpoints")
+
+face_adapter = './checkpoints/ip-adapter.bin'
+controlnet_path = './checkpoints/ControlNetModel'
+
+# load IdentityNet
+controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16)
+
+base_model = 'wangqixun/YamerMIX_v8'
+pipe = StableDiffusionXLInstantIDPipeline.from_pretrained(
+ base_model,
+ controlnet=controlnet,
+ torch_dtype=torch.float16
+)
+pipe.to("cuda")
+
+# load adapter
+pipe.load_ip_adapter_instantid(face_adapter)
+
+# load an image
+face_image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ai_face2.png")
+
+# prepare face emb
+face_info = app.get(cv2.cvtColor(np.array(face_image), cv2.COLOR_RGB2BGR))
+face_info = sorted(face_info, key=lambda x:(x['bbox'][2]-x['bbox'][0])*x['bbox'][3]-x['bbox'][1])[-1] # only use the maximum face
+face_emb = face_info['embedding']
+face_kps = draw_kps(face_image, face_info['kps'])
+
+# prompt
+prompt = "film noir style, ink sketch|vector, male man, highly detailed, sharp focus, ultra sharpness, monochrome, high contrast, dramatic shadows, 1940s style, mysterious, cinematic"
+negative_prompt = "ugly, deformed, noisy, blurry, low contrast, realism, photorealistic, vibrant, colorful"
+
+# generate image
+pipe.set_ip_adapter_scale(0.8)
+image = pipe(
+ prompt,
+ image_embeds=face_emb,
+ image=face_kps,
+ controlnet_conditioning_scale=0.8,
+).images[0]
+```
+
+### UFOGen Scheduler
+
+[UFOGen](https://arxiv.org/abs/2311.09257) is a generative model designed for fast one-step text-to-image generation, trained via adversarial training starting from an initial pretrained diffusion model such as Stable Diffusion. `scheduling_ufogen.py` implements a onestep and multistep sampling algorithm for UFOGen models compatible with pipelines like `StableDiffusionPipeline`. A usage example is as follows:
+
+```py
+import torch
+from diffusers import StableDiffusionPipeline
+
+from scheduling_ufogen import UFOGenScheduler
+
+# NOTE: currently, I am not aware of any publicly available UFOGen model checkpoints trained from SD v1.5.
+ufogen_model_id_or_path = "/path/to/ufogen/model"
+pipe = StableDiffusionPipeline(
+ ufogen_model_id_or_path,
+ torch_dtype=torch.float16,
+)
+
+# You can initialize a UFOGenScheduler as follows:
+pipe.scheduler = UFOGenScheduler.from_config(pipe.scheduler.config)
+
+prompt = "Three cats having dinner at a table at new years eve, cinematic shot, 8k."
+
+# Onestep sampling
+onestep_image = pipe(prompt, num_inference_steps=1).images[0]
+
+# Multistep sampling
+multistep_image = pipe(prompt, num_inference_steps=4).images[0]
+```
+
+### FRESCO
+
+This is the Diffusers implementation of zero-shot video-to-video translation pipeline [FRESCO](https://github.com/williamyang1991/FRESCO) (without Ebsynth postprocessing and background smooth). To run the code, please install gmflow. Then modify the path in `gmflow_dir`. After that, you can run the pipeline with:
+
+```py
+from PIL import Image
+import cv2
+import torch
+import numpy as np
+
+from diffusers import ControlNetModel, DDIMScheduler, DiffusionPipeline
+import sys
+
+gmflow_dir = "/path/to/gmflow"
+sys.path.insert(0, gmflow_dir)
+
+def video_to_frame(video_path: str, interval: int):
+ vidcap = cv2.VideoCapture(video_path)
+ success = True
+
+ count = 0
+ res = []
+ while success:
+ count += 1
+ success, image = vidcap.read()
+ if count % interval != 1:
+ continue
+ if image is not None:
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
+ res.append(image)
+ if len(res) >= 8:
+ break
+
+ vidcap.release()
+ return res
+
+
+input_video_path = 'https://github.com/williamyang1991/FRESCO/raw/main/data/car-turn.mp4'
+output_video_path = 'car.gif'
+
+# You can use any finetuned SD here
+model_path = 'SG161222/Realistic_Vision_V2.0'
+
+prompt = 'a red car turns in the winter'
+a_prompt = ', RAW photo, subject, (high detailed skin:1.2), 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3, '
+n_prompt = '(deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, mutated hands and fingers:1.4), (deformed, distorted, disfigured:1.3), poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation'
+
+input_interval = 5
+frames = video_to_frame(
+ input_video_path, input_interval)
+
+control_frames = []
+# get canny image
+for frame in frames:
+ image = cv2.Canny(frame, 50, 100)
+ np_image = np.array(image)
+ np_image = np_image[:, :, None]
+ np_image = np.concatenate([np_image, np_image, np_image], axis=2)
+ canny_image = Image.fromarray(np_image)
+ control_frames.append(canny_image)
+
+# You can use any ControlNet here
+controlnet = ControlNetModel.from_pretrained(
+ "lllyasviel/sd-controlnet-canny").to('cuda')
+
+pipe = DiffusionPipeline.from_pretrained(
+ model_path, controlnet=controlnet, custom_pipeline='fresco_v2v').to('cuda')
+pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
+
+generator = torch.manual_seed(0)
+frames = [Image.fromarray(frame) for frame in frames]
+
+output_frames = pipe(
+ prompt + a_prompt,
+ frames,
+ control_frames,
+ num_inference_steps=20,
+ strength=0.75,
+ controlnet_conditioning_scale=0.7,
+ generator=generator,
+ negative_prompt=n_prompt
+).images
+
+output_frames[0].save(output_video_path, save_all=True,
+ append_images=output_frames[1:], duration=100, loop=0)
+```
+
+### AnimateDiff on IPEX
+
+This diffusion pipeline aims to accelerate the inference of AnimateDiff on Intel Xeon CPUs with BF16/FP32 precision using [IPEX](https://github.com/intel/intel-extension-for-pytorch).
+
+To use this pipeline, you need to:
+1. Install [IPEX](https://github.com/intel/intel-extension-for-pytorch)
+
+**Note:** For each PyTorch release, there is a corresponding release of IPEX. Here is the mapping relationship. It is recommended to install Pytorch/IPEX2.3 to get the best performance.
+
+|PyTorch Version|IPEX Version|
+|--|--|
+|[v2.3.\*](https://github.com/pytorch/pytorch/tree/v2.3.0 "v2.3.0")|[v2.3.\*](https://github.com/intel/intel-extension-for-pytorch/tree/v2.3.0+cpu)|
+|[v1.13.\*](https://github.com/pytorch/pytorch/tree/v1.13.0 "v1.13.0")|[v1.13.\*](https://github.com/intel/intel-extension-for-pytorch/tree/v1.13.100+cpu)|
+
+You can simply use pip to install IPEX with the latest version.
+```python
+python -m pip install intel_extension_for_pytorch
+```
+**Note:** To install a specific version, run with the following command:
+```
+python -m pip install intel_extension_for_pytorch== -f https://developer.intel.com/ipex-whl-stable-cpu
+```
+2. After pipeline initialization, `prepare_for_ipex()` should be called to enable IPEX accelaration. Supported inference datatypes are Float32 and BFloat16.
+
+```python
+pipe = AnimateDiffPipelineIpex.from_pretrained(base, motion_adapter=adapter, torch_dtype=dtype).to(device)
+# For Float32
+pipe.prepare_for_ipex(torch.float32, prompt="A girl smiling")
+# For BFloat16
+pipe.prepare_for_ipex(torch.bfloat16, prompt="A girl smiling")
+```
+
+Then you can use the ipex pipeline in a similar way to the default animatediff pipeline.
+```python
+# For Float32
+output = pipe(prompt="A girl smiling", guidance_scale=1.0, num_inference_steps=step)
+# For BFloat16
+with torch.cpu.amp.autocast(enabled = True, dtype = torch.bfloat16):
+ output = pipe(prompt="A girl smiling", guidance_scale=1.0, num_inference_steps=step)
+```
+
+The following code compares the performance of the original animatediff pipeline with the ipex-optimized pipeline.
+By using this optimized pipeline, we can get about 1.5-2.2 times performance boost with BFloat16 on the fifth generation of Intel Xeon CPUs, code-named Emerald Rapids.
+
+```python
+import torch
+from diffusers import MotionAdapter, AnimateDiffPipeline, EulerDiscreteScheduler
+from safetensors.torch import load_file
+from pipeline_animatediff_ipex import AnimateDiffPipelineIpex
+import time
+
+device = "cpu"
+dtype = torch.float32
+
+prompt = "A girl smiling"
+step = 8 # Options: [1,2,4,8]
+repo = "ByteDance/AnimateDiff-Lightning"
+ckpt = f"animatediff_lightning_{step}step_diffusers.safetensors"
+base = "emilianJR/epiCRealism" # Choose to your favorite base model.
+
+adapter = MotionAdapter().to(device, dtype)
+adapter.load_state_dict(load_file(hf_hub_download(repo, ckpt), device=device))
+
+# Helper function for time evaluation
+def elapsed_time(pipeline, nb_pass=3, num_inference_steps=1):
+ # warmup
+ for _ in range(2):
+ output = pipeline(prompt = prompt, guidance_scale=1.0, num_inference_steps = num_inference_steps)
+ #time evaluation
+ start = time.time()
+ for _ in range(nb_pass):
+ pipeline(prompt = prompt, guidance_scale=1.0, num_inference_steps = num_inference_steps)
+ end = time.time()
+ return (end - start) / nb_pass
+
+############## bf16 inference performance ###############
+
+# 1. IPEX Pipeline initialization
+pipe = AnimateDiffPipelineIpex.from_pretrained(base, motion_adapter=adapter, torch_dtype=dtype).to(device)
+pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", beta_schedule="linear")
+pipe.prepare_for_ipex(torch.bfloat16, prompt = prompt)
+
+# 2. Original Pipeline initialization
+pipe2 = AnimateDiffPipeline.from_pretrained(base, motion_adapter=adapter, torch_dtype=dtype).to(device)
+pipe2.scheduler = EulerDiscreteScheduler.from_config(pipe2.scheduler.config, timestep_spacing="trailing", beta_schedule="linear")
+
+# 3. Compare performance between Original Pipeline and IPEX Pipeline
+with torch.cpu.amp.autocast(enabled=True, dtype=torch.bfloat16):
+ latency = elapsed_time(pipe, num_inference_steps=step)
+ print("Latency of AnimateDiffPipelineIpex--bf16", latency, "s for total", step, "steps")
+ latency = elapsed_time(pipe2, num_inference_steps=step)
+ print("Latency of AnimateDiffPipeline--bf16", latency, "s for total", step, "steps")
+
+############## fp32 inference performance ###############
+
+# 1. IPEX Pipeline initialization
+pipe3 = AnimateDiffPipelineIpex.from_pretrained(base, motion_adapter=adapter, torch_dtype=dtype).to(device)
+pipe3.scheduler = EulerDiscreteScheduler.from_config(pipe3.scheduler.config, timestep_spacing="trailing", beta_schedule="linear")
+pipe3.prepare_for_ipex(torch.float32, prompt = prompt)
+
+# 2. Original Pipeline initialization
+pipe4 = AnimateDiffPipeline.from_pretrained(base, motion_adapter=adapter, torch_dtype=dtype).to(device)
+pipe4.scheduler = EulerDiscreteScheduler.from_config(pipe4.scheduler.config, timestep_spacing="trailing", beta_schedule="linear")
+
+# 3. Compare performance between Original Pipeline and IPEX Pipeline
+latency = elapsed_time(pipe3, num_inference_steps=step)
+print("Latency of AnimateDiffPipelineIpex--fp32", latency, "s for total", step, "steps")
+latency = elapsed_time(pipe4, num_inference_steps=step)
+print("Latency of AnimateDiffPipeline--fp32",latency, "s for total", step, "steps")
+```
+### HunyuanDiT with Differential Diffusion
+
+#### Usage
+
+```python
+import torch
+from diffusers import FlowMatchEulerDiscreteScheduler
+from diffusers.utils import load_image
+from PIL import Image
+from torchvision import transforms
+
+from pipeline_hunyuandit_differential_img2img import (
+ HunyuanDiTDifferentialImg2ImgPipeline,
+)
+
+
+pipe = HunyuanDiTDifferentialImg2ImgPipeline.from_pretrained(
+ "Tencent-Hunyuan/HunyuanDiT-Diffusers", torch_dtype=torch.float16
+).to("cuda")
+
+
+source_image = load_image(
+ "https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/20240329211129_4024911930.png"
+)
+map = load_image(
+ "https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/gradient_mask_2.png"
+)
+prompt = "a green pear"
+negative_prompt = "blurry"
+
+image = pipe(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ image=source_image,
+ num_inference_steps=28,
+ guidance_scale=4.5,
+ strength=1.0,
+ map=map,
+).images[0]
+```
+
+|  |  |  |
+| ------------------------------------------------------------------------------------------ | --------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------- |
+| Gradient | Input | Output |
+
+A colab notebook demonstrating all results can be found [here](https://colab.research.google.com/drive/1v44a5fpzyr4Ffr4v2XBQ7BajzG874N4P?usp=sharing). Depth Maps have also been added in the same colab.
+
+# Perturbed-Attention Guidance
+
+[Project](https://ku-cvlab.github.io/Perturbed-Attention-Guidance/) / [arXiv](https://arxiv.org/abs/2403.17377) / [GitHub](https://github.com/KU-CVLAB/Perturbed-Attention-Guidance)
+
+This implementation is based on [Diffusers](https://huggingface.co/docs/diffusers/index). `StableDiffusionPAGPipeline` is a modification of `StableDiffusionPipeline` to support Perturbed-Attention Guidance (PAG).
+
+## Example Usage
+
+```py
+import os
+import torch
+
+from accelerate.utils import set_seed
+
+from diffusers import StableDiffusionPipeline
+from diffusers.utils import load_image, make_image_grid
+from diffusers.utils.torch_utils import randn_tensor
+
+pipe = StableDiffusionPipeline.from_pretrained(
+ "stable-diffusion-v1-5/stable-diffusion-v1-5",
+ custom_pipeline="hyoungwoncho/sd_perturbed_attention_guidance",
+ torch_dtype=torch.float16
+)
+
+device = "cuda"
+pipe = pipe.to(device)
+
+pag_scale = 5.0
+pag_applied_layers_index = ['m0']
+
+batch_size = 4
+seed = 10
+
+base_dir = "./results/"
+grid_dir = base_dir + "/pag" + str(pag_scale) + "/"
+
+if not os.path.exists(grid_dir):
+ os.makedirs(grid_dir)
+
+set_seed(seed)
+
+latent_input = randn_tensor(shape=(batch_size,4,64,64), generator=None, device=device, dtype=torch.float16)
+
+output_baseline = pipe(
+ "",
+ width=512,
+ height=512,
+ num_inference_steps=50,
+ guidance_scale=0.0,
+ pag_scale=0.0,
+ pag_applied_layers_index=pag_applied_layers_index,
+ num_images_per_prompt=batch_size,
+ latents=latent_input
+).images
+
+output_pag = pipe(
+ "",
+ width=512,
+ height=512,
+ num_inference_steps=50,
+ guidance_scale=0.0,
+ pag_scale=5.0,
+ pag_applied_layers_index=pag_applied_layers_index,
+ num_images_per_prompt=batch_size,
+ latents=latent_input
+).images
+
+grid_image = make_image_grid(output_baseline + output_pag, rows=2, cols=batch_size)
+grid_image.save(grid_dir + "sample.png")
+```
+
+## PAG Parameters
+
+`pag_scale` : guidance scale of PAG (ex: 5.0)
+
+`pag_applied_layers_index` : index of the layer to apply perturbation (ex: ['m0'])
diff --git a/diffusers/examples/community/README_community_scripts.md b/diffusers/examples/community/README_community_scripts.md
new file mode 100644
index 0000000000000000000000000000000000000000..8432b4e82c9f594449c3ce7b1d6f18711a5ce795
--- /dev/null
+++ b/diffusers/examples/community/README_community_scripts.md
@@ -0,0 +1,232 @@
+# Community Scripts
+
+**Community scripts** consist of inference examples using Diffusers pipelines that have been added by the community.
+Please have a look at the following table to get an overview of all community examples. Click on the **Code Example** to get a copy-and-paste code example that you can try out.
+If a community script doesn't work as expected, please open an issue and ping the author on it.
+
+| Example | Description | Code Example | Colab | Author |
+|:--------------------------------------------------------------------------------------------------------------------------------------|:---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:------------------------------------------------------------------------------------------|:-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------:|
+| Using IP-Adapter with negative noise | Using negative noise with IP-adapter to better control the generation (see the [original post](https://github.com/huggingface/diffusers/discussions/7167) on the forum for more details) | [IP-Adapter Negative Noise](#ip-adapter-negative-noise) | | [Álvaro Somoza](https://github.com/asomoza)|
+| asymmetric tiling |configure seamless image tiling independently for the X and Y axes | [Asymmetric Tiling](#asymmetric-tiling ) | | [alexisrolland](https://github.com/alexisrolland)|
+
+
+## Example usages
+
+### IP Adapter Negative Noise
+
+Diffusers pipelines are fully integrated with IP-Adapter, which allows you to prompt the diffusion model with an image. However, it does not support negative image prompts (there is no `negative_ip_adapter_image` argument) the same way it supports negative text prompts. When you pass an `ip_adapter_image,` it will create a zero-filled tensor as a negative image. This script shows you how to create a negative noise from `ip_adapter_image` and use it to significantly improve the generation quality while preserving the composition of images.
+
+[cubiq](https://github.com/cubiq) initially developed this feature in his [repository](https://github.com/cubiq/ComfyUI_IPAdapter_plus). The community script was contributed by [asomoza](https://github.com/Somoza). You can find more details about this experimentation [this discussion](https://github.com/huggingface/diffusers/discussions/7167)
+
+IP-Adapter without negative noise
+|source|result|
+|---|---|
+|||
+
+IP-Adapter with negative noise
+|source|result|
+|---|---|
+|||
+
+```python
+import torch
+
+from diffusers import AutoencoderKL, DPMSolverMultistepScheduler, StableDiffusionXLPipeline
+from diffusers.models import ImageProjection
+from diffusers.utils import load_image
+
+
+def encode_image(
+ image_encoder,
+ feature_extractor,
+ image,
+ device,
+ num_images_per_prompt,
+ output_hidden_states=None,
+ negative_image=None,
+):
+ dtype = next(image_encoder.parameters()).dtype
+
+ if not isinstance(image, torch.Tensor):
+ image = feature_extractor(image, return_tensors="pt").pixel_values
+
+ image = image.to(device=device, dtype=dtype)
+ if output_hidden_states:
+ image_enc_hidden_states = image_encoder(image, output_hidden_states=True).hidden_states[-2]
+ image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
+
+ if negative_image is None:
+ uncond_image_enc_hidden_states = image_encoder(
+ torch.zeros_like(image), output_hidden_states=True
+ ).hidden_states[-2]
+ else:
+ if not isinstance(negative_image, torch.Tensor):
+ negative_image = feature_extractor(negative_image, return_tensors="pt").pixel_values
+ negative_image = negative_image.to(device=device, dtype=dtype)
+ uncond_image_enc_hidden_states = image_encoder(negative_image, output_hidden_states=True).hidden_states[-2]
+
+ uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
+ return image_enc_hidden_states, uncond_image_enc_hidden_states
+ else:
+ image_embeds = image_encoder(image).image_embeds
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+ uncond_image_embeds = torch.zeros_like(image_embeds)
+
+ return image_embeds, uncond_image_embeds
+
+
+@torch.no_grad()
+def prepare_ip_adapter_image_embeds(
+ unet,
+ image_encoder,
+ feature_extractor,
+ ip_adapter_image,
+ do_classifier_free_guidance,
+ device,
+ num_images_per_prompt,
+ ip_adapter_negative_image=None,
+):
+ if not isinstance(ip_adapter_image, list):
+ ip_adapter_image = [ip_adapter_image]
+
+ if len(ip_adapter_image) != len(unet.encoder_hid_proj.image_projection_layers):
+ raise ValueError(
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
+ )
+
+ image_embeds = []
+ for single_ip_adapter_image, image_proj_layer in zip(
+ ip_adapter_image, unet.encoder_hid_proj.image_projection_layers
+ ):
+ output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
+ single_image_embeds, single_negative_image_embeds = encode_image(
+ image_encoder,
+ feature_extractor,
+ single_ip_adapter_image,
+ device,
+ 1,
+ output_hidden_state,
+ negative_image=ip_adapter_negative_image,
+ )
+ single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
+ single_negative_image_embeds = torch.stack([single_negative_image_embeds] * num_images_per_prompt, dim=0)
+
+ if do_classifier_free_guidance:
+ single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
+ single_image_embeds = single_image_embeds.to(device)
+
+ image_embeds.append(single_image_embeds)
+
+ return image_embeds
+
+
+vae = AutoencoderKL.from_pretrained(
+ "madebyollin/sdxl-vae-fp16-fix",
+ torch_dtype=torch.float16,
+).to("cuda")
+
+pipeline = StableDiffusionXLPipeline.from_pretrained(
+ "RunDiffusion/Juggernaut-XL-v9",
+ torch_dtype=torch.float16,
+ vae=vae,
+ variant="fp16",
+).to("cuda")
+
+pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
+pipeline.scheduler.config.use_karras_sigmas = True
+
+pipeline.load_ip_adapter(
+ "h94/IP-Adapter",
+ subfolder="sdxl_models",
+ weight_name="ip-adapter-plus_sdxl_vit-h.safetensors",
+ image_encoder_folder="models/image_encoder",
+)
+pipeline.set_ip_adapter_scale(0.7)
+
+ip_image = load_image("source.png")
+negative_ip_image = load_image("noise.png")
+
+image_embeds = prepare_ip_adapter_image_embeds(
+ unet=pipeline.unet,
+ image_encoder=pipeline.image_encoder,
+ feature_extractor=pipeline.feature_extractor,
+ ip_adapter_image=[[ip_image]],
+ do_classifier_free_guidance=True,
+ device="cuda",
+ num_images_per_prompt=1,
+ ip_adapter_negative_image=negative_ip_image,
+)
+
+
+prompt = "cinematic photo of a cyborg in the city, 4k, high quality, intricate, highly detailed"
+negative_prompt = "blurry, smooth, plastic"
+
+image = pipeline(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ ip_adapter_image_embeds=image_embeds,
+ guidance_scale=6.0,
+ num_inference_steps=25,
+ generator=torch.Generator(device="cpu").manual_seed(1556265306),
+).images[0]
+
+image.save("result.png")
+```
+
+### Asymmetric Tiling
+Stable Diffusion is not trained to generate seamless textures. However, you can use this simple script to add tiling to your generation. This script is contributed by [alexisrolland](https://github.com/alexisrolland). See more details in the [this issue](https://github.com/huggingface/diffusers/issues/556)
+
+
+|Generated|Tiled|
+|---|---|
+|||
+
+
+```py
+import torch
+from typing import Optional
+from diffusers import StableDiffusionPipeline
+from diffusers.models.lora import LoRACompatibleConv
+
+def seamless_tiling(pipeline, x_axis, y_axis):
+ def asymmetric_conv2d_convforward(self, input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None):
+ self.paddingX = (self._reversed_padding_repeated_twice[0], self._reversed_padding_repeated_twice[1], 0, 0)
+ self.paddingY = (0, 0, self._reversed_padding_repeated_twice[2], self._reversed_padding_repeated_twice[3])
+ working = torch.nn.functional.pad(input, self.paddingX, mode=x_mode)
+ working = torch.nn.functional.pad(working, self.paddingY, mode=y_mode)
+ return torch.nn.functional.conv2d(working, weight, bias, self.stride, torch.nn.modules.utils._pair(0), self.dilation, self.groups)
+ x_mode = 'circular' if x_axis else 'constant'
+ y_mode = 'circular' if y_axis else 'constant'
+ targets = [pipeline.vae, pipeline.text_encoder, pipeline.unet]
+ convolution_layers = []
+ for target in targets:
+ for module in target.modules():
+ if isinstance(module, torch.nn.Conv2d):
+ convolution_layers.append(module)
+ for layer in convolution_layers:
+ if isinstance(layer, LoRACompatibleConv) and layer.lora_layer is None:
+ layer.lora_layer = lambda * x: 0
+ layer._conv_forward = asymmetric_conv2d_convforward.__get__(layer, torch.nn.Conv2d)
+ return pipeline
+
+pipeline = StableDiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16, use_safetensors=True)
+pipeline.enable_model_cpu_offload()
+prompt = ["texture of a red brick wall"]
+seed = 123456
+generator = torch.Generator(device='cuda').manual_seed(seed)
+
+pipeline = seamless_tiling(pipeline=pipeline, x_axis=True, y_axis=True)
+image = pipeline(
+ prompt=prompt,
+ width=512,
+ height=512,
+ num_inference_steps=20,
+ guidance_scale=7,
+ num_images_per_prompt=1,
+ generator=generator
+).images[0]
+seamless_tiling(pipeline=pipeline, x_axis=False, y_axis=False)
+
+torch.cuda.empty_cache()
+image.save('image.png')
+```
\ No newline at end of file
diff --git a/diffusers/examples/community/bit_diffusion.py b/diffusers/examples/community/bit_diffusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..71d8f31163f2db719c418b5cbb5f54479c4b9033
--- /dev/null
+++ b/diffusers/examples/community/bit_diffusion.py
@@ -0,0 +1,264 @@
+from typing import Optional, Tuple, Union
+
+import torch
+from einops import rearrange, reduce
+
+from diffusers import DDIMScheduler, DDPMScheduler, DiffusionPipeline, ImagePipelineOutput, UNet2DConditionModel
+from diffusers.schedulers.scheduling_ddim import DDIMSchedulerOutput
+from diffusers.schedulers.scheduling_ddpm import DDPMSchedulerOutput
+
+
+BITS = 8
+
+
+# convert to bit representations and back taken from https://github.com/lucidrains/bit-diffusion/blob/main/bit_diffusion/bit_diffusion.py
+def decimal_to_bits(x, bits=BITS):
+ """expects image tensor ranging from 0 to 1, outputs bit tensor ranging from -1 to 1"""
+ device = x.device
+
+ x = (x * 255).int().clamp(0, 255)
+
+ mask = 2 ** torch.arange(bits - 1, -1, -1, device=device)
+ mask = rearrange(mask, "d -> d 1 1")
+ x = rearrange(x, "b c h w -> b c 1 h w")
+
+ bits = ((x & mask) != 0).float()
+ bits = rearrange(bits, "b c d h w -> b (c d) h w")
+ bits = bits * 2 - 1
+ return bits
+
+
+def bits_to_decimal(x, bits=BITS):
+ """expects bits from -1 to 1, outputs image tensor from 0 to 1"""
+ device = x.device
+
+ x = (x > 0).int()
+ mask = 2 ** torch.arange(bits - 1, -1, -1, device=device, dtype=torch.int32)
+
+ mask = rearrange(mask, "d -> d 1 1")
+ x = rearrange(x, "b (c d) h w -> b c d h w", d=8)
+ dec = reduce(x * mask, "b c d h w -> b c h w", "sum")
+ return (dec / 255).clamp(0.0, 1.0)
+
+
+# modified scheduler step functions for clamping the predicted x_0 between -bit_scale and +bit_scale
+def ddim_bit_scheduler_step(
+ self,
+ model_output: torch.Tensor,
+ timestep: int,
+ sample: torch.Tensor,
+ eta: float = 0.0,
+ use_clipped_model_output: bool = True,
+ generator=None,
+ return_dict: bool = True,
+) -> Union[DDIMSchedulerOutput, Tuple]:
+ """
+ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
+ process from the learned model outputs (most often the predicted noise).
+ Args:
+ model_output (`torch.Tensor`): direct output from learned diffusion model.
+ timestep (`int`): current discrete timestep in the diffusion chain.
+ sample (`torch.Tensor`):
+ current instance of sample being created by diffusion process.
+ eta (`float`): weight of noise for added noise in diffusion step.
+ use_clipped_model_output (`bool`): TODO
+ generator: random number generator.
+ return_dict (`bool`): option for returning tuple rather than DDIMSchedulerOutput class
+ Returns:
+ [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] or `tuple`:
+ [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
+ returning a tuple, the first element is the sample tensor.
+ """
+ if self.num_inference_steps is None:
+ raise ValueError(
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
+ )
+
+ # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
+ # Ideally, read DDIM paper in-detail understanding
+
+ # Notation ( ->
+ # - pred_noise_t -> e_theta(x_t, t)
+ # - pred_original_sample -> f_theta(x_t, t) or x_0
+ # - std_dev_t -> sigma_t
+ # - eta -> η
+ # - pred_sample_direction -> "direction pointing to x_t"
+ # - pred_prev_sample -> "x_t-1"
+
+ # 1. get previous step value (=t-1)
+ prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps
+
+ # 2. compute alphas, betas
+ alpha_prod_t = self.alphas_cumprod[timestep]
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
+
+ beta_prod_t = 1 - alpha_prod_t
+
+ # 3. compute predicted original sample from predicted noise also called
+ # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
+ pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
+
+ # 4. Clip "predicted x_0"
+ scale = self.bit_scale
+ if self.config.clip_sample:
+ pred_original_sample = torch.clamp(pred_original_sample, -scale, scale)
+
+ # 5. compute variance: "sigma_t(η)" -> see formula (16)
+ # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
+ variance = self._get_variance(timestep, prev_timestep)
+ std_dev_t = eta * variance ** (0.5)
+
+ if use_clipped_model_output:
+ # the model_output is always re-derived from the clipped x_0 in Glide
+ model_output = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
+
+ # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
+ pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output
+
+ # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
+ prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
+
+ if eta > 0:
+ # randn_like does not support generator https://github.com/pytorch/pytorch/issues/27072
+ device = model_output.device if torch.is_tensor(model_output) else "cpu"
+ noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator).to(device)
+ variance = self._get_variance(timestep, prev_timestep) ** (0.5) * eta * noise
+
+ prev_sample = prev_sample + variance
+
+ if not return_dict:
+ return (prev_sample,)
+
+ return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
+
+
+def ddpm_bit_scheduler_step(
+ self,
+ model_output: torch.Tensor,
+ timestep: int,
+ sample: torch.Tensor,
+ prediction_type="epsilon",
+ generator=None,
+ return_dict: bool = True,
+) -> Union[DDPMSchedulerOutput, Tuple]:
+ """
+ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
+ process from the learned model outputs (most often the predicted noise).
+ Args:
+ model_output (`torch.Tensor`): direct output from learned diffusion model.
+ timestep (`int`): current discrete timestep in the diffusion chain.
+ sample (`torch.Tensor`):
+ current instance of sample being created by diffusion process.
+ prediction_type (`str`, default `epsilon`):
+ indicates whether the model predicts the noise (epsilon), or the samples (`sample`).
+ generator: random number generator.
+ return_dict (`bool`): option for returning tuple rather than DDPMSchedulerOutput class
+ Returns:
+ [`~schedulers.scheduling_utils.DDPMSchedulerOutput`] or `tuple`:
+ [`~schedulers.scheduling_utils.DDPMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
+ returning a tuple, the first element is the sample tensor.
+ """
+ t = timestep
+
+ if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]:
+ model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1)
+ else:
+ predicted_variance = None
+
+ # 1. compute alphas, betas
+ alpha_prod_t = self.alphas_cumprod[t]
+ alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one
+ beta_prod_t = 1 - alpha_prod_t
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
+
+ # 2. compute predicted original sample from predicted noise also called
+ # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
+ if prediction_type == "epsilon":
+ pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
+ elif prediction_type == "sample":
+ pred_original_sample = model_output
+ else:
+ raise ValueError(f"Unsupported prediction_type {prediction_type}.")
+
+ # 3. Clip "predicted x_0"
+ scale = self.bit_scale
+ if self.config.clip_sample:
+ pred_original_sample = torch.clamp(pred_original_sample, -scale, scale)
+
+ # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
+ # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
+ pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * self.betas[t]) / beta_prod_t
+ current_sample_coeff = self.alphas[t] ** (0.5) * beta_prod_t_prev / beta_prod_t
+
+ # 5. Compute predicted previous sample µ_t
+ # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
+ pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample
+
+ # 6. Add noise
+ variance = 0
+ if t > 0:
+ noise = torch.randn(
+ model_output.size(), dtype=model_output.dtype, layout=model_output.layout, generator=generator
+ ).to(model_output.device)
+ variance = (self._get_variance(t, predicted_variance=predicted_variance) ** 0.5) * noise
+
+ pred_prev_sample = pred_prev_sample + variance
+
+ if not return_dict:
+ return (pred_prev_sample,)
+
+ return DDPMSchedulerOutput(prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample)
+
+
+class BitDiffusion(DiffusionPipeline):
+ def __init__(
+ self,
+ unet: UNet2DConditionModel,
+ scheduler: Union[DDIMScheduler, DDPMScheduler],
+ bit_scale: Optional[float] = 1.0,
+ ):
+ super().__init__()
+ self.bit_scale = bit_scale
+ self.scheduler.step = (
+ ddim_bit_scheduler_step if isinstance(scheduler, DDIMScheduler) else ddpm_bit_scheduler_step
+ )
+
+ self.register_modules(unet=unet, scheduler=scheduler)
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ height: Optional[int] = 256,
+ width: Optional[int] = 256,
+ num_inference_steps: Optional[int] = 50,
+ generator: Optional[torch.Generator] = None,
+ batch_size: Optional[int] = 1,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ **kwargs,
+ ) -> Union[Tuple, ImagePipelineOutput]:
+ latents = torch.randn(
+ (batch_size, self.unet.config.in_channels, height, width),
+ generator=generator,
+ )
+ latents = decimal_to_bits(latents) * self.bit_scale
+ latents = latents.to(self.device)
+
+ self.scheduler.set_timesteps(num_inference_steps)
+
+ for t in self.progress_bar(self.scheduler.timesteps):
+ # predict the noise residual
+ noise_pred = self.unet(latents, t).sample
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents).prev_sample
+
+ image = bits_to_decimal(latents)
+
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image,)
+
+ return ImagePipelineOutput(images=image)
diff --git a/diffusers/examples/community/checkpoint_merger.py b/diffusers/examples/community/checkpoint_merger.py
new file mode 100644
index 0000000000000000000000000000000000000000..6ba4b8c6e837964723aee3e7f573b84f2470493f
--- /dev/null
+++ b/diffusers/examples/community/checkpoint_merger.py
@@ -0,0 +1,284 @@
+import glob
+import os
+from typing import Dict, List, Union
+
+import safetensors.torch
+import torch
+from huggingface_hub import snapshot_download
+from huggingface_hub.utils import validate_hf_hub_args
+
+from diffusers import DiffusionPipeline, __version__
+from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
+from diffusers.utils import CONFIG_NAME, ONNX_WEIGHTS_NAME, WEIGHTS_NAME
+
+
+class CheckpointMergerPipeline(DiffusionPipeline):
+ """
+ A class that supports merging diffusion models based on the discussion here:
+ https://github.com/huggingface/diffusers/issues/877
+
+ Example usage:-
+
+ pipe = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", custom_pipeline="checkpoint_merger.py")
+
+ merged_pipe = pipe.merge(["CompVis/stable-diffusion-v1-4","prompthero/openjourney"], interp = 'inv_sigmoid', alpha = 0.8, force = True)
+
+ merged_pipe.to('cuda')
+
+ prompt = "An astronaut riding a unicycle on Mars"
+
+ results = merged_pipe(prompt)
+
+ ## For more details, see the docstring for the merge method.
+
+ """
+
+ def __init__(self):
+ self.register_to_config()
+ super().__init__()
+
+ def _compare_model_configs(self, dict0, dict1):
+ if dict0 == dict1:
+ return True
+ else:
+ config0, meta_keys0 = self._remove_meta_keys(dict0)
+ config1, meta_keys1 = self._remove_meta_keys(dict1)
+ if config0 == config1:
+ print(f"Warning !: Mismatch in keys {meta_keys0} and {meta_keys1}.")
+ return True
+ return False
+
+ def _remove_meta_keys(self, config_dict: Dict):
+ meta_keys = []
+ temp_dict = config_dict.copy()
+ for key in config_dict.keys():
+ if key.startswith("_"):
+ temp_dict.pop(key)
+ meta_keys.append(key)
+ return (temp_dict, meta_keys)
+
+ @torch.no_grad()
+ @validate_hf_hub_args
+ def merge(self, pretrained_model_name_or_path_list: List[Union[str, os.PathLike]], **kwargs):
+ """
+ Returns a new pipeline object of the class 'DiffusionPipeline' with the merged checkpoints(weights) of the models passed
+ in the argument 'pretrained_model_name_or_path_list' as a list.
+
+ Parameters:
+ -----------
+ pretrained_model_name_or_path_list : A list of valid pretrained model names in the HuggingFace hub or paths to locally stored models in the HuggingFace format.
+
+ **kwargs:
+ Supports all the default DiffusionPipeline.get_config_dict kwargs viz..
+
+ cache_dir, force_download, proxies, local_files_only, token, revision, torch_dtype, device_map.
+
+ alpha - The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha
+ would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2
+
+ interp - The interpolation method to use for the merging. Supports "sigmoid", "inv_sigmoid", "add_diff" and None.
+ Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_diff" is supported.
+
+ force - Whether to ignore mismatch in model_config.json for the current models. Defaults to False.
+
+ variant - which variant of a pretrained model to load, e.g. "fp16" (None)
+
+ """
+ # Default kwargs from DiffusionPipeline
+ cache_dir = kwargs.pop("cache_dir", None)
+ force_download = kwargs.pop("force_download", False)
+ proxies = kwargs.pop("proxies", None)
+ local_files_only = kwargs.pop("local_files_only", False)
+ token = kwargs.pop("token", None)
+ variant = kwargs.pop("variant", None)
+ revision = kwargs.pop("revision", None)
+ torch_dtype = kwargs.pop("torch_dtype", None)
+ device_map = kwargs.pop("device_map", None)
+
+ alpha = kwargs.pop("alpha", 0.5)
+ interp = kwargs.pop("interp", None)
+
+ print("Received list", pretrained_model_name_or_path_list)
+ print(f"Combining with alpha={alpha}, interpolation mode={interp}")
+
+ checkpoint_count = len(pretrained_model_name_or_path_list)
+ # Ignore result from model_index_json comparison of the two checkpoints
+ force = kwargs.pop("force", False)
+
+ # If less than 2 checkpoints, nothing to merge. If more than 3, not supported for now.
+ if checkpoint_count > 3 or checkpoint_count < 2:
+ raise ValueError(
+ "Received incorrect number of checkpoints to merge. Ensure that either 2 or 3 checkpoints are being"
+ " passed."
+ )
+
+ print("Received the right number of checkpoints")
+ # chkpt0, chkpt1 = pretrained_model_name_or_path_list[0:2]
+ # chkpt2 = pretrained_model_name_or_path_list[2] if checkpoint_count == 3 else None
+
+ # Validate that the checkpoints can be merged
+ # Step 1: Load the model config and compare the checkpoints. We'll compare the model_index.json first while ignoring the keys starting with '_'
+ config_dicts = []
+ for pretrained_model_name_or_path in pretrained_model_name_or_path_list:
+ config_dict = DiffusionPipeline.load_config(
+ pretrained_model_name_or_path,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ token=token,
+ revision=revision,
+ )
+ config_dicts.append(config_dict)
+
+ comparison_result = True
+ for idx in range(1, len(config_dicts)):
+ comparison_result &= self._compare_model_configs(config_dicts[idx - 1], config_dicts[idx])
+ if not force and comparison_result is False:
+ raise ValueError("Incompatible checkpoints. Please check model_index.json for the models.")
+ print("Compatible model_index.json files found")
+ # Step 2: Basic Validation has succeeded. Let's download the models and save them into our local files.
+ cached_folders = []
+ for pretrained_model_name_or_path, config_dict in zip(pretrained_model_name_or_path_list, config_dicts):
+ folder_names = [k for k in config_dict.keys() if not k.startswith("_")]
+ allow_patterns = [os.path.join(k, "*") for k in folder_names]
+ allow_patterns += [
+ WEIGHTS_NAME,
+ SCHEDULER_CONFIG_NAME,
+ CONFIG_NAME,
+ ONNX_WEIGHTS_NAME,
+ DiffusionPipeline.config_name,
+ ]
+ requested_pipeline_class = config_dict.get("_class_name")
+ user_agent = {"diffusers": __version__, "pipeline_class": requested_pipeline_class}
+
+ cached_folder = (
+ pretrained_model_name_or_path
+ if os.path.isdir(pretrained_model_name_or_path)
+ else snapshot_download(
+ pretrained_model_name_or_path,
+ cache_dir=cache_dir,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ token=token,
+ revision=revision,
+ allow_patterns=allow_patterns,
+ user_agent=user_agent,
+ )
+ )
+ print("Cached Folder", cached_folder)
+ cached_folders.append(cached_folder)
+
+ # Step 3:-
+ # Load the first checkpoint as a diffusion pipeline and modify its module state_dict in place
+ final_pipe = DiffusionPipeline.from_pretrained(
+ cached_folders[0],
+ torch_dtype=torch_dtype,
+ device_map=device_map,
+ variant=variant,
+ )
+ final_pipe.to(self.device)
+
+ checkpoint_path_2 = None
+ if len(cached_folders) > 2:
+ checkpoint_path_2 = os.path.join(cached_folders[2])
+
+ if interp == "sigmoid":
+ theta_func = CheckpointMergerPipeline.sigmoid
+ elif interp == "inv_sigmoid":
+ theta_func = CheckpointMergerPipeline.inv_sigmoid
+ elif interp == "add_diff":
+ theta_func = CheckpointMergerPipeline.add_difference
+ else:
+ theta_func = CheckpointMergerPipeline.weighted_sum
+
+ # Find each module's state dict.
+ for attr in final_pipe.config.keys():
+ if not attr.startswith("_"):
+ checkpoint_path_1 = os.path.join(cached_folders[1], attr)
+ if os.path.exists(checkpoint_path_1):
+ files = [
+ *glob.glob(os.path.join(checkpoint_path_1, "*.safetensors")),
+ *glob.glob(os.path.join(checkpoint_path_1, "*.bin")),
+ ]
+ checkpoint_path_1 = files[0] if len(files) > 0 else None
+ if len(cached_folders) < 3:
+ checkpoint_path_2 = None
+ else:
+ checkpoint_path_2 = os.path.join(cached_folders[2], attr)
+ if os.path.exists(checkpoint_path_2):
+ files = [
+ *glob.glob(os.path.join(checkpoint_path_2, "*.safetensors")),
+ *glob.glob(os.path.join(checkpoint_path_2, "*.bin")),
+ ]
+ checkpoint_path_2 = files[0] if len(files) > 0 else None
+ # For an attr if both checkpoint_path_1 and 2 are None, ignore.
+ # If at least one is present, deal with it according to interp method, of course only if the state_dict keys match.
+ if checkpoint_path_1 is None and checkpoint_path_2 is None:
+ print(f"Skipping {attr}: not present in 2nd or 3d model")
+ continue
+ try:
+ module = getattr(final_pipe, attr)
+ if isinstance(module, bool): # ignore requires_safety_checker boolean
+ continue
+ theta_0 = getattr(module, "state_dict")
+ theta_0 = theta_0()
+
+ update_theta_0 = getattr(module, "load_state_dict")
+ theta_1 = (
+ safetensors.torch.load_file(checkpoint_path_1)
+ if (checkpoint_path_1.endswith(".safetensors"))
+ else torch.load(checkpoint_path_1, map_location="cpu")
+ )
+ theta_2 = None
+ if checkpoint_path_2:
+ theta_2 = (
+ safetensors.torch.load_file(checkpoint_path_2)
+ if (checkpoint_path_2.endswith(".safetensors"))
+ else torch.load(checkpoint_path_2, map_location="cpu")
+ )
+
+ if not theta_0.keys() == theta_1.keys():
+ print(f"Skipping {attr}: key mismatch")
+ continue
+ if theta_2 and not theta_1.keys() == theta_2.keys():
+ print(f"Skipping {attr}:y mismatch")
+ except Exception as e:
+ print(f"Skipping {attr} do to an unexpected error: {str(e)}")
+ continue
+ print(f"MERGING {attr}")
+
+ for key in theta_0.keys():
+ if theta_2:
+ theta_0[key] = theta_func(theta_0[key], theta_1[key], theta_2[key], alpha)
+ else:
+ theta_0[key] = theta_func(theta_0[key], theta_1[key], None, alpha)
+
+ del theta_1
+ del theta_2
+ update_theta_0(theta_0)
+
+ del theta_0
+ return final_pipe
+
+ @staticmethod
+ def weighted_sum(theta0, theta1, theta2, alpha):
+ return ((1 - alpha) * theta0) + (alpha * theta1)
+
+ # Smoothstep (https://en.wikipedia.org/wiki/Smoothstep)
+ @staticmethod
+ def sigmoid(theta0, theta1, theta2, alpha):
+ alpha = alpha * alpha * (3 - (2 * alpha))
+ return theta0 + ((theta1 - theta0) * alpha)
+
+ # Inverse Smoothstep (https://en.wikipedia.org/wiki/Smoothstep)
+ @staticmethod
+ def inv_sigmoid(theta0, theta1, theta2, alpha):
+ import math
+
+ alpha = 0.5 - math.sin(math.asin(1.0 - 2.0 * alpha) / 3.0)
+ return theta0 + ((theta1 - theta0) * alpha)
+
+ @staticmethod
+ def add_difference(theta0, theta1, theta2, alpha):
+ return theta0 + (theta1 - theta2) * (1.0 - alpha)
diff --git a/diffusers/examples/community/clip_guided_images_mixing_stable_diffusion.py b/diffusers/examples/community/clip_guided_images_mixing_stable_diffusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..f9a4b12ad20fa59ecc690cbc71be7007667edc71
--- /dev/null
+++ b/diffusers/examples/community/clip_guided_images_mixing_stable_diffusion.py
@@ -0,0 +1,445 @@
+# -*- coding: utf-8 -*-
+import inspect
+from typing import Optional, Union
+
+import numpy as np
+import PIL.Image
+import torch
+from torch.nn import functional as F
+from torchvision import transforms
+from transformers import CLIPImageProcessor, CLIPModel, CLIPTextModel, CLIPTokenizer
+
+from diffusers import (
+ AutoencoderKL,
+ DDIMScheduler,
+ DPMSolverMultistepScheduler,
+ LMSDiscreteScheduler,
+ PNDMScheduler,
+ UNet2DConditionModel,
+)
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
+from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput
+from diffusers.utils import PIL_INTERPOLATION
+from diffusers.utils.torch_utils import randn_tensor
+
+
+def preprocess(image, w, h):
+ if isinstance(image, torch.Tensor):
+ return image
+ elif isinstance(image, PIL.Image.Image):
+ image = [image]
+
+ if isinstance(image[0], PIL.Image.Image):
+ image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image]
+ image = np.concatenate(image, axis=0)
+ image = np.array(image).astype(np.float32) / 255.0
+ image = image.transpose(0, 3, 1, 2)
+ image = 2.0 * image - 1.0
+ image = torch.from_numpy(image)
+ elif isinstance(image[0], torch.Tensor):
+ image = torch.cat(image, dim=0)
+ return image
+
+
+def slerp(t, v0, v1, DOT_THRESHOLD=0.9995):
+ if not isinstance(v0, np.ndarray):
+ inputs_are_torch = True
+ input_device = v0.device
+ v0 = v0.cpu().numpy()
+ v1 = v1.cpu().numpy()
+
+ dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1)))
+ if np.abs(dot) > DOT_THRESHOLD:
+ v2 = (1 - t) * v0 + t * v1
+ else:
+ theta_0 = np.arccos(dot)
+ sin_theta_0 = np.sin(theta_0)
+ theta_t = theta_0 * t
+ sin_theta_t = np.sin(theta_t)
+ s0 = np.sin(theta_0 - theta_t) / sin_theta_0
+ s1 = sin_theta_t / sin_theta_0
+ v2 = s0 * v0 + s1 * v1
+
+ if inputs_are_torch:
+ v2 = torch.from_numpy(v2).to(input_device)
+
+ return v2
+
+
+def spherical_dist_loss(x, y):
+ x = F.normalize(x, dim=-1)
+ y = F.normalize(y, dim=-1)
+ return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)
+
+
+def set_requires_grad(model, value):
+ for param in model.parameters():
+ param.requires_grad = value
+
+
+class CLIPGuidedImagesMixingStableDiffusion(DiffusionPipeline, StableDiffusionMixin):
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ clip_model: CLIPModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: Union[PNDMScheduler, LMSDiscreteScheduler, DDIMScheduler, DPMSolverMultistepScheduler],
+ feature_extractor: CLIPImageProcessor,
+ coca_model=None,
+ coca_tokenizer=None,
+ coca_transform=None,
+ ):
+ super().__init__()
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ clip_model=clip_model,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ feature_extractor=feature_extractor,
+ coca_model=coca_model,
+ coca_tokenizer=coca_tokenizer,
+ coca_transform=coca_transform,
+ )
+ self.feature_extractor_size = (
+ feature_extractor.size
+ if isinstance(feature_extractor.size, int)
+ else feature_extractor.size["shortest_edge"]
+ )
+ self.normalize = transforms.Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
+ set_requires_grad(self.text_encoder, False)
+ set_requires_grad(self.clip_model, False)
+
+ def freeze_vae(self):
+ set_requires_grad(self.vae, False)
+
+ def unfreeze_vae(self):
+ set_requires_grad(self.vae, True)
+
+ def freeze_unet(self):
+ set_requires_grad(self.unet, False)
+
+ def unfreeze_unet(self):
+ set_requires_grad(self.unet, True)
+
+ def get_timesteps(self, num_inference_steps, strength, device):
+ # get the original timestep using init_timestep
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
+
+ t_start = max(num_inference_steps - init_timestep, 0)
+ timesteps = self.scheduler.timesteps[t_start:]
+
+ return timesteps, num_inference_steps - t_start
+
+ def prepare_latents(self, image, timestep, batch_size, dtype, device, generator=None):
+ if not isinstance(image, torch.Tensor):
+ raise ValueError(f"`image` has to be of type `torch.Tensor` but is {type(image)}")
+
+ image = image.to(device=device, dtype=dtype)
+
+ if isinstance(generator, list):
+ init_latents = [
+ self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
+ ]
+ init_latents = torch.cat(init_latents, dim=0)
+ else:
+ init_latents = self.vae.encode(image).latent_dist.sample(generator)
+
+ # Hardcode 0.18215 because stable-diffusion-2-base has not self.vae.config.scaling_factor
+ init_latents = 0.18215 * init_latents
+ init_latents = init_latents.repeat_interleave(batch_size, dim=0)
+
+ noise = randn_tensor(init_latents.shape, generator=generator, device=device, dtype=dtype)
+
+ # get latents
+ init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
+ latents = init_latents
+
+ return latents
+
+ def get_image_description(self, image):
+ transformed_image = self.coca_transform(image).unsqueeze(0)
+ with torch.no_grad(), torch.cuda.amp.autocast():
+ generated = self.coca_model.generate(transformed_image.to(device=self.device, dtype=self.coca_model.dtype))
+ generated = self.coca_tokenizer.decode(generated[0].cpu().numpy())
+ return generated.split("")[0].replace("", "").rstrip(" .,")
+
+ def get_clip_image_embeddings(self, image, batch_size):
+ clip_image_input = self.feature_extractor.preprocess(image)
+ clip_image_features = torch.from_numpy(clip_image_input["pixel_values"][0]).unsqueeze(0).to(self.device).half()
+ image_embeddings_clip = self.clip_model.get_image_features(clip_image_features)
+ image_embeddings_clip = image_embeddings_clip / image_embeddings_clip.norm(p=2, dim=-1, keepdim=True)
+ image_embeddings_clip = image_embeddings_clip.repeat_interleave(batch_size, dim=0)
+ return image_embeddings_clip
+
+ @torch.enable_grad()
+ def cond_fn(
+ self,
+ latents,
+ timestep,
+ index,
+ text_embeddings,
+ noise_pred_original,
+ original_image_embeddings_clip,
+ clip_guidance_scale,
+ ):
+ latents = latents.detach().requires_grad_()
+
+ latent_model_input = self.scheduler.scale_model_input(latents, timestep)
+
+ # predict the noise residual
+ noise_pred = self.unet(latent_model_input, timestep, encoder_hidden_states=text_embeddings).sample
+
+ if isinstance(self.scheduler, (PNDMScheduler, DDIMScheduler, DPMSolverMultistepScheduler)):
+ alpha_prod_t = self.scheduler.alphas_cumprod[timestep]
+ beta_prod_t = 1 - alpha_prod_t
+ # compute predicted original sample from predicted noise also called
+ # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
+ pred_original_sample = (latents - beta_prod_t ** (0.5) * noise_pred) / alpha_prod_t ** (0.5)
+
+ fac = torch.sqrt(beta_prod_t)
+ sample = pred_original_sample * (fac) + latents * (1 - fac)
+ elif isinstance(self.scheduler, LMSDiscreteScheduler):
+ sigma = self.scheduler.sigmas[index]
+ sample = latents - sigma * noise_pred
+ else:
+ raise ValueError(f"scheduler type {type(self.scheduler)} not supported")
+
+ # Hardcode 0.18215 because stable-diffusion-2-base has not self.vae.config.scaling_factor
+ sample = 1 / 0.18215 * sample
+ image = self.vae.decode(sample).sample
+ image = (image / 2 + 0.5).clamp(0, 1)
+
+ image = transforms.Resize(self.feature_extractor_size)(image)
+ image = self.normalize(image).to(latents.dtype)
+
+ image_embeddings_clip = self.clip_model.get_image_features(image)
+ image_embeddings_clip = image_embeddings_clip / image_embeddings_clip.norm(p=2, dim=-1, keepdim=True)
+
+ loss = spherical_dist_loss(image_embeddings_clip, original_image_embeddings_clip).mean() * clip_guidance_scale
+
+ grads = -torch.autograd.grad(loss, latents)[0]
+
+ if isinstance(self.scheduler, LMSDiscreteScheduler):
+ latents = latents.detach() + grads * (sigma**2)
+ noise_pred = noise_pred_original
+ else:
+ noise_pred = noise_pred_original - torch.sqrt(beta_prod_t) * grads
+ return noise_pred, latents
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ style_image: Union[torch.Tensor, PIL.Image.Image],
+ content_image: Union[torch.Tensor, PIL.Image.Image],
+ style_prompt: Optional[str] = None,
+ content_prompt: Optional[str] = None,
+ height: Optional[int] = 512,
+ width: Optional[int] = 512,
+ noise_strength: float = 0.6,
+ num_inference_steps: Optional[int] = 50,
+ guidance_scale: Optional[float] = 7.5,
+ batch_size: Optional[int] = 1,
+ eta: float = 0.0,
+ clip_guidance_scale: Optional[float] = 100,
+ generator: Optional[torch.Generator] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ slerp_latent_style_strength: float = 0.8,
+ slerp_prompt_style_strength: float = 0.1,
+ slerp_clip_image_style_strength: float = 0.1,
+ ):
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(f"You have passed {batch_size} batch_size, but only {len(generator)} generators.")
+
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if isinstance(generator, torch.Generator) and batch_size > 1:
+ generator = [generator] + [None] * (batch_size - 1)
+
+ coca_is_none = [
+ ("model", self.coca_model is None),
+ ("tokenizer", self.coca_tokenizer is None),
+ ("transform", self.coca_transform is None),
+ ]
+ coca_is_none = [x[0] for x in coca_is_none if x[1]]
+ coca_is_none_str = ", ".join(coca_is_none)
+ # generate prompts with coca model if prompt is None
+ if content_prompt is None:
+ if len(coca_is_none):
+ raise ValueError(
+ f"Content prompt is None and CoCa [{coca_is_none_str}] is None."
+ f"Set prompt or pass Coca [{coca_is_none_str}] to DiffusionPipeline."
+ )
+ content_prompt = self.get_image_description(content_image)
+ if style_prompt is None:
+ if len(coca_is_none):
+ raise ValueError(
+ f"Style prompt is None and CoCa [{coca_is_none_str}] is None."
+ f" Set prompt or pass Coca [{coca_is_none_str}] to DiffusionPipeline."
+ )
+ style_prompt = self.get_image_description(style_image)
+
+ # get prompt text embeddings for content and style
+ content_text_input = self.tokenizer(
+ content_prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ content_text_embeddings = self.text_encoder(content_text_input.input_ids.to(self.device))[0]
+
+ style_text_input = self.tokenizer(
+ style_prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ style_text_embeddings = self.text_encoder(style_text_input.input_ids.to(self.device))[0]
+
+ text_embeddings = slerp(slerp_prompt_style_strength, content_text_embeddings, style_text_embeddings)
+
+ # duplicate text embeddings for each generation per prompt
+ text_embeddings = text_embeddings.repeat_interleave(batch_size, dim=0)
+
+ # set timesteps
+ accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())
+ extra_set_kwargs = {}
+ if accepts_offset:
+ extra_set_kwargs["offset"] = 1
+
+ self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
+ # Some schedulers like PNDM have timesteps as arrays
+ # It's more optimized to move all timesteps to correct device beforehand
+ self.scheduler.timesteps.to(self.device)
+
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, noise_strength, self.device)
+ latent_timestep = timesteps[:1].repeat(batch_size)
+
+ # Preprocess image
+ preprocessed_content_image = preprocess(content_image, width, height)
+ content_latents = self.prepare_latents(
+ preprocessed_content_image, latent_timestep, batch_size, text_embeddings.dtype, self.device, generator
+ )
+
+ preprocessed_style_image = preprocess(style_image, width, height)
+ style_latents = self.prepare_latents(
+ preprocessed_style_image, latent_timestep, batch_size, text_embeddings.dtype, self.device, generator
+ )
+
+ latents = slerp(slerp_latent_style_strength, content_latents, style_latents)
+
+ if clip_guidance_scale > 0:
+ content_clip_image_embedding = self.get_clip_image_embeddings(content_image, batch_size)
+ style_clip_image_embedding = self.get_clip_image_embeddings(style_image, batch_size)
+ clip_image_embeddings = slerp(
+ slerp_clip_image_style_strength, content_clip_image_embedding, style_clip_image_embedding
+ )
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance:
+ max_length = content_text_input.input_ids.shape[-1]
+ uncond_input = self.tokenizer([""], padding="max_length", max_length=max_length, return_tensors="pt")
+ uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
+ # duplicate unconditional embeddings for each generation per prompt
+ uncond_embeddings = uncond_embeddings.repeat_interleave(batch_size, dim=0)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
+
+ # get the initial random noise unless the user supplied it
+
+ # Unlike in other pipelines, latents need to be generated in the target device
+ # for 1-to-1 results reproducibility with the CompVis implementation.
+ # However this currently doesn't work in `mps`.
+ latents_shape = (batch_size, self.unet.config.in_channels, height // 8, width // 8)
+ latents_dtype = text_embeddings.dtype
+ if latents is None:
+ if self.device.type == "mps":
+ # randn does not work reproducibly on mps
+ latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to(
+ self.device
+ )
+ else:
+ latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)
+ else:
+ if latents.shape != latents_shape:
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
+ latents = latents.to(self.device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
+
+ # perform classifier free guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # perform clip guidance
+ if clip_guidance_scale > 0:
+ text_embeddings_for_guidance = (
+ text_embeddings.chunk(2)[1] if do_classifier_free_guidance else text_embeddings
+ )
+ noise_pred, latents = self.cond_fn(
+ latents,
+ t,
+ i,
+ text_embeddings_for_guidance,
+ noise_pred,
+ clip_image_embeddings,
+ clip_guidance_scale,
+ )
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+ progress_bar.update()
+ # Hardcode 0.18215 because stable-diffusion-2-base has not self.vae.config.scaling_factor
+ latents = 1 / 0.18215 * latents
+ image = self.vae.decode(latents).sample
+
+ image = (image / 2 + 0.5).clamp(0, 1)
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
+
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image, None)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=None)
diff --git a/diffusers/examples/community/clip_guided_stable_diffusion.py b/diffusers/examples/community/clip_guided_stable_diffusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..1350650113dada84073f16d8cabc6a63eb9d0165
--- /dev/null
+++ b/diffusers/examples/community/clip_guided_stable_diffusion.py
@@ -0,0 +1,337 @@
+import inspect
+from typing import List, Optional, Union
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+from torchvision import transforms
+from transformers import CLIPImageProcessor, CLIPModel, CLIPTextModel, CLIPTokenizer
+
+from diffusers import (
+ AutoencoderKL,
+ DDIMScheduler,
+ DPMSolverMultistepScheduler,
+ LMSDiscreteScheduler,
+ PNDMScheduler,
+ UNet2DConditionModel,
+)
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
+from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput
+
+
+class MakeCutouts(nn.Module):
+ def __init__(self, cut_size, cut_power=1.0):
+ super().__init__()
+
+ self.cut_size = cut_size
+ self.cut_power = cut_power
+
+ def forward(self, pixel_values, num_cutouts):
+ sideY, sideX = pixel_values.shape[2:4]
+ max_size = min(sideX, sideY)
+ min_size = min(sideX, sideY, self.cut_size)
+ cutouts = []
+ for _ in range(num_cutouts):
+ size = int(torch.rand([]) ** self.cut_power * (max_size - min_size) + min_size)
+ offsetx = torch.randint(0, sideX - size + 1, ())
+ offsety = torch.randint(0, sideY - size + 1, ())
+ cutout = pixel_values[:, :, offsety : offsety + size, offsetx : offsetx + size]
+ cutouts.append(F.adaptive_avg_pool2d(cutout, self.cut_size))
+ return torch.cat(cutouts)
+
+
+def spherical_dist_loss(x, y):
+ x = F.normalize(x, dim=-1)
+ y = F.normalize(y, dim=-1)
+ return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)
+
+
+def set_requires_grad(model, value):
+ for param in model.parameters():
+ param.requires_grad = value
+
+
+class CLIPGuidedStableDiffusion(DiffusionPipeline, StableDiffusionMixin):
+ """CLIP guided stable diffusion based on the amazing repo by @crowsonkb and @Jack000
+ - https://github.com/Jack000/glid-3-xl
+ - https://github.dev/crowsonkb/k-diffusion
+ """
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ clip_model: CLIPModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: Union[PNDMScheduler, LMSDiscreteScheduler, DDIMScheduler, DPMSolverMultistepScheduler],
+ feature_extractor: CLIPImageProcessor,
+ ):
+ super().__init__()
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ clip_model=clip_model,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ feature_extractor=feature_extractor,
+ )
+
+ self.normalize = transforms.Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
+ self.cut_out_size = (
+ feature_extractor.size
+ if isinstance(feature_extractor.size, int)
+ else feature_extractor.size["shortest_edge"]
+ )
+ self.make_cutouts = MakeCutouts(self.cut_out_size)
+
+ set_requires_grad(self.text_encoder, False)
+ set_requires_grad(self.clip_model, False)
+
+ def freeze_vae(self):
+ set_requires_grad(self.vae, False)
+
+ def unfreeze_vae(self):
+ set_requires_grad(self.vae, True)
+
+ def freeze_unet(self):
+ set_requires_grad(self.unet, False)
+
+ def unfreeze_unet(self):
+ set_requires_grad(self.unet, True)
+
+ @torch.enable_grad()
+ def cond_fn(
+ self,
+ latents,
+ timestep,
+ index,
+ text_embeddings,
+ noise_pred_original,
+ text_embeddings_clip,
+ clip_guidance_scale,
+ num_cutouts,
+ use_cutouts=True,
+ ):
+ latents = latents.detach().requires_grad_()
+
+ latent_model_input = self.scheduler.scale_model_input(latents, timestep)
+
+ # predict the noise residual
+ noise_pred = self.unet(latent_model_input, timestep, encoder_hidden_states=text_embeddings).sample
+
+ if isinstance(self.scheduler, (PNDMScheduler, DDIMScheduler, DPMSolverMultistepScheduler)):
+ alpha_prod_t = self.scheduler.alphas_cumprod[timestep]
+ beta_prod_t = 1 - alpha_prod_t
+ # compute predicted original sample from predicted noise also called
+ # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
+ pred_original_sample = (latents - beta_prod_t ** (0.5) * noise_pred) / alpha_prod_t ** (0.5)
+
+ fac = torch.sqrt(beta_prod_t)
+ sample = pred_original_sample * (fac) + latents * (1 - fac)
+ elif isinstance(self.scheduler, LMSDiscreteScheduler):
+ sigma = self.scheduler.sigmas[index]
+ sample = latents - sigma * noise_pred
+ else:
+ raise ValueError(f"scheduler type {type(self.scheduler)} not supported")
+
+ sample = 1 / self.vae.config.scaling_factor * sample
+ image = self.vae.decode(sample).sample
+ image = (image / 2 + 0.5).clamp(0, 1)
+
+ if use_cutouts:
+ image = self.make_cutouts(image, num_cutouts)
+ else:
+ image = transforms.Resize(self.cut_out_size)(image)
+ image = self.normalize(image).to(latents.dtype)
+
+ image_embeddings_clip = self.clip_model.get_image_features(image)
+ image_embeddings_clip = image_embeddings_clip / image_embeddings_clip.norm(p=2, dim=-1, keepdim=True)
+
+ if use_cutouts:
+ dists = spherical_dist_loss(image_embeddings_clip, text_embeddings_clip)
+ dists = dists.view([num_cutouts, sample.shape[0], -1])
+ loss = dists.sum(2).mean(0).sum() * clip_guidance_scale
+ else:
+ loss = spherical_dist_loss(image_embeddings_clip, text_embeddings_clip).mean() * clip_guidance_scale
+
+ grads = -torch.autograd.grad(loss, latents)[0]
+
+ if isinstance(self.scheduler, LMSDiscreteScheduler):
+ latents = latents.detach() + grads * (sigma**2)
+ noise_pred = noise_pred_original
+ else:
+ noise_pred = noise_pred_original - torch.sqrt(beta_prod_t) * grads
+ return noise_pred, latents
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]],
+ height: Optional[int] = 512,
+ width: Optional[int] = 512,
+ num_inference_steps: Optional[int] = 50,
+ guidance_scale: Optional[float] = 7.5,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ clip_guidance_scale: Optional[float] = 100,
+ clip_prompt: Optional[Union[str, List[str]]] = None,
+ num_cutouts: Optional[int] = 4,
+ use_cutouts: Optional[bool] = True,
+ generator: Optional[torch.Generator] = None,
+ latents: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ ):
+ if isinstance(prompt, str):
+ batch_size = 1
+ elif isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ # get prompt text embeddings
+ text_input = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
+ # duplicate text embeddings for each generation per prompt
+ text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
+
+ if clip_guidance_scale > 0:
+ if clip_prompt is not None:
+ clip_text_input = self.tokenizer(
+ clip_prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ ).input_ids.to(self.device)
+ else:
+ clip_text_input = text_input.input_ids.to(self.device)
+ text_embeddings_clip = self.clip_model.get_text_features(clip_text_input)
+ text_embeddings_clip = text_embeddings_clip / text_embeddings_clip.norm(p=2, dim=-1, keepdim=True)
+ # duplicate text embeddings clip for each generation per prompt
+ text_embeddings_clip = text_embeddings_clip.repeat_interleave(num_images_per_prompt, dim=0)
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance:
+ max_length = text_input.input_ids.shape[-1]
+ uncond_input = self.tokenizer([""], padding="max_length", max_length=max_length, return_tensors="pt")
+ uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
+ # duplicate unconditional embeddings for each generation per prompt
+ uncond_embeddings = uncond_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
+
+ # get the initial random noise unless the user supplied it
+
+ # Unlike in other pipelines, latents need to be generated in the target device
+ # for 1-to-1 results reproducibility with the CompVis implementation.
+ # However this currently doesn't work in `mps`.
+ latents_shape = (batch_size * num_images_per_prompt, self.unet.config.in_channels, height // 8, width // 8)
+ latents_dtype = text_embeddings.dtype
+ if latents is None:
+ if self.device.type == "mps":
+ # randn does not work reproducibly on mps
+ latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to(
+ self.device
+ )
+ else:
+ latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)
+ else:
+ if latents.shape != latents_shape:
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
+ latents = latents.to(self.device)
+
+ # set timesteps
+ accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())
+ extra_set_kwargs = {}
+ if accepts_offset:
+ extra_set_kwargs["offset"] = 1
+
+ self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
+
+ # Some schedulers like PNDM have timesteps as arrays
+ # It's more optimized to move all timesteps to correct device beforehand
+ timesteps_tensor = self.scheduler.timesteps.to(self.device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+
+ for i, t in enumerate(self.progress_bar(timesteps_tensor)):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
+
+ # perform classifier free guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # perform clip guidance
+ if clip_guidance_scale > 0:
+ text_embeddings_for_guidance = (
+ text_embeddings.chunk(2)[1] if do_classifier_free_guidance else text_embeddings
+ )
+ noise_pred, latents = self.cond_fn(
+ latents,
+ t,
+ i,
+ text_embeddings_for_guidance,
+ noise_pred,
+ text_embeddings_clip,
+ clip_guidance_scale,
+ num_cutouts,
+ use_cutouts,
+ )
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+ # scale and decode the image latents with vae
+ latents = 1 / self.vae.config.scaling_factor * latents
+ image = self.vae.decode(latents).sample
+
+ image = (image / 2 + 0.5).clamp(0, 1)
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
+
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image, None)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=None)
diff --git a/diffusers/examples/community/clip_guided_stable_diffusion_img2img.py b/diffusers/examples/community/clip_guided_stable_diffusion_img2img.py
new file mode 100644
index 0000000000000000000000000000000000000000..91c74b9ffa74ad415cf75b8c53506afed6ccca5f
--- /dev/null
+++ b/diffusers/examples/community/clip_guided_stable_diffusion_img2img.py
@@ -0,0 +1,490 @@
+import inspect
+from typing import List, Optional, Union
+
+import numpy as np
+import PIL.Image
+import torch
+from torch import nn
+from torch.nn import functional as F
+from torchvision import transforms
+from transformers import CLIPImageProcessor, CLIPModel, CLIPTextModel, CLIPTokenizer
+
+from diffusers import (
+ AutoencoderKL,
+ DDIMScheduler,
+ DPMSolverMultistepScheduler,
+ LMSDiscreteScheduler,
+ PNDMScheduler,
+ UNet2DConditionModel,
+)
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
+from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput
+from diffusers.utils import PIL_INTERPOLATION, deprecate
+from diffusers.utils.torch_utils import randn_tensor
+
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ from io import BytesIO
+
+ import requests
+ import torch
+ from diffusers import DiffusionPipeline
+ from PIL import Image
+ from transformers import CLIPImageProcessor, CLIPModel
+
+ feature_extractor = CLIPImageProcessor.from_pretrained(
+ "laion/CLIP-ViT-B-32-laion2B-s34B-b79K"
+ )
+ clip_model = CLIPModel.from_pretrained(
+ "laion/CLIP-ViT-B-32-laion2B-s34B-b79K", torch_dtype=torch.float16
+ )
+
+
+ guided_pipeline = DiffusionPipeline.from_pretrained(
+ "CompVis/stable-diffusion-v1-4",
+ # custom_pipeline="clip_guided_stable_diffusion",
+ custom_pipeline="/home/njindal/diffusers/examples/community/clip_guided_stable_diffusion.py",
+ clip_model=clip_model,
+ feature_extractor=feature_extractor,
+ torch_dtype=torch.float16,
+ )
+ guided_pipeline.enable_attention_slicing()
+ guided_pipeline = guided_pipeline.to("cuda")
+
+ prompt = "fantasy book cover, full moon, fantasy forest landscape, golden vector elements, fantasy magic, dark light night, intricate, elegant, sharp focus, illustration, highly detailed, digital painting, concept art, matte, art by WLOP and Artgerm and Albert Bierstadt, masterpiece"
+
+ url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
+
+ response = requests.get(url)
+ init_image = Image.open(BytesIO(response.content)).convert("RGB")
+
+ image = guided_pipeline(
+ prompt=prompt,
+ num_inference_steps=30,
+ image=init_image,
+ strength=0.75,
+ guidance_scale=7.5,
+ clip_guidance_scale=100,
+ num_cutouts=4,
+ use_cutouts=False,
+ ).images[0]
+ display(image)
+ ```
+"""
+
+
+def preprocess(image, w, h):
+ if isinstance(image, torch.Tensor):
+ return image
+ elif isinstance(image, PIL.Image.Image):
+ image = [image]
+
+ if isinstance(image[0], PIL.Image.Image):
+ image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image]
+ image = np.concatenate(image, axis=0)
+ image = np.array(image).astype(np.float32) / 255.0
+ image = image.transpose(0, 3, 1, 2)
+ image = 2.0 * image - 1.0
+ image = torch.from_numpy(image)
+ elif isinstance(image[0], torch.Tensor):
+ image = torch.cat(image, dim=0)
+ return image
+
+
+class MakeCutouts(nn.Module):
+ def __init__(self, cut_size, cut_power=1.0):
+ super().__init__()
+
+ self.cut_size = cut_size
+ self.cut_power = cut_power
+
+ def forward(self, pixel_values, num_cutouts):
+ sideY, sideX = pixel_values.shape[2:4]
+ max_size = min(sideX, sideY)
+ min_size = min(sideX, sideY, self.cut_size)
+ cutouts = []
+ for _ in range(num_cutouts):
+ size = int(torch.rand([]) ** self.cut_power * (max_size - min_size) + min_size)
+ offsetx = torch.randint(0, sideX - size + 1, ())
+ offsety = torch.randint(0, sideY - size + 1, ())
+ cutout = pixel_values[:, :, offsety : offsety + size, offsetx : offsetx + size]
+ cutouts.append(F.adaptive_avg_pool2d(cutout, self.cut_size))
+ return torch.cat(cutouts)
+
+
+def spherical_dist_loss(x, y):
+ x = F.normalize(x, dim=-1)
+ y = F.normalize(y, dim=-1)
+ return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)
+
+
+def set_requires_grad(model, value):
+ for param in model.parameters():
+ param.requires_grad = value
+
+
+class CLIPGuidedStableDiffusion(DiffusionPipeline, StableDiffusionMixin):
+ """CLIP guided stable diffusion based on the amazing repo by @crowsonkb and @Jack000
+ - https://github.com/Jack000/glid-3-xl
+ - https://github.dev/crowsonkb/k-diffusion
+ """
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ clip_model: CLIPModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: Union[PNDMScheduler, LMSDiscreteScheduler, DDIMScheduler, DPMSolverMultistepScheduler],
+ feature_extractor: CLIPImageProcessor,
+ ):
+ super().__init__()
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ clip_model=clip_model,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ feature_extractor=feature_extractor,
+ )
+
+ self.normalize = transforms.Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
+ self.cut_out_size = (
+ feature_extractor.size
+ if isinstance(feature_extractor.size, int)
+ else feature_extractor.size["shortest_edge"]
+ )
+ self.make_cutouts = MakeCutouts(self.cut_out_size)
+
+ set_requires_grad(self.text_encoder, False)
+ set_requires_grad(self.clip_model, False)
+
+ def freeze_vae(self):
+ set_requires_grad(self.vae, False)
+
+ def unfreeze_vae(self):
+ set_requires_grad(self.vae, True)
+
+ def freeze_unet(self):
+ set_requires_grad(self.unet, False)
+
+ def unfreeze_unet(self):
+ set_requires_grad(self.unet, True)
+
+ def get_timesteps(self, num_inference_steps, strength, device):
+ # get the original timestep using init_timestep
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
+
+ t_start = max(num_inference_steps - init_timestep, 0)
+ timesteps = self.scheduler.timesteps[t_start:]
+
+ return timesteps, num_inference_steps - t_start
+
+ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
+ if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
+ raise ValueError(
+ f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
+ )
+
+ image = image.to(device=device, dtype=dtype)
+
+ batch_size = batch_size * num_images_per_prompt
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if isinstance(generator, list):
+ init_latents = [
+ self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
+ ]
+ init_latents = torch.cat(init_latents, dim=0)
+ else:
+ init_latents = self.vae.encode(image).latent_dist.sample(generator)
+
+ init_latents = self.vae.config.scaling_factor * init_latents
+
+ if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
+ # expand init_latents for batch_size
+ deprecation_message = (
+ f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial"
+ " images (`image`). Initial images are now duplicating to match the number of text prompts. Note"
+ " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
+ " your script to pass as many initial images as text prompts to suppress this warning."
+ )
+ deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False)
+ additional_image_per_prompt = batch_size // init_latents.shape[0]
+ init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0)
+ elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
+ raise ValueError(
+ f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
+ )
+ else:
+ init_latents = torch.cat([init_latents], dim=0)
+
+ shape = init_latents.shape
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+
+ # get latents
+ init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
+ latents = init_latents
+
+ return latents
+
+ @torch.enable_grad()
+ def cond_fn(
+ self,
+ latents,
+ timestep,
+ index,
+ text_embeddings,
+ noise_pred_original,
+ text_embeddings_clip,
+ clip_guidance_scale,
+ num_cutouts,
+ use_cutouts=True,
+ ):
+ latents = latents.detach().requires_grad_()
+
+ latent_model_input = self.scheduler.scale_model_input(latents, timestep)
+
+ # predict the noise residual
+ noise_pred = self.unet(latent_model_input, timestep, encoder_hidden_states=text_embeddings).sample
+
+ if isinstance(self.scheduler, (PNDMScheduler, DDIMScheduler, DPMSolverMultistepScheduler)):
+ alpha_prod_t = self.scheduler.alphas_cumprod[timestep]
+ beta_prod_t = 1 - alpha_prod_t
+ # compute predicted original sample from predicted noise also called
+ # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
+ pred_original_sample = (latents - beta_prod_t ** (0.5) * noise_pred) / alpha_prod_t ** (0.5)
+
+ fac = torch.sqrt(beta_prod_t)
+ sample = pred_original_sample * (fac) + latents * (1 - fac)
+ elif isinstance(self.scheduler, LMSDiscreteScheduler):
+ sigma = self.scheduler.sigmas[index]
+ sample = latents - sigma * noise_pred
+ else:
+ raise ValueError(f"scheduler type {type(self.scheduler)} not supported")
+
+ sample = 1 / self.vae.config.scaling_factor * sample
+ image = self.vae.decode(sample).sample
+ image = (image / 2 + 0.5).clamp(0, 1)
+
+ if use_cutouts:
+ image = self.make_cutouts(image, num_cutouts)
+ else:
+ image = transforms.Resize(self.cut_out_size)(image)
+ image = self.normalize(image).to(latents.dtype)
+
+ image_embeddings_clip = self.clip_model.get_image_features(image)
+ image_embeddings_clip = image_embeddings_clip / image_embeddings_clip.norm(p=2, dim=-1, keepdim=True)
+
+ if use_cutouts:
+ dists = spherical_dist_loss(image_embeddings_clip, text_embeddings_clip)
+ dists = dists.view([num_cutouts, sample.shape[0], -1])
+ loss = dists.sum(2).mean(0).sum() * clip_guidance_scale
+ else:
+ loss = spherical_dist_loss(image_embeddings_clip, text_embeddings_clip).mean() * clip_guidance_scale
+
+ grads = -torch.autograd.grad(loss, latents)[0]
+
+ if isinstance(self.scheduler, LMSDiscreteScheduler):
+ latents = latents.detach() + grads * (sigma**2)
+ noise_pred = noise_pred_original
+ else:
+ noise_pred = noise_pred_original - torch.sqrt(beta_prod_t) * grads
+ return noise_pred, latents
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]],
+ height: Optional[int] = 512,
+ width: Optional[int] = 512,
+ image: Union[torch.Tensor, PIL.Image.Image] = None,
+ strength: float = 0.8,
+ num_inference_steps: Optional[int] = 50,
+ guidance_scale: Optional[float] = 7.5,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ clip_guidance_scale: Optional[float] = 100,
+ clip_prompt: Optional[Union[str, List[str]]] = None,
+ num_cutouts: Optional[int] = 4,
+ use_cutouts: Optional[bool] = True,
+ generator: Optional[torch.Generator] = None,
+ latents: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ ):
+ if isinstance(prompt, str):
+ batch_size = 1
+ elif isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ # get prompt text embeddings
+ text_input = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
+ # duplicate text embeddings for each generation per prompt
+ text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
+
+ # set timesteps
+ accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())
+ extra_set_kwargs = {}
+ if accepts_offset:
+ extra_set_kwargs["offset"] = 1
+
+ self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
+ # Some schedulers like PNDM have timesteps as arrays
+ # It's more optimized to move all timesteps to correct device beforehand
+ self.scheduler.timesteps.to(self.device)
+
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, self.device)
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
+
+ # Preprocess image
+ image = preprocess(image, width, height)
+ if latents is None:
+ latents = self.prepare_latents(
+ image,
+ latent_timestep,
+ batch_size,
+ num_images_per_prompt,
+ text_embeddings.dtype,
+ self.device,
+ generator,
+ )
+
+ if clip_guidance_scale > 0:
+ if clip_prompt is not None:
+ clip_text_input = self.tokenizer(
+ clip_prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ ).input_ids.to(self.device)
+ else:
+ clip_text_input = text_input.input_ids.to(self.device)
+ text_embeddings_clip = self.clip_model.get_text_features(clip_text_input)
+ text_embeddings_clip = text_embeddings_clip / text_embeddings_clip.norm(p=2, dim=-1, keepdim=True)
+ # duplicate text embeddings clip for each generation per prompt
+ text_embeddings_clip = text_embeddings_clip.repeat_interleave(num_images_per_prompt, dim=0)
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance:
+ max_length = text_input.input_ids.shape[-1]
+ uncond_input = self.tokenizer([""], padding="max_length", max_length=max_length, return_tensors="pt")
+ uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
+ # duplicate unconditional embeddings for each generation per prompt
+ uncond_embeddings = uncond_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
+
+ # get the initial random noise unless the user supplied it
+
+ # Unlike in other pipelines, latents need to be generated in the target device
+ # for 1-to-1 results reproducibility with the CompVis implementation.
+ # However this currently doesn't work in `mps`.
+ latents_shape = (batch_size * num_images_per_prompt, self.unet.config.in_channels, height // 8, width // 8)
+ latents_dtype = text_embeddings.dtype
+ if latents is None:
+ if self.device.type == "mps":
+ # randn does not work reproducibly on mps
+ latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to(
+ self.device
+ )
+ else:
+ latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)
+ else:
+ if latents.shape != latents_shape:
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
+ latents = latents.to(self.device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+
+ with self.progress_bar(total=num_inference_steps):
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
+
+ # perform classifier free guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # perform clip guidance
+ if clip_guidance_scale > 0:
+ text_embeddings_for_guidance = (
+ text_embeddings.chunk(2)[1] if do_classifier_free_guidance else text_embeddings
+ )
+ noise_pred, latents = self.cond_fn(
+ latents,
+ t,
+ i,
+ text_embeddings_for_guidance,
+ noise_pred,
+ text_embeddings_clip,
+ clip_guidance_scale,
+ num_cutouts,
+ use_cutouts,
+ )
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+ # scale and decode the image latents with vae
+ latents = 1 / self.vae.config.scaling_factor * latents
+ image = self.vae.decode(latents).sample
+
+ image = (image / 2 + 0.5).clamp(0, 1)
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
+
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image, None)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=None)
diff --git a/diffusers/examples/community/composable_stable_diffusion.py b/diffusers/examples/community/composable_stable_diffusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..46d12ba1f2aa4222c99523917eec1274f50a41a0
--- /dev/null
+++ b/diffusers/examples/community/composable_stable_diffusion.py
@@ -0,0 +1,532 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Callable, List, Optional, Union
+
+import torch
+from packaging import version
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
+
+from diffusers import DiffusionPipeline
+from diffusers.configuration_utils import FrozenDict
+from diffusers.models import AutoencoderKL, UNet2DConditionModel
+from diffusers.pipelines.pipeline_utils import StableDiffusionMixin
+from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput
+from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
+from diffusers.schedulers import (
+ DDIMScheduler,
+ DPMSolverMultistepScheduler,
+ EulerAncestralDiscreteScheduler,
+ EulerDiscreteScheduler,
+ LMSDiscreteScheduler,
+ PNDMScheduler,
+)
+from diffusers.utils import deprecate, logging
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class ComposableStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin):
+ r"""
+ Pipeline for text-to-image generation using Stable Diffusion.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ feature_extractor ([`CLIPImageProcessor`]):
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+ """
+
+ _optional_components = ["safety_checker", "feature_extractor"]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: Union[
+ DDIMScheduler,
+ PNDMScheduler,
+ LMSDiscreteScheduler,
+ EulerDiscreteScheduler,
+ EulerAncestralDiscreteScheduler,
+ DPMSolverMultistepScheduler,
+ ],
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPImageProcessor,
+ requires_safety_checker: bool = True,
+ ):
+ super().__init__()
+
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
+ " file"
+ )
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["steps_offset"] = 1
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
+ )
+ deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["clip_sample"] = False
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if safety_checker is None and requires_safety_checker:
+ logger.warning(
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
+ )
+
+ if safety_checker is not None and feature_extractor is None:
+ raise ValueError(
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
+ )
+
+ is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
+ version.parse(unet.config._diffusers_version).base_version
+ ) < version.parse("0.9.0.dev0")
+ is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
+ deprecation_message = (
+ "The configuration file of the unet has set the default `sample_size` to smaller than"
+ " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
+ " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
+ " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
+ " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
+ " in the config might lead to incorrect results in future versions. If you have downloaded this"
+ " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
+ " the `unet/config.json` file"
+ )
+ deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(unet.config)
+ new_config["sample_size"] = 64
+ unet._internal_dict = FrozenDict(new_config)
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
+
+ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `list(int)`):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ """
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = text_inputs.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ text_embeddings = self.text_encoder(
+ text_input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ text_embeddings = text_embeddings[0]
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ bs_embed, seq_len, _ = text_embeddings.shape
+ text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
+ text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ max_length = text_input_ids.shape[-1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = uncond_input.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ uncond_embeddings = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ uncond_embeddings = uncond_embeddings[0]
+
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = uncond_embeddings.shape[1]
+ uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
+ uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
+
+ return text_embeddings
+
+ def run_safety_checker(self, image, device, dtype):
+ if self.safety_checker is not None:
+ safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
+ image, has_nsfw_concept = self.safety_checker(
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
+ )
+ else:
+ has_nsfw_concept = None
+ return image, has_nsfw_concept
+
+ def decode_latents(self, latents):
+ latents = 1 / 0.18215 * latents
+ image = self.vae.decode(latents).sample
+ image = (image / 2 + 0.5).clamp(0, 1)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+ return image
+
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(self, prompt, height, width, callback_steps):
+ if not isinstance(prompt, str) and not isinstance(prompt, list):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
+ shape = (
+ batch_size,
+ num_channels_latents,
+ int(height) // self.vae_scale_factor,
+ int(width) // self.vae_scale_factor,
+ )
+ if latents is None:
+ if device.type == "mps":
+ # randn does not work reproducibly on mps
+ latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
+ else:
+ latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ if latents.shape != shape:
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]],
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[torch.Generator] = None,
+ latents: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
+ callback_steps: int = 1,
+ weights: Optional[str] = "",
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image generation.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 5.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator`, *optional*):
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
+ deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+ # 0. Default height and width to unet
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(prompt, height, width, callback_steps)
+
+ # 2. Define call parameters
+ batch_size = 1 if isinstance(prompt, str) else len(prompt)
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ if "|" in prompt:
+ prompt = [x.strip() for x in prompt.split("|")]
+ print(f"composing {prompt}...")
+
+ if not weights:
+ # specify weights for prompts (excluding the unconditional score)
+ print("using equal positive weights (conjunction) for all prompts...")
+ weights = torch.tensor([guidance_scale] * len(prompt), device=self.device).reshape(-1, 1, 1, 1)
+ else:
+ # set prompt weight for each
+ num_prompts = len(prompt) if isinstance(prompt, list) else 1
+ weights = [float(w.strip()) for w in weights.split("|")]
+ # guidance scale as the default
+ if len(weights) < num_prompts:
+ weights.append(guidance_scale)
+ else:
+ weights = weights[:num_prompts]
+ assert len(weights) == len(prompt), "weights specified are not equal to the number of prompts"
+ weights = torch.tensor(weights, device=self.device).reshape(-1, 1, 1, 1)
+ else:
+ weights = guidance_scale
+
+ # 3. Encode input prompt
+ text_embeddings = self._encode_prompt(
+ prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
+ )
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.unet.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ text_embeddings.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # composable diffusion
+ if isinstance(prompt, list) and batch_size == 1:
+ # remove extra unconditional embedding
+ # N = one unconditional embed + conditional embeds
+ text_embeddings = text_embeddings[len(prompt) - 1 :]
+
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 7. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ noise_pred = []
+ for j in range(text_embeddings.shape[0]):
+ noise_pred.append(
+ self.unet(latent_model_input[:1], t, encoder_hidden_states=text_embeddings[j : j + 1]).sample
+ )
+ noise_pred = torch.cat(noise_pred, dim=0)
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred[:1], noise_pred[1:]
+ noise_pred = noise_pred_uncond + (weights * (noise_pred_text - noise_pred_uncond)).sum(
+ dim=0, keepdims=True
+ )
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
+
+ # 8. Post-processing
+ image = self.decode_latents(latents)
+
+ # 9. Run safety checker
+ image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype)
+
+ # 10. Convert to PIL
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
diff --git a/diffusers/examples/community/ddim_noise_comparative_analysis.py b/diffusers/examples/community/ddim_noise_comparative_analysis.py
new file mode 100644
index 0000000000000000000000000000000000000000..829106c47f65cb70746c031e9d4f9ef698035afa
--- /dev/null
+++ b/diffusers/examples/community/ddim_noise_comparative_analysis.py
@@ -0,0 +1,190 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import List, Optional, Tuple, Union
+
+import PIL.Image
+import torch
+from torchvision import transforms
+
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+from diffusers.schedulers import DDIMScheduler
+from diffusers.utils.torch_utils import randn_tensor
+
+
+trans = transforms.Compose(
+ [
+ transforms.Resize((256, 256)),
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+)
+
+
+def preprocess(image):
+ if isinstance(image, torch.Tensor):
+ return image
+ elif isinstance(image, PIL.Image.Image):
+ image = [image]
+
+ image = [trans(img.convert("RGB")) for img in image]
+ image = torch.stack(image)
+ return image
+
+
+class DDIMNoiseComparativeAnalysisPipeline(DiffusionPipeline):
+ r"""
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Parameters:
+ unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image. Can be one of
+ [`DDPMScheduler`], or [`DDIMScheduler`].
+ """
+
+ def __init__(self, unet, scheduler):
+ super().__init__()
+
+ # make sure scheduler can always be converted to DDIM
+ scheduler = DDIMScheduler.from_config(scheduler.config)
+
+ self.register_modules(unet=unet, scheduler=scheduler)
+
+ def check_inputs(self, strength):
+ if strength < 0 or strength > 1:
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
+
+ def get_timesteps(self, num_inference_steps, strength, device):
+ # get the original timestep using init_timestep
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
+
+ t_start = max(num_inference_steps - init_timestep, 0)
+ timesteps = self.scheduler.timesteps[t_start:]
+
+ return timesteps, num_inference_steps - t_start
+
+ def prepare_latents(self, image, timestep, batch_size, dtype, device, generator=None):
+ if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
+ raise ValueError(
+ f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
+ )
+
+ init_latents = image.to(device=device, dtype=dtype)
+
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ shape = init_latents.shape
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+
+ # get latents
+ print("add noise to latents at timestep", timestep)
+ init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
+ latents = init_latents
+
+ return latents
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ image: Union[torch.Tensor, PIL.Image.Image] = None,
+ strength: float = 0.8,
+ batch_size: int = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ eta: float = 0.0,
+ num_inference_steps: int = 50,
+ use_clipped_model_output: Optional[bool] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ ) -> Union[ImagePipelineOutput, Tuple]:
+ r"""
+ Args:
+ image (`torch.Tensor` or `PIL.Image.Image`):
+ `Image`, or tensor representing an image batch, that will be used as the starting point for the
+ process.
+ strength (`float`, *optional*, defaults to 0.8):
+ Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image`
+ will be used as a starting point, adding more noise to it the larger the `strength`. The number of
+ denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will
+ be maximum and the denoising process will run for the full number of iterations specified in
+ `num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
+ batch_size (`int`, *optional*, defaults to 1):
+ The number of images to generate.
+ generator (`torch.Generator`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ eta (`float`, *optional*, defaults to 0.0):
+ The eta parameter which controls the scale of the variance (0 is DDIM and 1 is one type of DDPM).
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ use_clipped_model_output (`bool`, *optional*, defaults to `None`):
+ if `True` or `False`, see documentation for `DDIMScheduler.step`. If `None`, nothing is passed
+ downstream to the scheduler. So use `None` for schedulers which don't support this argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if `return_dict` is
+ True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images.
+ """
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(strength)
+
+ # 2. Preprocess image
+ image = preprocess(image)
+
+ # 3. set timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=self.device)
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, self.device)
+ latent_timestep = timesteps[:1].repeat(batch_size)
+
+ # 4. Prepare latent variables
+ latents = self.prepare_latents(image, latent_timestep, batch_size, self.unet.dtype, self.device, generator)
+ image = latents
+
+ # 5. Denoising loop
+ for t in self.progress_bar(timesteps):
+ # 1. predict noise model_output
+ model_output = self.unet(image, t).sample
+
+ # 2. predict previous mean of image x_t-1 and add variance depending on eta
+ # eta corresponds to η in paper and should be between [0, 1]
+ # do x_t -> x_t-1
+ image = self.scheduler.step(
+ model_output,
+ t,
+ image,
+ eta=eta,
+ use_clipped_model_output=use_clipped_model_output,
+ generator=generator,
+ ).prev_sample
+
+ image = (image / 2 + 0.5).clamp(0, 1)
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image, latent_timestep.item())
+
+ return ImagePipelineOutput(images=image)
diff --git a/diffusers/examples/community/dps_pipeline.py b/diffusers/examples/community/dps_pipeline.py
new file mode 100644
index 0000000000000000000000000000000000000000..a0bf3e0ad33d0e09e2f1d30ad8a091cd17a2aad9
--- /dev/null
+++ b/diffusers/examples/community/dps_pipeline.py
@@ -0,0 +1,466 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from math import pi
+from typing import Callable, List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+from PIL import Image
+
+from diffusers import DDPMScheduler, DiffusionPipeline, ImagePipelineOutput, UNet2DModel
+from diffusers.utils.torch_utils import randn_tensor
+
+
+class DPSPipeline(DiffusionPipeline):
+ r"""
+ Pipeline for Diffusion Posterior Sampling.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
+
+ Parameters:
+ unet ([`UNet2DModel`]):
+ A `UNet2DModel` to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image. Can be one of
+ [`DDPMScheduler`], or [`DDIMScheduler`].
+ """
+
+ model_cpu_offload_seq = "unet"
+
+ def __init__(self, unet, scheduler):
+ super().__init__()
+ self.register_modules(unet=unet, scheduler=scheduler)
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ measurement: torch.Tensor,
+ operator: torch.nn.Module,
+ loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
+ batch_size: int = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ num_inference_steps: int = 1000,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ zeta: float = 0.3,
+ ) -> Union[ImagePipelineOutput, Tuple]:
+ r"""
+ The call function to the pipeline for generation.
+
+ Args:
+ measurement (`torch.Tensor`, *required*):
+ A 'torch.Tensor', the corrupted image
+ operator (`torch.nn.Module`, *required*):
+ A 'torch.nn.Module', the operator generating the corrupted image
+ loss_fn (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`, *required*):
+ A 'Callable[[torch.Tensor, torch.Tensor], torch.Tensor]', the loss function used
+ between the measurements, for most of the cases using RMSE is fine.
+ batch_size (`int`, *optional*, defaults to 1):
+ The number of images to generate.
+ generator (`torch.Generator`, *optional*):
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
+ generation deterministic.
+ num_inference_steps (`int`, *optional*, defaults to 1000):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
+
+ Example:
+
+ ```py
+ >>> from diffusers import DDPMPipeline
+
+ >>> # load model and scheduler
+ >>> pipe = DDPMPipeline.from_pretrained("google/ddpm-cat-256")
+
+ >>> # run pipeline in inference (sample random noise and denoise)
+ >>> image = pipe().images[0]
+
+ >>> # save image
+ >>> image.save("ddpm_generated_image.png")
+ ```
+
+ Returns:
+ [`~pipelines.ImagePipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is
+ returned where the first element is a list with the generated images
+ """
+ # Sample gaussian noise to begin loop
+ if isinstance(self.unet.config.sample_size, int):
+ image_shape = (
+ batch_size,
+ self.unet.config.in_channels,
+ self.unet.config.sample_size,
+ self.unet.config.sample_size,
+ )
+ else:
+ image_shape = (batch_size, self.unet.config.in_channels, *self.unet.config.sample_size)
+
+ if self.device.type == "mps":
+ # randn does not work reproducibly on mps
+ image = randn_tensor(image_shape, generator=generator)
+ image = image.to(self.device)
+ else:
+ image = randn_tensor(image_shape, generator=generator, device=self.device)
+
+ # set step values
+ self.scheduler.set_timesteps(num_inference_steps)
+
+ for t in self.progress_bar(self.scheduler.timesteps):
+ with torch.enable_grad():
+ # 1. predict noise model_output
+ image = image.requires_grad_()
+ model_output = self.unet(image, t).sample
+
+ # 2. compute previous image x'_{t-1} and original prediction x0_{t}
+ scheduler_out = self.scheduler.step(model_output, t, image, generator=generator)
+ image_pred, origi_pred = scheduler_out.prev_sample, scheduler_out.pred_original_sample
+
+ # 3. compute y'_t = f(x0_{t})
+ measurement_pred = operator(origi_pred)
+
+ # 4. compute loss = d(y, y'_t-1)
+ loss = loss_fn(measurement, measurement_pred)
+ loss.backward()
+
+ print("distance: {0:.4f}".format(loss.item()))
+
+ with torch.no_grad():
+ image_pred = image_pred - zeta * image.grad
+ image = image_pred.detach()
+
+ image = (image / 2 + 0.5).clamp(0, 1)
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image,)
+
+ return ImagePipelineOutput(images=image)
+
+
+if __name__ == "__main__":
+ import scipy
+ from torch import nn
+ from torchvision.utils import save_image
+
+ # defining the operators f(.) of y = f(x)
+ # super-resolution operator
+ class SuperResolutionOperator(nn.Module):
+ def __init__(self, in_shape, scale_factor):
+ super().__init__()
+
+ # Resizer local class, do not use outiside the SR operator class
+ class Resizer(nn.Module):
+ def __init__(self, in_shape, scale_factor=None, output_shape=None, kernel=None, antialiasing=True):
+ super(Resizer, self).__init__()
+
+ # First standardize values and fill missing arguments (if needed) by deriving scale from output shape or vice versa
+ scale_factor, output_shape = self.fix_scale_and_size(in_shape, output_shape, scale_factor)
+
+ # Choose interpolation method, each method has the matching kernel size
+ def cubic(x):
+ absx = np.abs(x)
+ absx2 = absx**2
+ absx3 = absx**3
+ return (1.5 * absx3 - 2.5 * absx2 + 1) * (absx <= 1) + (
+ -0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2
+ ) * ((1 < absx) & (absx <= 2))
+
+ def lanczos2(x):
+ return (
+ (np.sin(pi * x) * np.sin(pi * x / 2) + np.finfo(np.float32).eps)
+ / ((pi**2 * x**2 / 2) + np.finfo(np.float32).eps)
+ ) * (abs(x) < 2)
+
+ def box(x):
+ return ((-0.5 <= x) & (x < 0.5)) * 1.0
+
+ def lanczos3(x):
+ return (
+ (np.sin(pi * x) * np.sin(pi * x / 3) + np.finfo(np.float32).eps)
+ / ((pi**2 * x**2 / 3) + np.finfo(np.float32).eps)
+ ) * (abs(x) < 3)
+
+ def linear(x):
+ return (x + 1) * ((-1 <= x) & (x < 0)) + (1 - x) * ((0 <= x) & (x <= 1))
+
+ method, kernel_width = {
+ "cubic": (cubic, 4.0),
+ "lanczos2": (lanczos2, 4.0),
+ "lanczos3": (lanczos3, 6.0),
+ "box": (box, 1.0),
+ "linear": (linear, 2.0),
+ None: (cubic, 4.0), # set default interpolation method as cubic
+ }.get(kernel)
+
+ # Antialiasing is only used when downscaling
+ antialiasing *= np.any(np.array(scale_factor) < 1)
+
+ # Sort indices of dimensions according to scale of each dimension. since we are going dim by dim this is efficient
+ sorted_dims = np.argsort(np.array(scale_factor))
+ self.sorted_dims = [int(dim) for dim in sorted_dims if scale_factor[dim] != 1]
+
+ # Iterate over dimensions to calculate local weights for resizing and resize each time in one direction
+ field_of_view_list = []
+ weights_list = []
+ for dim in self.sorted_dims:
+ # for each coordinate (along 1 dim), calculate which coordinates in the input image affect its result and the
+ # weights that multiply the values there to get its result.
+ weights, field_of_view = self.contributions(
+ in_shape[dim], output_shape[dim], scale_factor[dim], method, kernel_width, antialiasing
+ )
+
+ # convert to torch tensor
+ weights = torch.tensor(weights.T, dtype=torch.float32)
+
+ # We add singleton dimensions to the weight matrix so we can multiply it with the big tensor we get for
+ # tmp_im[field_of_view.T], (bsxfun style)
+ weights_list.append(
+ nn.Parameter(
+ torch.reshape(weights, list(weights.shape) + (len(scale_factor) - 1) * [1]),
+ requires_grad=False,
+ )
+ )
+ field_of_view_list.append(
+ nn.Parameter(
+ torch.tensor(field_of_view.T.astype(np.int32), dtype=torch.long), requires_grad=False
+ )
+ )
+
+ self.field_of_view = nn.ParameterList(field_of_view_list)
+ self.weights = nn.ParameterList(weights_list)
+
+ def forward(self, in_tensor):
+ x = in_tensor
+
+ # Use the affecting position values and the set of weights to calculate the result of resizing along this 1 dim
+ for dim, fov, w in zip(self.sorted_dims, self.field_of_view, self.weights):
+ # To be able to act on each dim, we swap so that dim 0 is the wanted dim to resize
+ x = torch.transpose(x, dim, 0)
+
+ # This is a bit of a complicated multiplication: x[field_of_view.T] is a tensor of order image_dims+1.
+ # for each pixel in the output-image it matches the positions the influence it from the input image (along 1 dim
+ # only, this is why it only adds 1 dim to 5the shape). We then multiply, for each pixel, its set of positions with
+ # the matching set of weights. we do this by this big tensor element-wise multiplication (MATLAB bsxfun style:
+ # matching dims are multiplied element-wise while singletons mean that the matching dim is all multiplied by the
+ # same number
+ x = torch.sum(x[fov] * w, dim=0)
+
+ # Finally we swap back the axes to the original order
+ x = torch.transpose(x, dim, 0)
+
+ return x
+
+ def fix_scale_and_size(self, input_shape, output_shape, scale_factor):
+ # First fixing the scale-factor (if given) to be standardized the function expects (a list of scale factors in the
+ # same size as the number of input dimensions)
+ if scale_factor is not None:
+ # By default, if scale-factor is a scalar we assume 2d resizing and duplicate it.
+ if np.isscalar(scale_factor) and len(input_shape) > 1:
+ scale_factor = [scale_factor, scale_factor]
+
+ # We extend the size of scale-factor list to the size of the input by assigning 1 to all the unspecified scales
+ scale_factor = list(scale_factor)
+ scale_factor = [1] * (len(input_shape) - len(scale_factor)) + scale_factor
+
+ # Fixing output-shape (if given): extending it to the size of the input-shape, by assigning the original input-size
+ # to all the unspecified dimensions
+ if output_shape is not None:
+ output_shape = list(input_shape[len(output_shape) :]) + list(np.uint(np.array(output_shape)))
+
+ # Dealing with the case of non-give scale-factor, calculating according to output-shape. note that this is
+ # sub-optimal, because there can be different scales to the same output-shape.
+ if scale_factor is None:
+ scale_factor = 1.0 * np.array(output_shape) / np.array(input_shape)
+
+ # Dealing with missing output-shape. calculating according to scale-factor
+ if output_shape is None:
+ output_shape = np.uint(np.ceil(np.array(input_shape) * np.array(scale_factor)))
+
+ return scale_factor, output_shape
+
+ def contributions(self, in_length, out_length, scale, kernel, kernel_width, antialiasing):
+ # This function calculates a set of 'filters' and a set of field_of_view that will later on be applied
+ # such that each position from the field_of_view will be multiplied with a matching filter from the
+ # 'weights' based on the interpolation method and the distance of the sub-pixel location from the pixel centers
+ # around it. This is only done for one dimension of the image.
+
+ # When anti-aliasing is activated (default and only for downscaling) the receptive field is stretched to size of
+ # 1/sf. this means filtering is more 'low-pass filter'.
+ fixed_kernel = (lambda arg: scale * kernel(scale * arg)) if antialiasing else kernel
+ kernel_width *= 1.0 / scale if antialiasing else 1.0
+
+ # These are the coordinates of the output image
+ out_coordinates = np.arange(1, out_length + 1)
+
+ # since both scale-factor and output size can be provided simulatneously, perserving the center of the image requires shifting
+ # the output coordinates. the deviation is because out_length doesn't necesary equal in_length*scale.
+ # to keep the center we need to subtract half of this deivation so that we get equal margins for boths sides and center is preserved.
+ shifted_out_coordinates = out_coordinates - (out_length - in_length * scale) / 2
+
+ # These are the matching positions of the output-coordinates on the input image coordinates.
+ # Best explained by example: say we have 4 horizontal pixels for HR and we downscale by SF=2 and get 2 pixels:
+ # [1,2,3,4] -> [1,2]. Remember each pixel number is the middle of the pixel.
+ # The scaling is done between the distances and not pixel numbers (the right boundary of pixel 4 is transformed to
+ # the right boundary of pixel 2. pixel 1 in the small image matches the boundary between pixels 1 and 2 in the big
+ # one and not to pixel 2. This means the position is not just multiplication of the old pos by scale-factor).
+ # So if we measure distance from the left border, middle of pixel 1 is at distance d=0.5, border between 1 and 2 is
+ # at d=1, and so on (d = p - 0.5). we calculate (d_new = d_old / sf) which means:
+ # (p_new-0.5 = (p_old-0.5) / sf) -> p_new = p_old/sf + 0.5 * (1-1/sf)
+ match_coordinates = shifted_out_coordinates / scale + 0.5 * (1 - 1 / scale)
+
+ # This is the left boundary to start multiplying the filter from, it depends on the size of the filter
+ left_boundary = np.floor(match_coordinates - kernel_width / 2)
+
+ # Kernel width needs to be enlarged because when covering has sub-pixel borders, it must 'see' the pixel centers
+ # of the pixels it only covered a part from. So we add one pixel at each side to consider (weights can zeroize them)
+ expanded_kernel_width = np.ceil(kernel_width) + 2
+
+ # Determine a set of field_of_view for each each output position, these are the pixels in the input image
+ # that the pixel in the output image 'sees'. We get a matrix whos horizontal dim is the output pixels (big) and the
+ # vertical dim is the pixels it 'sees' (kernel_size + 2)
+ field_of_view = np.squeeze(
+ np.int16(np.expand_dims(left_boundary, axis=1) + np.arange(expanded_kernel_width) - 1)
+ )
+
+ # Assign weight to each pixel in the field of view. A matrix whos horizontal dim is the output pixels and the
+ # vertical dim is a list of weights matching to the pixel in the field of view (that are specified in
+ # 'field_of_view')
+ weights = fixed_kernel(1.0 * np.expand_dims(match_coordinates, axis=1) - field_of_view - 1)
+
+ # Normalize weights to sum up to 1. be careful from dividing by 0
+ sum_weights = np.sum(weights, axis=1)
+ sum_weights[sum_weights == 0] = 1.0
+ weights = 1.0 * weights / np.expand_dims(sum_weights, axis=1)
+
+ # We use this mirror structure as a trick for reflection padding at the boundaries
+ mirror = np.uint(np.concatenate((np.arange(in_length), np.arange(in_length - 1, -1, step=-1))))
+ field_of_view = mirror[np.mod(field_of_view, mirror.shape[0])]
+
+ # Get rid of weights and pixel positions that are of zero weight
+ non_zero_out_pixels = np.nonzero(np.any(weights, axis=0))
+ weights = np.squeeze(weights[:, non_zero_out_pixels])
+ field_of_view = np.squeeze(field_of_view[:, non_zero_out_pixels])
+
+ # Final products are the relative positions and the matching weights, both are output_size X fixed_kernel_size
+ return weights, field_of_view
+
+ self.down_sample = Resizer(in_shape, 1 / scale_factor)
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def forward(self, data, **kwargs):
+ return self.down_sample(data)
+
+ # Gaussian blurring operator
+ class GaussialBlurOperator(nn.Module):
+ def __init__(self, kernel_size, intensity):
+ super().__init__()
+
+ class Blurkernel(nn.Module):
+ def __init__(self, blur_type="gaussian", kernel_size=31, std=3.0):
+ super().__init__()
+ self.blur_type = blur_type
+ self.kernel_size = kernel_size
+ self.std = std
+ self.seq = nn.Sequential(
+ nn.ReflectionPad2d(self.kernel_size // 2),
+ nn.Conv2d(3, 3, self.kernel_size, stride=1, padding=0, bias=False, groups=3),
+ )
+ self.weights_init()
+
+ def forward(self, x):
+ return self.seq(x)
+
+ def weights_init(self):
+ if self.blur_type == "gaussian":
+ n = np.zeros((self.kernel_size, self.kernel_size))
+ n[self.kernel_size // 2, self.kernel_size // 2] = 1
+ k = scipy.ndimage.gaussian_filter(n, sigma=self.std)
+ k = torch.from_numpy(k)
+ self.k = k
+ for name, f in self.named_parameters():
+ f.data.copy_(k)
+
+ def update_weights(self, k):
+ if not torch.is_tensor(k):
+ k = torch.from_numpy(k)
+ for name, f in self.named_parameters():
+ f.data.copy_(k)
+
+ def get_kernel(self):
+ return self.k
+
+ self.kernel_size = kernel_size
+ self.conv = Blurkernel(blur_type="gaussian", kernel_size=kernel_size, std=intensity)
+ self.kernel = self.conv.get_kernel()
+ self.conv.update_weights(self.kernel.type(torch.float32))
+
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def forward(self, data, **kwargs):
+ return self.conv(data)
+
+ def transpose(self, data, **kwargs):
+ return data
+
+ def get_kernel(self):
+ return self.kernel.view(1, 1, self.kernel_size, self.kernel_size)
+
+ # assuming the forward process y = f(x) is polluted by Gaussian noise, use l2 norm
+ def RMSELoss(yhat, y):
+ return torch.sqrt(torch.sum((yhat - y) ** 2))
+
+ # set up source image
+ src = Image.open("sample.png")
+ # read image into [1,3,H,W]
+ src = torch.from_numpy(np.array(src, dtype=np.float32)).permute(2, 0, 1)[None]
+ # normalize image to [-1,1]
+ src = (src / 127.5) - 1.0
+ src = src.to("cuda")
+
+ # set up operator and measurement
+ # operator = SuperResolutionOperator(in_shape=src.shape, scale_factor=4).to("cuda")
+ operator = GaussialBlurOperator(kernel_size=61, intensity=3.0).to("cuda")
+ measurement = operator(src)
+
+ # set up scheduler
+ scheduler = DDPMScheduler.from_pretrained("google/ddpm-celebahq-256")
+ scheduler.set_timesteps(1000)
+
+ # set up model
+ model = UNet2DModel.from_pretrained("google/ddpm-celebahq-256").to("cuda")
+
+ save_image((src + 1.0) / 2.0, "dps_src.png")
+ save_image((measurement + 1.0) / 2.0, "dps_mea.png")
+
+ # finally, the pipeline
+ dpspipe = DPSPipeline(model, scheduler)
+ image = dpspipe(
+ measurement=measurement,
+ operator=operator,
+ loss_fn=RMSELoss,
+ zeta=1.0,
+ ).images[0]
+
+ image.save("dps_generated_image.png")
diff --git a/diffusers/examples/community/edict_pipeline.py b/diffusers/examples/community/edict_pipeline.py
new file mode 100644
index 0000000000000000000000000000000000000000..ac977f79abecd281c07e780c76023216afb1a5f6
--- /dev/null
+++ b/diffusers/examples/community/edict_pipeline.py
@@ -0,0 +1,264 @@
+from typing import Optional
+
+import torch
+from PIL import Image
+from tqdm.auto import tqdm
+from transformers import CLIPTextModel, CLIPTokenizer
+
+from diffusers import AutoencoderKL, DDIMScheduler, DiffusionPipeline, UNet2DConditionModel
+from diffusers.image_processor import VaeImageProcessor
+from diffusers.utils import (
+ deprecate,
+)
+
+
+class EDICTPipeline(DiffusionPipeline):
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: DDIMScheduler,
+ mixing_coeff: float = 0.93,
+ leapfrog_steps: bool = True,
+ ):
+ self.mixing_coeff = mixing_coeff
+ self.leapfrog_steps = leapfrog_steps
+
+ super().__init__()
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ )
+
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+
+ def _encode_prompt(
+ self, prompt: str, negative_prompt: Optional[str] = None, do_classifier_free_guidance: bool = False
+ ):
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ prompt_embeds = self.text_encoder(text_inputs.input_ids.to(self.device)).last_hidden_state
+
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=self.device)
+
+ if do_classifier_free_guidance:
+ uncond_tokens = "" if negative_prompt is None else negative_prompt
+
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ negative_prompt_embeds = self.text_encoder(uncond_input.input_ids.to(self.device)).last_hidden_state
+
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+
+ return prompt_embeds
+
+ def denoise_mixing_layer(self, x: torch.Tensor, y: torch.Tensor):
+ x = self.mixing_coeff * x + (1 - self.mixing_coeff) * y
+ y = self.mixing_coeff * y + (1 - self.mixing_coeff) * x
+
+ return [x, y]
+
+ def noise_mixing_layer(self, x: torch.Tensor, y: torch.Tensor):
+ y = (y - (1 - self.mixing_coeff) * x) / self.mixing_coeff
+ x = (x - (1 - self.mixing_coeff) * y) / self.mixing_coeff
+
+ return [x, y]
+
+ def _get_alpha_and_beta(self, t: torch.Tensor):
+ # as self.alphas_cumprod is always in cpu
+ t = int(t)
+
+ alpha_prod = self.scheduler.alphas_cumprod[t] if t >= 0 else self.scheduler.final_alpha_cumprod
+
+ return alpha_prod, 1 - alpha_prod
+
+ def noise_step(
+ self,
+ base: torch.Tensor,
+ model_input: torch.Tensor,
+ model_output: torch.Tensor,
+ timestep: torch.Tensor,
+ ):
+ prev_timestep = timestep - self.scheduler.config.num_train_timesteps / self.scheduler.num_inference_steps
+
+ alpha_prod_t, beta_prod_t = self._get_alpha_and_beta(timestep)
+ alpha_prod_t_prev, beta_prod_t_prev = self._get_alpha_and_beta(prev_timestep)
+
+ a_t = (alpha_prod_t_prev / alpha_prod_t) ** 0.5
+ b_t = -a_t * (beta_prod_t**0.5) + beta_prod_t_prev**0.5
+
+ next_model_input = (base - b_t * model_output) / a_t
+
+ return model_input, next_model_input.to(base.dtype)
+
+ def denoise_step(
+ self,
+ base: torch.Tensor,
+ model_input: torch.Tensor,
+ model_output: torch.Tensor,
+ timestep: torch.Tensor,
+ ):
+ prev_timestep = timestep - self.scheduler.config.num_train_timesteps / self.scheduler.num_inference_steps
+
+ alpha_prod_t, beta_prod_t = self._get_alpha_and_beta(timestep)
+ alpha_prod_t_prev, beta_prod_t_prev = self._get_alpha_and_beta(prev_timestep)
+
+ a_t = (alpha_prod_t_prev / alpha_prod_t) ** 0.5
+ b_t = -a_t * (beta_prod_t**0.5) + beta_prod_t_prev**0.5
+ next_model_input = a_t * base + b_t * model_output
+
+ return model_input, next_model_input.to(base.dtype)
+
+ @torch.no_grad()
+ def decode_latents(self, latents: torch.Tensor):
+ latents = 1 / self.vae.config.scaling_factor * latents
+ image = self.vae.decode(latents).sample
+ image = (image / 2 + 0.5).clamp(0, 1)
+ return image
+
+ @torch.no_grad()
+ def prepare_latents(
+ self,
+ image: Image.Image,
+ text_embeds: torch.Tensor,
+ timesteps: torch.Tensor,
+ guidance_scale: float,
+ generator: Optional[torch.Generator] = None,
+ ):
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ image = image.to(device=self.device, dtype=text_embeds.dtype)
+ latent = self.vae.encode(image).latent_dist.sample(generator)
+
+ latent = self.vae.config.scaling_factor * latent
+
+ coupled_latents = [latent.clone(), latent.clone()]
+
+ for i, t in tqdm(enumerate(timesteps), total=len(timesteps)):
+ coupled_latents = self.noise_mixing_layer(x=coupled_latents[0], y=coupled_latents[1])
+
+ # j - model_input index, k - base index
+ for j in range(2):
+ k = j ^ 1
+
+ if self.leapfrog_steps:
+ if i % 2 == 0:
+ k, j = j, k
+
+ model_input = coupled_latents[j]
+ base = coupled_latents[k]
+
+ latent_model_input = torch.cat([model_input] * 2) if do_classifier_free_guidance else model_input
+
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeds).sample
+
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ base, model_input = self.noise_step(
+ base=base,
+ model_input=model_input,
+ model_output=noise_pred,
+ timestep=t,
+ )
+
+ coupled_latents[k] = model_input
+
+ return coupled_latents
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ base_prompt: str,
+ target_prompt: str,
+ image: Image.Image,
+ guidance_scale: float = 3.0,
+ num_inference_steps: int = 50,
+ strength: float = 0.8,
+ negative_prompt: Optional[str] = None,
+ generator: Optional[torch.Generator] = None,
+ output_type: Optional[str] = "pil",
+ ):
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ image = self.image_processor.preprocess(image)
+
+ base_embeds = self._encode_prompt(base_prompt, negative_prompt, do_classifier_free_guidance)
+ target_embeds = self._encode_prompt(target_prompt, negative_prompt, do_classifier_free_guidance)
+
+ self.scheduler.set_timesteps(num_inference_steps, self.device)
+
+ t_limit = num_inference_steps - int(num_inference_steps * strength)
+ fwd_timesteps = self.scheduler.timesteps[t_limit:]
+ bwd_timesteps = fwd_timesteps.flip(0)
+
+ coupled_latents = self.prepare_latents(image, base_embeds, bwd_timesteps, guidance_scale, generator)
+
+ for i, t in tqdm(enumerate(fwd_timesteps), total=len(fwd_timesteps)):
+ # j - model_input index, k - base index
+ for k in range(2):
+ j = k ^ 1
+
+ if self.leapfrog_steps:
+ if i % 2 == 1:
+ k, j = j, k
+
+ model_input = coupled_latents[j]
+ base = coupled_latents[k]
+
+ latent_model_input = torch.cat([model_input] * 2) if do_classifier_free_guidance else model_input
+
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=target_embeds).sample
+
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ base, model_input = self.denoise_step(
+ base=base,
+ model_input=model_input,
+ model_output=noise_pred,
+ timestep=t,
+ )
+
+ coupled_latents[k] = model_input
+
+ coupled_latents = self.denoise_mixing_layer(x=coupled_latents[0], y=coupled_latents[1])
+
+ # either one is fine
+ final_latent = coupled_latents[0]
+
+ if output_type not in ["latent", "pt", "np", "pil"]:
+ deprecation_message = (
+ f"the output_type {output_type} is outdated. Please make sure to set it to one of these instead: "
+ "`pil`, `np`, `pt`, `latent`"
+ )
+ deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False)
+ output_type = "np"
+
+ if output_type == "latent":
+ image = final_latent
+ else:
+ image = self.decode_latents(final_latent)
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ return image
diff --git a/diffusers/examples/community/fresco_v2v.py b/diffusers/examples/community/fresco_v2v.py
new file mode 100644
index 0000000000000000000000000000000000000000..ab191ecf0d814217da7e204026ad4aed925abcd4
--- /dev/null
+++ b/diffusers/examples/community/fresco_v2v.py
@@ -0,0 +1,2511 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import gc
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import PIL.Image
+import torch
+import torch.nn.functional as F
+import torch.utils.model_zoo
+from einops import rearrange, repeat
+from gmflow.gmflow import GMFlow
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
+
+from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
+from diffusers.loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
+from diffusers.models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel
+from diffusers.models.attention_processor import AttnProcessor2_0
+from diffusers.models.lora import adjust_lora_scale_text_encoder
+from diffusers.models.unets.unet_2d_condition import UNet2DConditionOutput
+from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
+from diffusers.pipelines.controlnet.pipeline_controlnet_img2img import StableDiffusionControlNetImg2ImgPipeline
+from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
+from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
+from diffusers.schedulers import KarrasDiffusionSchedulers
+from diffusers.utils import (
+ USE_PEFT_BACKEND,
+ deprecate,
+ logging,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
+from diffusers.utils.torch_utils import is_compiled_module, randn_tensor
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+def clear_cache():
+ gc.collect()
+ torch.cuda.empty_cache()
+
+
+def coords_grid(b, h, w, homogeneous=False, device=None):
+ y, x = torch.meshgrid(torch.arange(h), torch.arange(w)) # [H, W]
+
+ stacks = [x, y]
+
+ if homogeneous:
+ ones = torch.ones_like(x) # [H, W]
+ stacks.append(ones)
+
+ grid = torch.stack(stacks, dim=0).float() # [2, H, W] or [3, H, W]
+
+ grid = grid[None].repeat(b, 1, 1, 1) # [B, 2, H, W] or [B, 3, H, W]
+
+ if device is not None:
+ grid = grid.to(device)
+
+ return grid
+
+
+def bilinear_sample(img, sample_coords, mode="bilinear", padding_mode="zeros", return_mask=False):
+ # img: [B, C, H, W]
+ # sample_coords: [B, 2, H, W] in image scale
+ if sample_coords.size(1) != 2: # [B, H, W, 2]
+ sample_coords = sample_coords.permute(0, 3, 1, 2)
+
+ b, _, h, w = sample_coords.shape
+
+ # Normalize to [-1, 1]
+ x_grid = 2 * sample_coords[:, 0] / (w - 1) - 1
+ y_grid = 2 * sample_coords[:, 1] / (h - 1) - 1
+
+ grid = torch.stack([x_grid, y_grid], dim=-1) # [B, H, W, 2]
+
+ img = F.grid_sample(img, grid, mode=mode, padding_mode=padding_mode, align_corners=True)
+
+ if return_mask:
+ mask = (x_grid >= -1) & (y_grid >= -1) & (x_grid <= 1) & (y_grid <= 1) # [B, H, W]
+
+ return img, mask
+
+ return img
+
+
+class Dilate:
+ def __init__(self, kernel_size=7, channels=1, device="cpu"):
+ self.kernel_size = kernel_size
+ self.channels = channels
+ gaussian_kernel = torch.ones(1, 1, self.kernel_size, self.kernel_size)
+ gaussian_kernel = gaussian_kernel.repeat(self.channels, 1, 1, 1)
+ self.mean = (self.kernel_size - 1) // 2
+ gaussian_kernel = gaussian_kernel.to(device)
+ self.gaussian_filter = gaussian_kernel
+
+ def __call__(self, x):
+ x = F.pad(x, (self.mean, self.mean, self.mean, self.mean), "replicate")
+ return torch.clamp(F.conv2d(x, self.gaussian_filter, bias=None), 0, 1)
+
+
+def flow_warp(feature, flow, mask=False, mode="bilinear", padding_mode="zeros"):
+ b, c, h, w = feature.size()
+ assert flow.size(1) == 2
+
+ grid = coords_grid(b, h, w).to(flow.device) + flow # [B, 2, H, W]
+ grid = grid.to(feature.dtype)
+ return bilinear_sample(feature, grid, mode=mode, padding_mode=padding_mode, return_mask=mask)
+
+
+def forward_backward_consistency_check(fwd_flow, bwd_flow, alpha=0.01, beta=0.5):
+ # fwd_flow, bwd_flow: [B, 2, H, W]
+ # alpha and beta values are following UnFlow
+ # (https://arxiv.org/abs/1711.07837)
+ assert fwd_flow.dim() == 4 and bwd_flow.dim() == 4
+ assert fwd_flow.size(1) == 2 and bwd_flow.size(1) == 2
+ flow_mag = torch.norm(fwd_flow, dim=1) + torch.norm(bwd_flow, dim=1) # [B, H, W]
+
+ warped_bwd_flow = flow_warp(bwd_flow, fwd_flow) # [B, 2, H, W]
+ warped_fwd_flow = flow_warp(fwd_flow, bwd_flow) # [B, 2, H, W]
+
+ diff_fwd = torch.norm(fwd_flow + warped_bwd_flow, dim=1) # [B, H, W]
+ diff_bwd = torch.norm(bwd_flow + warped_fwd_flow, dim=1)
+
+ threshold = alpha * flow_mag + beta
+
+ fwd_occ = (diff_fwd > threshold).float() # [B, H, W]
+ bwd_occ = (diff_bwd > threshold).float()
+
+ return fwd_occ, bwd_occ
+
+
+def numpy2tensor(img):
+ x0 = torch.from_numpy(img.copy()).float().cuda() / 255.0 * 2.0 - 1.0
+ x0 = torch.stack([x0], dim=0)
+ # einops.rearrange(x0, 'b h w c -> b c h w').clone()
+ return x0.permute(0, 3, 1, 2)
+
+
+def calc_mean_std(feat, eps=1e-5, chunk=1):
+ size = feat.size()
+ assert len(size) == 4
+ if chunk == 2:
+ feat = torch.cat(feat.chunk(2), dim=3)
+ N, C = size[:2]
+ feat_var = feat.view(N // chunk, C, -1).var(dim=2) + eps
+ feat_std = feat_var.sqrt().view(N, C, 1, 1)
+ feat_mean = feat.view(N // chunk, C, -1).mean(dim=2).view(N // chunk, C, 1, 1)
+ return feat_mean.repeat(chunk, 1, 1, 1), feat_std.repeat(chunk, 1, 1, 1)
+
+
+def adaptive_instance_normalization(content_feat, style_feat, chunk=1):
+ assert content_feat.size()[:2] == style_feat.size()[:2]
+ size = content_feat.size()
+ style_mean, style_std = calc_mean_std(style_feat, chunk)
+ content_mean, content_std = calc_mean_std(content_feat)
+
+ normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
+ return normalized_feat * style_std.expand(size) + style_mean.expand(size)
+
+
+def optimize_feature(
+ sample, flows, occs, correlation_matrix=[], intra_weight=1e2, iters=20, unet_chunk_size=2, optimize_temporal=True
+):
+ """
+ FRESO-guided latent feature optimization
+ * optimize spatial correspondence (match correlation_matrix)
+ * optimize temporal correspondence (match warped_image)
+ """
+ if (flows is None or occs is None or (not optimize_temporal)) and (
+ intra_weight == 0 or len(correlation_matrix) == 0
+ ):
+ return sample
+ # flows=[fwd_flows, bwd_flows]: (N-1)*2*H1*W1
+ # occs=[fwd_occs, bwd_occs]: (N-1)*H1*W1
+ # sample: 2N*C*H*W
+ torch.cuda.empty_cache()
+ video_length = sample.shape[0] // unet_chunk_size
+ latent = rearrange(sample.to(torch.float32), "(b f) c h w -> b f c h w", f=video_length)
+
+ cs = torch.nn.Parameter((latent.detach().clone()))
+ optimizer = torch.optim.Adam([cs], lr=0.2)
+
+ # unify resolution
+ if flows is not None and occs is not None:
+ scale = sample.shape[2] * 1.0 / flows[0].shape[2]
+ kernel = int(1 / scale)
+ bwd_flow_ = F.interpolate(flows[1] * scale, scale_factor=scale, mode="bilinear").repeat(
+ unet_chunk_size, 1, 1, 1
+ )
+ bwd_occ_ = F.max_pool2d(occs[1].unsqueeze(1), kernel_size=kernel).repeat(
+ unet_chunk_size, 1, 1, 1
+ ) # 2(N-1)*1*H1*W1
+ fwd_flow_ = F.interpolate(flows[0] * scale, scale_factor=scale, mode="bilinear").repeat(
+ unet_chunk_size, 1, 1, 1
+ )
+ fwd_occ_ = F.max_pool2d(occs[0].unsqueeze(1), kernel_size=kernel).repeat(
+ unet_chunk_size, 1, 1, 1
+ ) # 2(N-1)*1*H1*W1
+ # match frame 0,1,2,3 and frame 1,2,3,0
+ reshuffle_list = list(range(1, video_length)) + [0]
+
+ # attention_probs is the GRAM matrix of the normalized feature
+ attention_probs = None
+ for tmp in correlation_matrix:
+ if sample.shape[2] * sample.shape[3] == tmp.shape[1]:
+ attention_probs = tmp # 2N*HW*HW
+ break
+
+ n_iter = [0]
+ while n_iter[0] < iters:
+
+ def closure():
+ optimizer.zero_grad()
+
+ loss = 0
+
+ # temporal consistency loss
+ if optimize_temporal and flows is not None and occs is not None:
+ c1 = rearrange(cs[:, :], "b f c h w -> (b f) c h w")
+ c2 = rearrange(cs[:, reshuffle_list], "b f c h w -> (b f) c h w")
+ warped_image1 = flow_warp(c1, bwd_flow_)
+ warped_image2 = flow_warp(c2, fwd_flow_)
+ loss = (
+ abs((c2 - warped_image1) * (1 - bwd_occ_)) + abs((c1 - warped_image2) * (1 - fwd_occ_))
+ ).mean() * 2
+
+ # spatial consistency loss
+ if attention_probs is not None and intra_weight > 0:
+ cs_vector = rearrange(cs, "b f c h w -> (b f) (h w) c")
+ # attention_scores = torch.bmm(cs_vector, cs_vector.transpose(-1, -2))
+ # cs_attention_probs = attention_scores.softmax(dim=-1)
+ cs_vector = cs_vector / ((cs_vector**2).sum(dim=2, keepdims=True) ** 0.5)
+ cs_attention_probs = torch.bmm(cs_vector, cs_vector.transpose(-1, -2))
+ tmp = F.l1_loss(cs_attention_probs, attention_probs) * intra_weight
+ loss = tmp + loss
+
+ loss.backward()
+ n_iter[0] += 1
+
+ return loss
+
+ optimizer.step(closure)
+
+ torch.cuda.empty_cache()
+ return adaptive_instance_normalization(rearrange(cs.data.to(sample.dtype), "b f c h w -> (b f) c h w"), sample)
+
+
+@torch.no_grad()
+def warp_tensor(sample, flows, occs, saliency, unet_chunk_size):
+ """
+ Warp images or features based on optical flow
+ Fuse the warped imges or features based on occusion masks and saliency map
+ """
+ scale = sample.shape[2] * 1.0 / flows[0].shape[2]
+ kernel = int(1 / scale)
+ bwd_flow_ = F.interpolate(flows[1] * scale, scale_factor=scale, mode="bilinear")
+ bwd_occ_ = F.max_pool2d(occs[1].unsqueeze(1), kernel_size=kernel) # (N-1)*1*H1*W1
+ if scale == 1:
+ bwd_occ_ = Dilate(kernel_size=13, device=sample.device)(bwd_occ_)
+ fwd_flow_ = F.interpolate(flows[0] * scale, scale_factor=scale, mode="bilinear")
+ fwd_occ_ = F.max_pool2d(occs[0].unsqueeze(1), kernel_size=kernel) # (N-1)*1*H1*W1
+ if scale == 1:
+ fwd_occ_ = Dilate(kernel_size=13, device=sample.device)(fwd_occ_)
+ scale2 = sample.shape[2] * 1.0 / saliency.shape[2]
+ saliency = F.interpolate(saliency, scale_factor=scale2, mode="bilinear")
+ latent = sample.to(torch.float32)
+ video_length = sample.shape[0] // unet_chunk_size
+ warp_saliency = flow_warp(saliency, bwd_flow_)
+ warp_saliency_ = flow_warp(saliency[0:1], fwd_flow_[video_length - 1 : video_length])
+
+ for j in range(unet_chunk_size):
+ for ii in range(video_length - 1):
+ i = video_length * j + ii
+ warped_image = flow_warp(latent[i : i + 1], bwd_flow_[ii : ii + 1])
+ mask = (1 - bwd_occ_[ii : ii + 1]) * saliency[ii + 1 : ii + 2] * warp_saliency[ii : ii + 1]
+ latent[i + 1 : i + 2] = latent[i + 1 : i + 2] * (1 - mask) + warped_image * mask
+ i = video_length * j
+ ii = video_length - 1
+ warped_image = flow_warp(latent[i : i + 1], fwd_flow_[ii : ii + 1])
+ mask = (1 - fwd_occ_[ii : ii + 1]) * saliency[ii : ii + 1] * warp_saliency_
+ latent[ii + i : ii + i + 1] = latent[ii + i : ii + i + 1] * (1 - mask) + warped_image * mask
+
+ return latent.to(sample.dtype)
+
+
+def my_forward(
+ self,
+ steps=[],
+ layers=[0, 1, 2, 3],
+ flows=None,
+ occs=None,
+ correlation_matrix=[],
+ intra_weight=1e2,
+ iters=20,
+ optimize_temporal=True,
+ saliency=None,
+):
+ """
+ Hacked pipe.unet.forward()
+ copied from https://github.com/huggingface/diffusers/blob/v0.19.3/src/diffusers/models/unet_2d_condition.py#L700
+ if you are using a new version of diffusers, please copy the source code and modify it accordingly (find [HACK] in the code)
+ * restore and return the decoder features
+ * optimize the decoder features
+ * perform background smoothing
+ """
+
+ def forward(
+ sample: torch.FloatTensor,
+ timestep: Union[torch.Tensor, float, int],
+ encoder_hidden_states: torch.Tensor,
+ class_labels: Optional[torch.Tensor] = None,
+ timestep_cond: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ return_dict: bool = True,
+ ) -> Union[UNet2DConditionOutput, Tuple]:
+ r"""
+ The [`UNet2DConditionModel`] forward method.
+
+ Args:
+ sample (`torch.FloatTensor`):
+ The noisy input tensor with the following shape `(batch, channel, height, width)`.
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
+ encoder_hidden_states (`torch.FloatTensor`):
+ The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
+ encoder_attention_mask (`torch.Tensor`):
+ A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
+ `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
+ which adds large negative values to the attention scores corresponding to "discard" tokens.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
+ tuple.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
+ added_cond_kwargs: (`dict`, *optional*):
+ A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
+ are passed along to the UNet blocks.
+
+ Returns:
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
+ If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
+ a `tuple` is returned where the first element is the sample tensor.
+ """
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
+ # on the fly if necessary.
+ default_overall_up_factor = 2**self.num_upsamplers
+
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
+ forward_upsample_size = False
+ upsample_size = None
+
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
+ logger.info("Forward upsample size to force interpolation output size.")
+ forward_upsample_size = True
+
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
+ # expects mask of shape:
+ # [batch, key_tokens]
+ # adds singleton query_tokens dimension:
+ # [batch, 1, key_tokens]
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
+ if attention_mask is not None:
+ # assume that mask is expressed as:
+ # (1 = keep, 0 = discard)
+ # convert mask into a bias that can be added to attention scores:
+ # (keep = +0, discard = -10000.0)
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
+ attention_mask = attention_mask.unsqueeze(1)
+
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
+ if encoder_attention_mask is not None:
+ encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
+
+ # 0. center input if necessary
+ if self.config.center_input_sample:
+ sample = 2 * sample - 1.0
+
+ # 1. time
+ timesteps = timestep
+ if not torch.is_tensor(timesteps):
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
+ # This would be a good case for the `match` statement (Python 3.10+)
+ is_mps = sample.device.type == "mps"
+ if isinstance(timestep, float):
+ dtype = torch.float32 if is_mps else torch.float64
+ else:
+ dtype = torch.int32 if is_mps else torch.int64
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
+ elif len(timesteps.shape) == 0:
+ timesteps = timesteps[None].to(sample.device)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timesteps = timesteps.expand(sample.shape[0])
+
+ t_emb = self.time_proj(timesteps)
+
+ # `Timesteps` does not contain any weights and will always return f32 tensors
+ # but time_embedding might actually be running in fp16. so we need to cast here.
+ # there might be better ways to encapsulate this.
+ t_emb = t_emb.to(dtype=sample.dtype)
+
+ emb = self.time_embedding(t_emb, timestep_cond)
+ aug_emb = None
+
+ if self.class_embedding is not None:
+ if class_labels is None:
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
+
+ if self.config.class_embed_type == "timestep":
+ class_labels = self.time_proj(class_labels)
+
+ # `Timesteps` does not contain any weights and will always return f32 tensors
+ # there might be better ways to encapsulate this.
+ class_labels = class_labels.to(dtype=sample.dtype)
+
+ class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
+
+ if self.config.class_embeddings_concat:
+ emb = torch.cat([emb, class_emb], dim=-1)
+ else:
+ emb = emb + class_emb
+
+ if self.config.addition_embed_type == "text":
+ aug_emb = self.add_embedding(encoder_hidden_states)
+ elif self.config.addition_embed_type == "text_image":
+ # Kandinsky 2.1 - style
+ if "image_embeds" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
+ )
+
+ image_embs = added_cond_kwargs.get("image_embeds")
+ text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
+ aug_emb = self.add_embedding(text_embs, image_embs)
+ elif self.config.addition_embed_type == "text_time":
+ # SDXL - style
+ if "text_embeds" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
+ )
+ text_embeds = added_cond_kwargs.get("text_embeds")
+ if "time_ids" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
+ )
+ time_ids = added_cond_kwargs.get("time_ids")
+ time_embeds = self.add_time_proj(time_ids.flatten())
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
+
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
+ add_embeds = add_embeds.to(emb.dtype)
+ aug_emb = self.add_embedding(add_embeds)
+ elif self.config.addition_embed_type == "image":
+ # Kandinsky 2.2 - style
+ if "image_embeds" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
+ )
+ image_embs = added_cond_kwargs.get("image_embeds")
+ aug_emb = self.add_embedding(image_embs)
+ elif self.config.addition_embed_type == "image_hint":
+ # Kandinsky 2.2 - style
+ if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
+ )
+ image_embs = added_cond_kwargs.get("image_embeds")
+ hint = added_cond_kwargs.get("hint")
+ aug_emb, hint = self.add_embedding(image_embs, hint)
+ sample = torch.cat([sample, hint], dim=1)
+
+ emb = emb + aug_emb if aug_emb is not None else emb
+
+ if self.time_embed_act is not None:
+ emb = self.time_embed_act(emb)
+
+ if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
+ # Kadinsky 2.1 - style
+ if "image_embeds" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
+ )
+
+ image_embeds = added_cond_kwargs.get("image_embeds")
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
+ # Kandinsky 2.2 - style
+ if "image_embeds" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
+ )
+ image_embeds = added_cond_kwargs.get("image_embeds")
+ encoder_hidden_states = self.encoder_hid_proj(image_embeds)
+ # 2. pre-process
+ sample = self.conv_in(sample)
+
+ # 3. down
+
+ is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
+ is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None
+
+ down_block_res_samples = (sample,)
+ for downsample_block in self.down_blocks:
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
+ # For t2i-adapter CrossAttnDownBlock2D
+ additional_residuals = {}
+ if is_adapter and len(down_block_additional_residuals) > 0:
+ additional_residuals["additional_residuals"] = down_block_additional_residuals.pop(0)
+
+ sample, res_samples = downsample_block(
+ hidden_states=sample,
+ temb=emb,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ cross_attention_kwargs=cross_attention_kwargs,
+ encoder_attention_mask=encoder_attention_mask,
+ **additional_residuals,
+ )
+ else:
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
+
+ if is_adapter and len(down_block_additional_residuals) > 0:
+ sample += down_block_additional_residuals.pop(0)
+ down_block_res_samples += res_samples
+
+ if is_controlnet:
+ new_down_block_res_samples = ()
+
+ for down_block_res_sample, down_block_additional_residual in zip(
+ down_block_res_samples, down_block_additional_residuals
+ ):
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
+ new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
+
+ down_block_res_samples = new_down_block_res_samples
+
+ # 4. mid
+ if self.mid_block is not None:
+ sample = self.mid_block(
+ sample,
+ emb,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ cross_attention_kwargs=cross_attention_kwargs,
+ encoder_attention_mask=encoder_attention_mask,
+ )
+
+ if is_controlnet:
+ sample = sample + mid_block_additional_residual
+
+ # 5. up
+ """
+ [HACK] restore the decoder features in up_samples
+ """
+ up_samples = ()
+ # down_samples = ()
+ for i, upsample_block in enumerate(self.up_blocks):
+ is_final_block = i == len(self.up_blocks) - 1
+
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
+
+ """
+ [HACK] restore the decoder features in up_samples
+ [HACK] optimize the decoder features
+ [HACK] perform background smoothing
+ """
+ if i in layers:
+ up_samples += (sample,)
+ if timestep in steps and i in layers:
+ sample = optimize_feature(
+ sample, flows, occs, correlation_matrix, intra_weight, iters, optimize_temporal=optimize_temporal
+ )
+ if saliency is not None:
+ sample = warp_tensor(sample, flows, occs, saliency, 2)
+
+ # if we have not reached the final block and need to forward the
+ # upsample size, we do it here
+ if not is_final_block and forward_upsample_size:
+ upsample_size = down_block_res_samples[-1].shape[2:]
+
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
+ sample = upsample_block(
+ hidden_states=sample,
+ temb=emb,
+ res_hidden_states_tuple=res_samples,
+ encoder_hidden_states=encoder_hidden_states,
+ cross_attention_kwargs=cross_attention_kwargs,
+ upsample_size=upsample_size,
+ attention_mask=attention_mask,
+ encoder_attention_mask=encoder_attention_mask,
+ )
+ else:
+ sample = upsample_block(
+ hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
+ )
+
+ # 6. post-process
+ if self.conv_norm_out:
+ sample = self.conv_norm_out(sample)
+ sample = self.conv_act(sample)
+ sample = self.conv_out(sample)
+
+ """
+ [HACK] return the output feature as well as the decoder features
+ """
+ if not return_dict:
+ return (sample,) + up_samples
+
+ return UNet2DConditionOutput(sample=sample)
+
+ return forward
+
+
+@torch.no_grad()
+def get_single_mapping_ind(bwd_flow, bwd_occ, imgs, scale=1.0):
+ """
+ FLATTEN: Optical fLow-guided attention (Temoporal-guided attention)
+ Find the correspondence between every pixels in a pair of frames
+
+ [input]
+ bwd_flow: 1*2*H*W
+ bwd_occ: 1*H*W i.e., f2 = warp(f1, bwd_flow) * bwd_occ
+ imgs: 2*3*H*W i.e., [f1,f2]
+
+ [output]
+ mapping_ind: pixel index correspondence
+ unlinkedmask: indicate whether a pixel has no correspondence
+ i.e., f2 = f1[mapping_ind] * unlinkedmask
+ """
+ flows = F.interpolate(bwd_flow, scale_factor=1.0 / scale, mode="bilinear")[0][[1, 0]] / scale # 2*H*W
+ _, H, W = flows.shape
+ masks = torch.logical_not(F.interpolate(bwd_occ[None], scale_factor=1.0 / scale, mode="bilinear") > 0.5)[
+ 0
+ ] # 1*H*W
+ frames = F.interpolate(imgs, scale_factor=1.0 / scale, mode="bilinear").view(2, 3, -1) # 2*3*HW
+ grid = torch.stack(torch.meshgrid([torch.arange(H), torch.arange(W)]), dim=0).to(flows.device) # 2*H*W
+ warp_grid = torch.round(grid + flows)
+ mask = torch.logical_and(
+ torch.logical_and(
+ torch.logical_and(torch.logical_and(warp_grid[0] >= 0, warp_grid[0] < H), warp_grid[1] >= 0),
+ warp_grid[1] < W,
+ ),
+ masks[0],
+ ).view(-1) # HW
+ warp_grid = warp_grid.view(2, -1) # 2*HW
+ warp_ind = (warp_grid[0] * W + warp_grid[1]).to(torch.long) # HW
+ mapping_ind = torch.zeros_like(warp_ind) - 1 # HW
+
+ for f0ind, f1ind in enumerate(warp_ind):
+ if mask[f0ind]:
+ if mapping_ind[f1ind] == -1:
+ mapping_ind[f1ind] = f0ind
+ else:
+ targetv = frames[0, :, f1ind]
+ pref0ind = mapping_ind[f1ind]
+ prev = frames[1, :, pref0ind]
+ v = frames[1, :, f0ind]
+ if ((prev - targetv) ** 2).mean() > ((v - targetv) ** 2).mean():
+ mask[pref0ind] = False
+ mapping_ind[f1ind] = f0ind
+ else:
+ mask[f0ind] = False
+
+ unusedind = torch.arange(len(mask)).to(mask.device)[~mask]
+ unlinkedmask = mapping_ind == -1
+ mapping_ind[unlinkedmask] = unusedind
+ return mapping_ind, unlinkedmask
+
+
+@torch.no_grad()
+def get_mapping_ind(bwd_flows, bwd_occs, imgs, scale=1.0):
+ """
+ FLATTEN: Optical fLow-guided attention (Temoporal-guided attention)
+ Find pixel correspondence between every consecutive frames in a batch
+
+ [input]
+ bwd_flow: (N-1)*2*H*W
+ bwd_occ: (N-1)*H*W
+ imgs: N*3*H*W
+
+ [output]
+ fwd_mappings: N*1*HW
+ bwd_mappings: N*1*HW
+ flattn_mask: HW*1*N*N
+ i.e., imgs[i,:,fwd_mappings[i]] corresponds to imgs[0]
+ i.e., imgs[i,:,fwd_mappings[i]][:,bwd_mappings[i]] restore the original imgs[i]
+ """
+ N, H, W = imgs.shape[0], int(imgs.shape[2] // scale), int(imgs.shape[3] // scale)
+ iterattn_mask = torch.ones(H * W, N, N, dtype=torch.bool).to(imgs.device)
+ for i in range(len(imgs) - 1):
+ one_mask = torch.ones(N, N, dtype=torch.bool).to(imgs.device)
+ one_mask[: i + 1, i + 1 :] = False
+ one_mask[i + 1 :, : i + 1] = False
+ mapping_ind, unlinkedmask = get_single_mapping_ind(
+ bwd_flows[i : i + 1], bwd_occs[i : i + 1], imgs[i : i + 2], scale
+ )
+ if i == 0:
+ fwd_mapping = [torch.arange(len(mapping_ind)).to(mapping_ind.device)]
+ bwd_mapping = [torch.arange(len(mapping_ind)).to(mapping_ind.device)]
+ iterattn_mask[unlinkedmask[fwd_mapping[-1]]] = torch.logical_and(
+ iterattn_mask[unlinkedmask[fwd_mapping[-1]]], one_mask
+ )
+ fwd_mapping += [mapping_ind[fwd_mapping[-1]]]
+ bwd_mapping += [torch.sort(fwd_mapping[-1])[1]]
+ fwd_mappings = torch.stack(fwd_mapping, dim=0).unsqueeze(1)
+ bwd_mappings = torch.stack(bwd_mapping, dim=0).unsqueeze(1)
+ return fwd_mappings, bwd_mappings, iterattn_mask.unsqueeze(1)
+
+
+def apply_FRESCO_opt(
+ pipe,
+ steps=[],
+ layers=[0, 1, 2, 3],
+ flows=None,
+ occs=None,
+ correlation_matrix=[],
+ intra_weight=1e2,
+ iters=20,
+ optimize_temporal=True,
+ saliency=None,
+):
+ """
+ Apply FRESCO-based optimization to a StableDiffusionPipeline
+ """
+ pipe.unet.forward = my_forward(
+ pipe.unet, steps, layers, flows, occs, correlation_matrix, intra_weight, iters, optimize_temporal, saliency
+ )
+
+
+@torch.no_grad()
+def get_intraframe_paras(pipe, imgs, frescoProc, prompt_embeds, do_classifier_free_guidance=True, generator=None):
+ """
+ Get parameters for spatial-guided attention and optimization
+ * perform one step denoising
+ * collect attention feature, stored in frescoProc.controller.stored_attn['decoder_attn']
+ * compute the gram matrix of the normalized feature for spatial consistency loss
+ """
+
+ noise_scheduler = pipe.scheduler
+ timestep = noise_scheduler.timesteps[-1]
+ device = pipe._execution_device
+ B, C, H, W = imgs.shape
+
+ frescoProc.controller.disable_controller()
+ apply_FRESCO_opt(pipe)
+ frescoProc.controller.clear_store()
+ frescoProc.controller.enable_store()
+
+ latents = pipe.prepare_latents(
+ imgs.to(pipe.unet.dtype), timestep, B, 1, prompt_embeds.dtype, device, generator=generator, repeat_noise=False
+ )
+
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ model_output = pipe.unet(
+ latent_model_input,
+ timestep,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs=None,
+ return_dict=False,
+ )
+
+ frescoProc.controller.disable_store()
+
+ # gram matrix of the normalized feature for spatial consistency loss
+ correlation_matrix = []
+ for tmp in model_output[1:]:
+ latent_vector = rearrange(tmp, "b c h w -> b (h w) c")
+ latent_vector = latent_vector / ((latent_vector**2).sum(dim=2, keepdims=True) ** 0.5)
+ attention_probs = torch.bmm(latent_vector, latent_vector.transpose(-1, -2))
+ correlation_matrix += [attention_probs.detach().clone().to(torch.float32)]
+ del attention_probs, latent_vector, tmp
+ del model_output
+
+ clear_cache()
+
+ return correlation_matrix
+
+
+@torch.no_grad()
+def get_flow_and_interframe_paras(flow_model, imgs):
+ """
+ Get parameters for temporal-guided attention and optimization
+ * predict optical flow and occlusion mask
+ * compute pixel index correspondence for FLATTEN
+ """
+ images = torch.stack([torch.from_numpy(img).permute(2, 0, 1).float() for img in imgs], dim=0).cuda()
+ imgs_torch = torch.cat([numpy2tensor(img) for img in imgs], dim=0)
+
+ reshuffle_list = list(range(1, len(images))) + [0]
+
+ results_dict = flow_model(
+ images,
+ images[reshuffle_list],
+ attn_splits_list=[2],
+ corr_radius_list=[-1],
+ prop_radius_list=[-1],
+ pred_bidir_flow=True,
+ )
+ flow_pr = results_dict["flow_preds"][-1] # [2*B, 2, H, W]
+ fwd_flows, bwd_flows = flow_pr.chunk(2) # [B, 2, H, W]
+ fwd_occs, bwd_occs = forward_backward_consistency_check(fwd_flows, bwd_flows) # [B, H, W]
+
+ warped_image1 = flow_warp(images, bwd_flows)
+ bwd_occs = torch.clamp(
+ bwd_occs + (abs(images[reshuffle_list] - warped_image1).mean(dim=1) > 255 * 0.25).float(), 0, 1
+ )
+
+ warped_image2 = flow_warp(images[reshuffle_list], fwd_flows)
+ fwd_occs = torch.clamp(fwd_occs + (abs(images - warped_image2).mean(dim=1) > 255 * 0.25).float(), 0, 1)
+
+ attn_mask = []
+ for scale in [8.0, 16.0, 32.0]:
+ bwd_occs_ = F.interpolate(bwd_occs[:-1].unsqueeze(1), scale_factor=1.0 / scale, mode="bilinear")
+ attn_mask += [
+ torch.cat((bwd_occs_[0:1].reshape(1, -1) > -1, bwd_occs_.reshape(bwd_occs_.shape[0], -1) > 0.5), dim=0)
+ ]
+
+ fwd_mappings = []
+ bwd_mappings = []
+ interattn_masks = []
+ for scale in [8.0, 16.0]:
+ fwd_mapping, bwd_mapping, interattn_mask = get_mapping_ind(bwd_flows, bwd_occs, imgs_torch, scale=scale)
+ fwd_mappings += [fwd_mapping]
+ bwd_mappings += [bwd_mapping]
+ interattn_masks += [interattn_mask]
+
+ interattn_paras = {}
+ interattn_paras["fwd_mappings"] = fwd_mappings
+ interattn_paras["bwd_mappings"] = bwd_mappings
+ interattn_paras["interattn_masks"] = interattn_masks
+
+ clear_cache()
+
+ return [fwd_flows, bwd_flows], [fwd_occs, bwd_occs], attn_mask, interattn_paras
+
+
+class AttentionControl:
+ """
+ Control FRESCO-based attention
+ * enable/diable spatial-guided attention
+ * enable/diable temporal-guided attention
+ * enable/diable cross-frame attention
+ * collect intermediate attention feature (for spatial-guided attention)
+ """
+
+ def __init__(self):
+ self.stored_attn = self.get_empty_store()
+ self.store = False
+ self.index = 0
+ self.attn_mask = None
+ self.interattn_paras = None
+ self.use_interattn = False
+ self.use_cfattn = False
+ self.use_intraattn = False
+ self.intraattn_bias = 0
+ self.intraattn_scale_factor = 0.2
+ self.interattn_scale_factor = 0.2
+
+ @staticmethod
+ def get_empty_store():
+ return {
+ "decoder_attn": [],
+ }
+
+ def clear_store(self):
+ del self.stored_attn
+ torch.cuda.empty_cache()
+ gc.collect()
+ self.stored_attn = self.get_empty_store()
+ self.disable_intraattn()
+
+ # store attention feature of the input frame for spatial-guided attention
+ def enable_store(self):
+ self.store = True
+
+ def disable_store(self):
+ self.store = False
+
+ # spatial-guided attention
+ def enable_intraattn(self):
+ self.index = 0
+ self.use_intraattn = True
+ self.disable_store()
+ if len(self.stored_attn["decoder_attn"]) == 0:
+ self.use_intraattn = False
+
+ def disable_intraattn(self):
+ self.index = 0
+ self.use_intraattn = False
+ self.disable_store()
+
+ def disable_cfattn(self):
+ self.use_cfattn = False
+
+ # cross frame attention
+ def enable_cfattn(self, attn_mask=None):
+ if attn_mask:
+ if self.attn_mask:
+ del self.attn_mask
+ torch.cuda.empty_cache()
+ self.attn_mask = attn_mask
+ self.use_cfattn = True
+ else:
+ if self.attn_mask:
+ self.use_cfattn = True
+ else:
+ print("Warning: no valid cross-frame attention parameters available!")
+ self.disable_cfattn()
+
+ def disable_interattn(self):
+ self.use_interattn = False
+
+ # temporal-guided attention
+ def enable_interattn(self, interattn_paras=None):
+ if interattn_paras:
+ if self.interattn_paras:
+ del self.interattn_paras
+ torch.cuda.empty_cache()
+ self.interattn_paras = interattn_paras
+ self.use_interattn = True
+ else:
+ if self.interattn_paras:
+ self.use_interattn = True
+ else:
+ print("Warning: no valid temporal-guided attention parameters available!")
+ self.disable_interattn()
+
+ def disable_controller(self):
+ self.disable_intraattn()
+ self.disable_interattn()
+ self.disable_cfattn()
+
+ def enable_controller(self, interattn_paras=None, attn_mask=None):
+ self.enable_intraattn()
+ self.enable_interattn(interattn_paras)
+ self.enable_cfattn(attn_mask)
+
+ def forward(self, context):
+ if self.store:
+ self.stored_attn["decoder_attn"].append(context.detach())
+ if self.use_intraattn and len(self.stored_attn["decoder_attn"]) > 0:
+ tmp = self.stored_attn["decoder_attn"][self.index]
+ self.index = self.index + 1
+ if self.index >= len(self.stored_attn["decoder_attn"]):
+ self.index = 0
+ self.disable_store()
+ return tmp
+ return context
+
+ def __call__(self, context):
+ context = self.forward(context)
+ return context
+
+
+class FRESCOAttnProcessor2_0:
+ """
+ Hack self attention to FRESCO-based attention
+ * adding spatial-guided attention
+ * adding temporal-guided attention
+ * adding cross-frame attention
+
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
+ Usage
+ frescoProc = FRESCOAttnProcessor2_0(2, attn_mask)
+ attnProc = AttnProcessor2_0()
+
+ attn_processor_dict = {}
+ for k in pipe.unet.attn_processors.keys():
+ if k.startswith("up_blocks.2") or k.startswith("up_blocks.3"):
+ attn_processor_dict[k] = frescoProc
+ else:
+ attn_processor_dict[k] = attnProc
+ pipe.unet.set_attn_processor(attn_processor_dict)
+ """
+
+ def __init__(self, unet_chunk_size=2, controller=None):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
+ self.unet_chunk_size = unet_chunk_size
+ self.controller = controller
+
+ def __call__(
+ self,
+ attn,
+ hidden_states,
+ encoder_hidden_states=None,
+ attention_mask=None,
+ temb=None,
+ ):
+ residual = hidden_states
+
+ if attn.spatial_norm is not None:
+ hidden_states = attn.spatial_norm(hidden_states, temb)
+
+ input_ndim = hidden_states.ndim
+
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+
+ if attention_mask is not None:
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+ # scaled_dot_product_attention expects attention_mask shape to be
+ # (batch, heads, source_length, target_length)
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
+
+ if attn.group_norm is not None:
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = attn.to_q(hidden_states)
+
+ crossattn = False
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ if self.controller and self.controller.store:
+ self.controller(hidden_states.detach().clone())
+ else:
+ crossattn = True
+ if attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ # BC * HW * 8D
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ query_raw, key_raw = None, None
+ if self.controller and self.controller.use_interattn and (not crossattn):
+ query_raw, key_raw = query.clone(), key.clone()
+
+ inner_dim = key.shape[-1] # 8D
+ head_dim = inner_dim // attn.heads # D
+
+ """for efficient cross-frame attention"""
+ if self.controller and self.controller.use_cfattn and (not crossattn):
+ video_length = key.size()[0] // self.unet_chunk_size
+ former_frame_index = [0] * video_length
+ attn_mask = None
+ if self.controller.attn_mask is not None:
+ for m in self.controller.attn_mask:
+ if m.shape[1] == key.shape[1]:
+ attn_mask = m
+ # BC * HW * 8D --> B * C * HW * 8D
+ key = rearrange(key, "(b f) d c -> b f d c", f=video_length)
+ # B * C * HW * 8D --> B * C * HW * 8D
+ if attn_mask is None:
+ key = key[:, former_frame_index]
+ else:
+ key = repeat(key[:, attn_mask], "b d c -> b f d c", f=video_length)
+ # B * C * HW * 8D --> BC * HW * 8D
+ key = rearrange(key, "b f d c -> (b f) d c").detach()
+ value = rearrange(value, "(b f) d c -> b f d c", f=video_length)
+ if attn_mask is None:
+ value = value[:, former_frame_index]
+ else:
+ value = repeat(value[:, attn_mask], "b d c -> b f d c", f=video_length)
+ value = rearrange(value, "b f d c -> (b f) d c").detach()
+
+ # BC * HW * 8D --> BC * HW * 8 * D --> BC * 8 * HW * D
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ # BC * 8 * HW2 * D
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ # BC * 8 * HW2 * D2
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ """for spatial-guided intra-frame attention"""
+ if self.controller and self.controller.use_intraattn and (not crossattn):
+ ref_hidden_states = self.controller(None)
+ assert ref_hidden_states.shape == encoder_hidden_states.shape
+ query_ = attn.to_q(ref_hidden_states)
+ key_ = attn.to_k(ref_hidden_states)
+
+ # BC * 8 * HW * D
+ query_ = query_.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ key_ = key_.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ query = F.scaled_dot_product_attention(
+ query_,
+ key_ * self.controller.intraattn_scale_factor,
+ query,
+ attn_mask=torch.eye(query_.size(-2), key_.size(-2), dtype=query.dtype, device=query.device)
+ * self.controller.intraattn_bias,
+ ).detach()
+
+ del query_, key_
+ torch.cuda.empty_cache()
+
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
+ # TODO: add support for attn.scale when we move to Torch 2.1
+ # output: BC * 8 * HW * D2
+ hidden_states = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ )
+
+ """for temporal-guided inter-frame attention (FLATTEN)"""
+ if self.controller and self.controller.use_interattn and (not crossattn):
+ del query, key, value
+ torch.cuda.empty_cache()
+ bwd_mapping = None
+ fwd_mapping = None
+ for i, f in enumerate(self.controller.interattn_paras["fwd_mappings"]):
+ if f.shape[2] == hidden_states.shape[2]:
+ fwd_mapping = f
+ bwd_mapping = self.controller.interattn_paras["bwd_mappings"][i]
+ interattn_mask = self.controller.interattn_paras["interattn_masks"][i]
+ video_length = key_raw.size()[0] // self.unet_chunk_size
+ # BC * HW * 8D --> C * 8BD * HW
+ key = rearrange(key_raw, "(b f) d c -> f (b c) d", f=video_length)
+ query = rearrange(query_raw, "(b f) d c -> f (b c) d", f=video_length)
+ # BC * 8 * HW * D --> C * 8BD * HW
+ # key = rearrange(hidden_states, "(b f) h d c -> f (b h c) d", f=video_length) ########
+ # query = rearrange(hidden_states, "(b f) h d c -> f (b h c) d", f=video_length) #######
+
+ value = rearrange(hidden_states, "(b f) h d c -> f (b h c) d", f=video_length)
+ key = torch.gather(key, 2, fwd_mapping.expand(-1, key.shape[1], -1))
+ query = torch.gather(query, 2, fwd_mapping.expand(-1, query.shape[1], -1))
+ value = torch.gather(value, 2, fwd_mapping.expand(-1, value.shape[1], -1))
+ # C * 8BD * HW --> BHW, C, 8D
+ key = rearrange(key, "f (b c) d -> (b d) f c", b=self.unet_chunk_size)
+ query = rearrange(query, "f (b c) d -> (b d) f c", b=self.unet_chunk_size)
+ value = rearrange(value, "f (b c) d -> (b d) f c", b=self.unet_chunk_size)
+ # BHW * C * 8D --> BHW * C * 8 * D--> BHW * 8 * C * D
+ query = query.view(-1, video_length, attn.heads, head_dim).transpose(1, 2).detach()
+ key = key.view(-1, video_length, attn.heads, head_dim).transpose(1, 2).detach()
+ value = value.view(-1, video_length, attn.heads, head_dim).transpose(1, 2).detach()
+ hidden_states_ = F.scaled_dot_product_attention(
+ query,
+ key * self.controller.interattn_scale_factor,
+ value,
+ # .to(query.dtype)-1.0) * 1e6 -
+ attn_mask=(interattn_mask.repeat(self.unet_chunk_size, 1, 1, 1)),
+ # torch.eye(interattn_mask.shape[2]).to(query.device).to(query.dtype) * 1e4,
+ )
+
+ # BHW * 8 * C * D --> C * 8BD * HW
+ hidden_states_ = rearrange(hidden_states_, "(b d) h f c -> f (b h c) d", b=self.unet_chunk_size)
+ hidden_states_ = torch.gather(
+ hidden_states_, 2, bwd_mapping.expand(-1, hidden_states_.shape[1], -1)
+ ).detach()
+ # C * 8BD * HW --> BC * 8 * HW * D
+ hidden_states = rearrange(
+ hidden_states_, "f (b h c) d -> (b f) h d c", b=self.unet_chunk_size, h=attn.heads
+ )
+
+ # BC * 8 * HW * D --> BC * HW * 8D
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states = hidden_states.to(query.dtype)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ return hidden_states
+
+
+def apply_FRESCO_attn(pipe):
+ """
+ Apply FRESCO-guided attention to a StableDiffusionPipeline
+ """
+ frescoProc = FRESCOAttnProcessor2_0(2, AttentionControl())
+ attnProc = AttnProcessor2_0()
+ attn_processor_dict = {}
+ for k in pipe.unet.attn_processors.keys():
+ if k.startswith("up_blocks.2") or k.startswith("up_blocks.3"):
+ attn_processor_dict[k] = frescoProc
+ else:
+ attn_processor_dict[k] = attnProc
+ pipe.unet.set_attn_processor(attn_processor_dict)
+ return frescoProc
+
+
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+def prepare_image(image):
+ if isinstance(image, torch.Tensor):
+ # Batch single image
+ if image.ndim == 3:
+ image = image.unsqueeze(0)
+
+ image = image.to(dtype=torch.float32)
+ else:
+ # preprocess image
+ if isinstance(image, (PIL.Image.Image, np.ndarray)):
+ image = [image]
+
+ if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):
+ image = [np.array(i.convert("RGB"))[None, :] for i in image]
+ image = np.concatenate(image, axis=0)
+ elif isinstance(image, list) and isinstance(image[0], np.ndarray):
+ image = np.concatenate([i[None, :] for i in image], axis=0)
+
+ image = image.transpose(0, 3, 1, 2)
+ image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
+
+ return image
+
+
+class FrescoV2VPipeline(StableDiffusionControlNetImg2ImgPipeline):
+ r"""
+ Pipeline for video-to-video translation using Stable Diffusion with FRESCO Algorithm.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
+
+ The pipeline also inherits the following loading methods:
+ - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
+ - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
+ - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
+ - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
+ - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
+ text_encoder ([`~transformers.CLIPTextModel`]):
+ Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
+ tokenizer ([`~transformers.CLIPTokenizer`]):
+ A `CLIPTokenizer` to tokenize text.
+ unet ([`UNet2DConditionModel`]):
+ A `UNet2DConditionModel` to denoise the encoded image latents.
+ controlnet ([`ControlNetModel`] or `List[ControlNetModel]`):
+ Provides additional conditioning to the `unet` during the denoising process. If you set multiple
+ ControlNets as a list, the outputs from each ControlNet are added together to create one combined
+ additional conditioning.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
+ about a model's potential harms.
+ feature_extractor ([`~transformers.CLIPImageProcessor`]):
+ A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
+ """
+
+ model_cpu_offload_seq = "text_encoder->unet->vae"
+ _optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
+ _exclude_from_cpu_offload = ["safety_checker"]
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],
+ scheduler: KarrasDiffusionSchedulers,
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPImageProcessor,
+ image_encoder: CLIPVisionModelWithProjection = None,
+ requires_safety_checker: bool = True,
+ ):
+ super().__init__(
+ vae,
+ text_encoder,
+ tokenizer,
+ unet,
+ controlnet,
+ scheduler,
+ safety_checker,
+ feature_extractor,
+ image_encoder,
+ requires_safety_checker,
+ )
+
+ if safety_checker is None and requires_safety_checker:
+ logger.warning(
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
+ )
+
+ if safety_checker is not None and feature_extractor is None:
+ raise ValueError(
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
+ )
+
+ if isinstance(controlnet, (list, tuple)):
+ controlnet = MultiControlNetModel(controlnet)
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ controlnet=controlnet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ image_encoder=image_encoder,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
+ self.control_image_processor = VaeImageProcessor(
+ vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
+ )
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
+
+ frescoProc = FRESCOAttnProcessor2_0(2, AttentionControl())
+ attnProc = AttnProcessor2_0()
+ attn_processor_dict = {}
+ for k in self.unet.attn_processors.keys():
+ if k.startswith("up_blocks.2") or k.startswith("up_blocks.3"):
+ attn_processor_dict[k] = frescoProc
+ else:
+ attn_processor_dict[k] = attnProc
+ self.unet.set_attn_processor(attn_processor_dict)
+ self.frescoProc = frescoProc
+
+ flow_model = GMFlow(
+ feature_channels=128,
+ num_scales=1,
+ upsample_factor=8,
+ num_head=1,
+ attention_type="swin",
+ ffn_dim_expansion=4,
+ num_transformer_layers=6,
+ ).to(self.device)
+
+ checkpoint = torch.utils.model_zoo.load_url(
+ "https://huggingface.co/Anonymous-sub/Rerender/resolve/main/models/gmflow_sintel-0c07dcb3.pth",
+ map_location=lambda storage, loc: storage,
+ )
+ weights = checkpoint["model"] if "model" in checkpoint else checkpoint
+ flow_model.load_state_dict(weights, strict=False)
+ flow_model.eval()
+ self.flow_model = flow_model
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
+ def _encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt=None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ lora_scale: Optional[float] = None,
+ **kwargs,
+ ):
+ deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
+ deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
+
+ prompt_embeds_tuple = self.encode_prompt(
+ prompt=prompt,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ do_classifier_free_guidance=do_classifier_free_guidance,
+ negative_prompt=negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ lora_scale=lora_scale,
+ **kwargs,
+ )
+
+ # concatenate for backwards comp
+ prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
+
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt=None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ lora_scale: Optional[float] = None,
+ clip_skip: Optional[int] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ lora_scale (`float`, *optional*):
+ A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+ """
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ # dynamically adjust the LoRA scale
+ if not USE_PEFT_BACKEND:
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
+ else:
+ scale_lora_layers(self.text_encoder, lora_scale)
+
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ # textual inversion: process multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = self.tokenizer.batch_decode(
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
+ )
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = text_inputs.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ if clip_skip is None:
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
+ prompt_embeds = prompt_embeds[0]
+ else:
+ prompt_embeds = self.text_encoder(
+ text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True
+ )
+ # Access the `hidden_states` first, that contains a tuple of
+ # all the hidden states from the encoder layers. Then index into
+ # the tuple to access the hidden states from the desired layer.
+ prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
+ # We also need to apply the final LayerNorm here to not mess with the
+ # representations. The `last_hidden_states` that we typically use for
+ # obtaining the final prompt representations passes through the LayerNorm
+ # layer.
+ prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
+
+ if self.text_encoder is not None:
+ prompt_embeds_dtype = self.text_encoder.dtype
+ elif self.unet is not None:
+ prompt_embeds_dtype = self.unet.dtype
+ else:
+ prompt_embeds_dtype = prompt_embeds.dtype
+
+ prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ # textual inversion: process multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = uncond_input.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ negative_prompt_embeds = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ negative_prompt_embeds = negative_prompt_embeds[0]
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder, lora_scale)
+
+ return prompt_embeds, negative_prompt_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
+ def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
+ dtype = next(self.image_encoder.parameters()).dtype
+
+ if not isinstance(image, torch.Tensor):
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
+
+ image = image.to(device=device, dtype=dtype)
+ if output_hidden_states:
+ image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
+ image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
+ uncond_image_enc_hidden_states = self.image_encoder(
+ torch.zeros_like(image), output_hidden_states=True
+ ).hidden_states[-2]
+ uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
+ num_images_per_prompt, dim=0
+ )
+ return image_enc_hidden_states, uncond_image_enc_hidden_states
+ else:
+ image_embeds = self.image_encoder(image).image_embeds
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+ uncond_image_embeds = torch.zeros_like(image_embeds)
+
+ return image_embeds, uncond_image_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
+ def prepare_ip_adapter_image_embeds(
+ self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
+ ):
+ if ip_adapter_image_embeds is None:
+ if not isinstance(ip_adapter_image, list):
+ ip_adapter_image = [ip_adapter_image]
+
+ if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
+ raise ValueError(
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
+ )
+
+ image_embeds = []
+ for single_ip_adapter_image, image_proj_layer in zip(
+ ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
+ ):
+ output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
+ single_image_embeds, single_negative_image_embeds = self.encode_image(
+ single_ip_adapter_image, device, 1, output_hidden_state
+ )
+ single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
+ single_negative_image_embeds = torch.stack(
+ [single_negative_image_embeds] * num_images_per_prompt, dim=0
+ )
+
+ if do_classifier_free_guidance:
+ single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
+ single_image_embeds = single_image_embeds.to(device)
+
+ image_embeds.append(single_image_embeds)
+ else:
+ repeat_dims = [1]
+ image_embeds = []
+ for single_image_embeds in ip_adapter_image_embeds:
+ if do_classifier_free_guidance:
+ single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
+ single_image_embeds = single_image_embeds.repeat(
+ num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
+ )
+ single_negative_image_embeds = single_negative_image_embeds.repeat(
+ num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
+ )
+ single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
+ else:
+ single_image_embeds = single_image_embeds.repeat(
+ num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
+ )
+ image_embeds.append(single_image_embeds)
+
+ return image_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
+ def run_safety_checker(self, image, device, dtype):
+ if self.safety_checker is None:
+ has_nsfw_concept = None
+ else:
+ if torch.is_tensor(image):
+ feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
+ else:
+ feature_extractor_input = self.image_processor.numpy_to_pil(image)
+ safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
+ image, has_nsfw_concept = self.safety_checker(
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
+ )
+ return image, has_nsfw_concept
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
+ def decode_latents(self, latents):
+ deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
+ deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
+
+ latents = 1 / self.vae.config.scaling_factor * latents
+ image = self.vae.decode(latents, return_dict=False)[0]
+ image = (image / 2 + 0.5).clamp(0, 1)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+ return image
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(
+ self,
+ prompt,
+ image,
+ callback_steps,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ ip_adapter_image=None,
+ ip_adapter_image_embeds=None,
+ controlnet_conditioning_scale=1.0,
+ control_guidance_start=0.0,
+ control_guidance_end=1.0,
+ callback_on_step_end_tensor_inputs=None,
+ ):
+ if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ # `prompt` needs more sophisticated handling when there are multiple
+ # conditionings.
+ if isinstance(self.controlnet, MultiControlNetModel):
+ if isinstance(prompt, list):
+ logger.warning(
+ f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}"
+ " prompts. The conditionings will be fixed across the prompts."
+ )
+
+ # Check `image`
+ is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
+ self.controlnet, torch._dynamo.eval_frame.OptimizedModule
+ )
+ if (
+ isinstance(self.controlnet, ControlNetModel)
+ or is_compiled
+ and isinstance(self.controlnet._orig_mod, ControlNetModel)
+ ):
+ self.check_image(image, prompt, prompt_embeds)
+ elif (
+ isinstance(self.controlnet, MultiControlNetModel)
+ or is_compiled
+ and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
+ ):
+ if not isinstance(image, list):
+ raise TypeError("For multiple controlnets: `image` must be type `list`")
+
+ # When `image` is a nested list:
+ # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]])
+ elif any(isinstance(i, list) for i in image):
+ raise ValueError("A single batch of multiple conditionings are supported at the moment.")
+ elif len(image) != len(self.controlnet.nets):
+ raise ValueError(
+ f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets."
+ )
+
+ for image_ in image:
+ self.check_image(image_, prompt, prompt_embeds)
+ else:
+ assert False
+
+ # Check `controlnet_conditioning_scale`
+ if (
+ isinstance(self.controlnet, ControlNetModel)
+ or is_compiled
+ and isinstance(self.controlnet._orig_mod, ControlNetModel)
+ ):
+ if not isinstance(controlnet_conditioning_scale, float):
+ raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
+ elif (
+ isinstance(self.controlnet, MultiControlNetModel)
+ or is_compiled
+ and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
+ ):
+ if isinstance(controlnet_conditioning_scale, list):
+ if any(isinstance(i, list) for i in controlnet_conditioning_scale):
+ raise ValueError("A single batch of multiple conditionings are supported at the moment.")
+ elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len(
+ self.controlnet.nets
+ ):
+ raise ValueError(
+ "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have"
+ " the same length as the number of controlnets"
+ )
+ else:
+ assert False
+
+ if len(control_guidance_start) != len(control_guidance_end):
+ raise ValueError(
+ f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list."
+ )
+
+ if isinstance(self.controlnet, MultiControlNetModel):
+ if len(control_guidance_start) != len(self.controlnet.nets):
+ raise ValueError(
+ f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}."
+ )
+
+ for start, end in zip(control_guidance_start, control_guidance_end):
+ if start >= end:
+ raise ValueError(
+ f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}."
+ )
+ if start < 0.0:
+ raise ValueError(f"control guidance start: {start} can't be smaller than 0.")
+ if end > 1.0:
+ raise ValueError(f"control guidance end: {end} can't be larger than 1.0.")
+
+ if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
+ raise ValueError(
+ "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
+ )
+
+ if ip_adapter_image_embeds is not None:
+ if not isinstance(ip_adapter_image_embeds, list):
+ raise ValueError(
+ f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
+ )
+ elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
+ raise ValueError(
+ f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
+ )
+
+ # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image
+ def check_image(self, image, prompt, prompt_embeds):
+ image_is_pil = isinstance(image, PIL.Image.Image)
+ image_is_tensor = isinstance(image, torch.Tensor)
+ image_is_np = isinstance(image, np.ndarray)
+ image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)
+ image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)
+ image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray)
+
+ if (
+ not image_is_pil
+ and not image_is_tensor
+ and not image_is_np
+ and not image_is_pil_list
+ and not image_is_tensor_list
+ and not image_is_np_list
+ ):
+ raise TypeError(
+ f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}"
+ )
+
+ if image_is_pil:
+ image_batch_size = 1
+ else:
+ image_batch_size = len(image)
+
+ if prompt is not None and isinstance(prompt, str):
+ prompt_batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ prompt_batch_size = len(prompt)
+ elif prompt_embeds is not None:
+ prompt_batch_size = prompt_embeds.shape[0]
+
+ if image_batch_size != 1 and image_batch_size != prompt_batch_size:
+ raise ValueError(
+ f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
+ )
+
+ # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image
+ def prepare_control_image(
+ self,
+ image,
+ width,
+ height,
+ batch_size,
+ num_images_per_prompt,
+ device,
+ dtype,
+ do_classifier_free_guidance=False,
+ guess_mode=False,
+ ):
+ image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
+ image_batch_size = image.shape[0]
+
+ if image_batch_size == 1:
+ repeat_by = batch_size
+ else:
+ # image batch size is the same as prompt batch size
+ repeat_by = num_images_per_prompt
+
+ image = image.repeat_interleave(repeat_by, dim=0)
+
+ image = image.to(device=device, dtype=dtype)
+
+ if do_classifier_free_guidance and not guess_mode:
+ image = torch.cat([image] * 2)
+
+ return image
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
+ def get_timesteps(self, num_inference_steps, strength, device):
+ # get the original timestep using init_timestep
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
+
+ t_start = max(num_inference_steps - init_timestep, 0)
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
+ if hasattr(self.scheduler, "set_begin_index"):
+ self.scheduler.set_begin_index(t_start * self.scheduler.order)
+
+ return timesteps, num_inference_steps - t_start
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.prepare_latents
+ def prepare_latents(
+ self, image, timestep, batch_size, num_images_per_prompt, dtype, device, repeat_noise, generator=None
+ ):
+ if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
+ raise ValueError(
+ f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
+ )
+
+ image = image.to(device=device, dtype=dtype)
+
+ batch_size = batch_size * num_images_per_prompt
+
+ if image.shape[1] == 4:
+ init_latents = image
+
+ else:
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ elif isinstance(generator, list):
+ init_latents = [
+ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
+ for i in range(batch_size)
+ ]
+ init_latents = torch.cat(init_latents, dim=0)
+ else:
+ init_latents = retrieve_latents(self.vae.encode(image), generator=generator)
+
+ init_latents = self.vae.config.scaling_factor * init_latents
+
+ if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
+ # expand init_latents for batch_size
+ deprecation_message = (
+ f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial"
+ " images (`image`). Initial images are now duplicating to match the number of text prompts. Note"
+ " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
+ " your script to pass as many initial images as text prompts to suppress this warning."
+ )
+ deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False)
+ additional_image_per_prompt = batch_size // init_latents.shape[0]
+ init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0)
+ elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
+ raise ValueError(
+ f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
+ )
+ else:
+ init_latents = torch.cat([init_latents], dim=0)
+
+ shape = init_latents.shape
+ if repeat_noise:
+ noise = randn_tensor((1, *shape[1:]), generator=generator, device=device, dtype=dtype)
+ one_tuple = (1,) * (len(shape) - 1)
+ noise = noise.repeat(batch_size, *one_tuple)
+ else:
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+
+ # get latents
+ init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
+ latents = init_latents
+
+ return latents
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def clip_skip(self):
+ return self._clip_skip
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1
+
+ @property
+ def cross_attention_kwargs(self):
+ return self._cross_attention_kwargs
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ frames: Union[List[np.ndarray], torch.FloatTensor] = None,
+ control_frames: Union[List[np.ndarray], torch.FloatTensor] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ strength: float = 0.8,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ ip_adapter_image: Optional[PipelineImageInput] = None,
+ ip_adapter_image_embeds: Optional[List[torch.FloatTensor]] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ controlnet_conditioning_scale: Union[float, List[float]] = 0.8,
+ guess_mode: bool = False,
+ control_guidance_start: Union[float, List[float]] = 0.0,
+ control_guidance_end: Union[float, List[float]] = 1.0,
+ clip_skip: Optional[int] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ end_opt_step=15,
+ num_intraattn_steps=1,
+ step_interattn_end=350,
+ **kwargs,
+ ):
+ r"""
+ The call function to the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
+ frames (`List[np.ndarray]` or `torch.FloatTensor`): The input images to be used as the starting point for the image generation process.
+ control_frames (`List[np.ndarray]` or `torch.FloatTensor`): The ControlNet input images condition to provide guidance to the `unet` for generation.
+ height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
+ The width in pixels of the generated image.
+ strength (`float`, *optional*, defaults to 0.8):
+ Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
+ starting point and more noise is added the higher the `strength`. The number of denoising steps depends
+ on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
+ process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
+ essentially ignores `image`.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ A higher guidance scale value encourages the model to generate images closely linked to the text
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
+ generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor is generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
+ provided, text embeddings are generated from the `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
+ ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
+ contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
+ The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added
+ to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set
+ the corresponding scale as a list.
+ guess_mode (`bool`, *optional*, defaults to `False`):
+ The ControlNet encoder tries to recognize the content of the input image even if you remove all
+ prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended.
+ control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
+ The percentage of total steps at which the ControlNet starts applying.
+ control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
+ The percentage of total steps at which the ControlNet stops applying.
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ end_opt_step:
+ The feature optimization is activated from strength * num_inference_step to end_opt_step.
+ num_intraattn_steps:
+ Apply num_interattn_steps steps of spatial-guided attention.
+ step_interattn_end:
+ Apply temporal-guided attention in [step_interattn_end, 1000] steps
+
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
+ otherwise a `tuple` is returned where the first element is a list with the generated images and the
+ second element is a list of `bool`s indicating whether the corresponding generated image contains
+ "not-safe-for-work" (nsfw) content.
+ """
+
+ callback = kwargs.pop("callback", None)
+ callback_steps = kwargs.pop("callback_steps", None)
+
+ if callback is not None:
+ deprecate(
+ "callback",
+ "1.0.0",
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
+ )
+ if callback_steps is not None:
+ deprecate(
+ "callback_steps",
+ "1.0.0",
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
+ )
+
+ controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
+
+ # align format for control guidance
+ if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
+ control_guidance_start = len(control_guidance_end) * [control_guidance_start]
+ elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
+ control_guidance_end = len(control_guidance_start) * [control_guidance_end]
+ elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
+ mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
+ control_guidance_start, control_guidance_end = (
+ mult * [control_guidance_start],
+ mult * [control_guidance_end],
+ )
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ control_frames[0],
+ callback_steps,
+ negative_prompt,
+ prompt_embeds,
+ negative_prompt_embeds,
+ ip_adapter_image,
+ ip_adapter_image_embeds,
+ controlnet_conditioning_scale,
+ control_guidance_start,
+ control_guidance_end,
+ callback_on_step_end_tensor_inputs,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._clip_skip = clip_skip
+ self._cross_attention_kwargs = cross_attention_kwargs
+
+ # 2. Define call parameters
+ batch_size = len(frames)
+
+ device = self._execution_device
+
+ if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
+ controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
+
+ global_pool_conditions = (
+ controlnet.config.global_pool_conditions
+ if isinstance(controlnet, ControlNetModel)
+ else controlnet.nets[0].config.global_pool_conditions
+ )
+ guess_mode = guess_mode or global_pool_conditions
+
+ # 3. Encode input prompt
+ text_encoder_lora_scale = (
+ self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
+ )
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
+ prompt,
+ device,
+ num_images_per_prompt,
+ self.do_classifier_free_guidance,
+ negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ lora_scale=text_encoder_lora_scale,
+ clip_skip=self.clip_skip,
+ )
+ prompt_embeds = prompt_embeds.repeat(batch_size, 1, 1)
+ negative_prompt_embeds = negative_prompt_embeds.repeat(batch_size, 1, 1)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ if self.do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
+ image_embeds = self.prepare_ip_adapter_image_embeds(
+ ip_adapter_image,
+ ip_adapter_image_embeds,
+ device,
+ batch_size * num_images_per_prompt,
+ self.do_classifier_free_guidance,
+ )
+
+ # 4. Prepare image
+ imgs_np = []
+ for frame in frames:
+ if isinstance(frame, PIL.Image.Image):
+ imgs_np.append(np.asarray(frame))
+ else:
+ # np.ndarray
+ imgs_np.append(frame)
+ images_pt = self.image_processor.preprocess(frames).to(dtype=torch.float32)
+
+ # 5. Prepare controlnet_conditioning_image
+ if isinstance(controlnet, ControlNetModel):
+ control_image = self.prepare_control_image(
+ image=control_frames,
+ width=width,
+ height=height,
+ batch_size=batch_size * num_images_per_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ dtype=controlnet.dtype,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ guess_mode=guess_mode,
+ )
+ elif isinstance(controlnet, MultiControlNetModel):
+ control_images = []
+
+ for control_image_ in control_frames:
+ control_image_ = self.prepare_control_image(
+ image=control_image_,
+ width=width,
+ height=height,
+ batch_size=batch_size * num_images_per_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ dtype=controlnet.dtype,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ guess_mode=guess_mode,
+ )
+
+ control_images.append(control_image_)
+
+ control_image = control_images
+ else:
+ assert False
+
+ self.flow_model.to(device)
+
+ flows, occs, attn_mask, interattn_paras = get_flow_and_interframe_paras(self.flow_model, imgs_np)
+ correlation_matrix = get_intraframe_paras(self, images_pt, self.frescoProc, prompt_embeds, generator)
+
+ """
+ Flexible settings for attention:
+ * Turn off FRESCO-guided attention: frescoProc.controller.disable_controller()
+ Then you can turn on one specific attention submodule
+ * Turn on Cross-frame attention: frescoProc.controller.enable_cfattn(attn_mask)
+ * Turn on Spatial-guided attention: frescoProc.controller.enable_intraattn()
+ * Turn on Temporal-guided attention: frescoProc.controller.enable_interattn(interattn_paras)
+
+ Flexible settings for optimization:
+ * Turn off Spatial-guided optimization: set optimize_temporal = False in apply_FRESCO_opt()
+ * Turn off Temporal-guided optimization: set correlation_matrix = [] in apply_FRESCO_opt()
+ * Turn off FRESCO-guided optimization: disable_FRESCO_opt(pipe)
+
+ Flexible settings for background smoothing:
+ * Turn off background smoothing: set saliency = None in apply_FRESCO_opt()
+ """
+
+ self.frescoProc.controller.enable_controller(interattn_paras=interattn_paras, attn_mask=attn_mask)
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+ apply_FRESCO_opt(
+ self,
+ steps=timesteps[:end_opt_step],
+ flows=flows,
+ occs=occs,
+ correlation_matrix=correlation_matrix,
+ saliency=None,
+ optimize_temporal=True,
+ )
+
+ clear_cache()
+
+ # 5. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
+ self._num_timesteps = len(timesteps)
+
+ # 6. Prepare latent variables
+ latents = self.prepare_latents(
+ images_pt,
+ latent_timestep,
+ batch_size,
+ num_images_per_prompt,
+ prompt_embeds.dtype,
+ device,
+ generator=generator,
+ repeat_noise=True,
+ )
+
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 7.1 Add image embeds for IP-Adapter
+ added_cond_kwargs = (
+ {"image_embeds": image_embeds}
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None
+ else None
+ )
+
+ # 7.2 Create tensor stating which controlnets to keep
+ controlnet_keep = []
+ for i in range(len(timesteps)):
+ keeps = [
+ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
+ for s, e in zip(control_guidance_start, control_guidance_end)
+ ]
+ controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)
+
+ # 8. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if i >= num_intraattn_steps:
+ self.frescoProc.controller.disable_intraattn()
+ if t < step_interattn_end:
+ self.frescoProc.controller.disable_interattn()
+
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # controlnet(s) inference
+ if guess_mode and self.do_classifier_free_guidance:
+ # Infer ControlNet only for the conditional batch.
+ control_model_input = latents
+ control_model_input = self.scheduler.scale_model_input(control_model_input, t)
+ controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
+ else:
+ control_model_input = latent_model_input
+ controlnet_prompt_embeds = prompt_embeds
+
+ if isinstance(controlnet_keep[i], list):
+ cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
+ else:
+ controlnet_cond_scale = controlnet_conditioning_scale
+ if isinstance(controlnet_cond_scale, list):
+ controlnet_cond_scale = controlnet_cond_scale[0]
+ cond_scale = controlnet_cond_scale * controlnet_keep[i]
+
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
+ control_model_input,
+ t,
+ encoder_hidden_states=controlnet_prompt_embeds,
+ controlnet_cond=control_image,
+ conditioning_scale=cond_scale,
+ guess_mode=guess_mode,
+ return_dict=False,
+ )
+
+ if guess_mode and self.do_classifier_free_guidance:
+ # Inferred ControlNet only for the conditional batch.
+ # To apply the output of ControlNet to both the unconditional and conditional batches,
+ # add 0 to the unconditional batch to keep it unchanged.
+ down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
+ mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
+
+ # predict the noise residual
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs=self.cross_attention_kwargs,
+ down_block_additional_residuals=down_block_res_samples,
+ mid_block_additional_residual=mid_block_res_sample,
+ added_cond_kwargs=added_cond_kwargs,
+ return_dict=False,
+ )[0]
+
+ # perform guidance
+ if self.do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
+
+ # If we do sequential model offloading, let's offload unet and controlnet
+ # manually for max memory savings
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.unet.to("cpu")
+ self.controlnet.to("cpu")
+ torch.cuda.empty_cache()
+
+ if not output_type == "latent":
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
+ 0
+ ]
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
+ else:
+ image = latents
+ has_nsfw_concept = None
+
+ if has_nsfw_concept is None:
+ do_denormalize = [True] * image.shape[0]
+ else:
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
+
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
diff --git a/diffusers/examples/community/gluegen.py b/diffusers/examples/community/gluegen.py
new file mode 100644
index 0000000000000000000000000000000000000000..91026c5d966f824302a8e454e3016007171c4c43
--- /dev/null
+++ b/diffusers/examples/community/gluegen.py
@@ -0,0 +1,816 @@
+import inspect
+from typing import Any, Dict, List, Optional, Union
+
+import torch
+import torch.nn as nn
+from transformers import AutoModel, AutoTokenizer, CLIPImageProcessor
+
+from diffusers import DiffusionPipeline
+from diffusers.image_processor import VaeImageProcessor
+from diffusers.loaders import StableDiffusionLoraLoaderMixin
+from diffusers.models import AutoencoderKL, UNet2DConditionModel
+from diffusers.models.lora import adjust_lora_scale_text_encoder
+from diffusers.pipelines.pipeline_utils import StableDiffusionMixin
+from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
+from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
+from diffusers.schedulers import KarrasDiffusionSchedulers
+from diffusers.utils import (
+ USE_PEFT_BACKEND,
+ logging,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
+from diffusers.utils.torch_utils import randn_tensor
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class TranslatorBase(nn.Module):
+ def __init__(self, num_tok, dim, dim_out, mult=2):
+ super().__init__()
+
+ self.dim_in = dim
+ self.dim_out = dim_out
+
+ self.net_tok = nn.Sequential(
+ nn.Linear(num_tok, int(num_tok * mult)),
+ nn.LayerNorm(int(num_tok * mult)),
+ nn.GELU(),
+ nn.Linear(int(num_tok * mult), int(num_tok * mult)),
+ nn.LayerNorm(int(num_tok * mult)),
+ nn.GELU(),
+ nn.Linear(int(num_tok * mult), num_tok),
+ nn.LayerNorm(num_tok),
+ )
+
+ self.net_sen = nn.Sequential(
+ nn.Linear(dim, int(dim * mult)),
+ nn.LayerNorm(int(dim * mult)),
+ nn.GELU(),
+ nn.Linear(int(dim * mult), int(dim * mult)),
+ nn.LayerNorm(int(dim * mult)),
+ nn.GELU(),
+ nn.Linear(int(dim * mult), dim_out),
+ nn.LayerNorm(dim_out),
+ )
+
+ def forward(self, x):
+ if self.dim_in == self.dim_out:
+ indentity_0 = x
+ x = self.net_sen(x)
+ x += indentity_0
+ x = x.transpose(1, 2)
+
+ indentity_1 = x
+ x = self.net_tok(x)
+ x += indentity_1
+ x = x.transpose(1, 2)
+ else:
+ x = self.net_sen(x)
+ x = x.transpose(1, 2)
+
+ x = self.net_tok(x)
+ x = x.transpose(1, 2)
+ return x
+
+
+class TranslatorBaseNoLN(nn.Module):
+ def __init__(self, num_tok, dim, dim_out, mult=2):
+ super().__init__()
+
+ self.dim_in = dim
+ self.dim_out = dim_out
+
+ self.net_tok = nn.Sequential(
+ nn.Linear(num_tok, int(num_tok * mult)),
+ nn.GELU(),
+ nn.Linear(int(num_tok * mult), int(num_tok * mult)),
+ nn.GELU(),
+ nn.Linear(int(num_tok * mult), num_tok),
+ )
+
+ self.net_sen = nn.Sequential(
+ nn.Linear(dim, int(dim * mult)),
+ nn.GELU(),
+ nn.Linear(int(dim * mult), int(dim * mult)),
+ nn.GELU(),
+ nn.Linear(int(dim * mult), dim_out),
+ )
+
+ def forward(self, x):
+ if self.dim_in == self.dim_out:
+ indentity_0 = x
+ x = self.net_sen(x)
+ x += indentity_0
+ x = x.transpose(1, 2)
+
+ indentity_1 = x
+ x = self.net_tok(x)
+ x += indentity_1
+ x = x.transpose(1, 2)
+ else:
+ x = self.net_sen(x)
+ x = x.transpose(1, 2)
+
+ x = self.net_tok(x)
+ x = x.transpose(1, 2)
+ return x
+
+
+class TranslatorNoLN(nn.Module):
+ def __init__(self, num_tok, dim, dim_out, mult=2, depth=5):
+ super().__init__()
+
+ self.blocks = nn.ModuleList([TranslatorBase(num_tok, dim, dim, mult=2) for d in range(depth)])
+ self.gelu = nn.GELU()
+
+ self.tail = TranslatorBaseNoLN(num_tok, dim, dim_out, mult=2)
+
+ def forward(self, x):
+ for block in self.blocks:
+ x = block(x) + x
+ x = self.gelu(x)
+
+ x = self.tail(x)
+ return x
+
+
+def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
+ """
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
+ """
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
+ # rescale the results from guidance (fixes overexposure)
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
+ return noise_cfg
+
+
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ **kwargs,
+):
+ """
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used,
+ `timesteps` must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
+ timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
+ must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class GlueGenStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin, StableDiffusionLoraLoaderMixin):
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: AutoModel,
+ tokenizer: AutoTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: KarrasDiffusionSchedulers,
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPImageProcessor,
+ language_adapter: TranslatorNoLN = None,
+ tensor_norm: torch.Tensor = None,
+ requires_safety_checker: bool = True,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ language_adapter=language_adapter,
+ tensor_norm=tensor_norm,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
+
+ def load_language_adapter(
+ self,
+ model_path: str,
+ num_token: int,
+ dim: int,
+ dim_out: int,
+ tensor_norm: torch.Tensor,
+ mult: int = 2,
+ depth: int = 5,
+ ):
+ device = self._execution_device
+ self.tensor_norm = tensor_norm.to(device)
+ self.language_adapter = TranslatorNoLN(num_tok=num_token, dim=dim, dim_out=dim_out, mult=mult, depth=depth).to(
+ device
+ )
+ self.language_adapter.load_state_dict(torch.load(model_path))
+
+ def _adapt_language(self, prompt_embeds: torch.Tensor):
+ prompt_embeds = prompt_embeds / 3
+ prompt_embeds = self.language_adapter(prompt_embeds) * (self.tensor_norm / 2)
+ return prompt_embeds
+
+ def encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt=None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ lora_scale: Optional[float] = None,
+ clip_skip: Optional[int] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ lora_scale (`float`, *optional*):
+ A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+ """
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ # dynamically adjust the LoRA scale
+ if not USE_PEFT_BACKEND:
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
+ else:
+ scale_lora_layers(self.text_encoder, lora_scale)
+
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = self.tokenizer.batch_decode(
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
+ )
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = text_inputs.attention_mask.to(device)
+ elif self.language_adapter is not None:
+ attention_mask = text_inputs.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ if clip_skip is None:
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
+ prompt_embeds = prompt_embeds[0]
+
+ else:
+ prompt_embeds = self.text_encoder(
+ text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True
+ )
+ # Access the `hidden_states` first, that contains a tuple of
+ # all the hidden states from the encoder layers. Then index into
+ # the tuple to access the hidden states from the desired layer.
+ prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
+ # We also need to apply the final LayerNorm here to not mess with the
+ # representations. The `last_hidden_states` that we typically use for
+ # obtaining the final prompt representations passes through the LayerNorm
+ # layer.
+ prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
+
+ # Run prompt language adapter
+ if self.language_adapter is not None:
+ prompt_embeds = self._adapt_language(prompt_embeds)
+
+ if self.text_encoder is not None:
+ prompt_embeds_dtype = self.text_encoder.dtype
+ elif self.unet is not None:
+ prompt_embeds_dtype = self.unet.dtype
+ else:
+ prompt_embeds_dtype = prompt_embeds.dtype
+
+ prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = uncond_input.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ negative_prompt_embeds = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ negative_prompt_embeds = negative_prompt_embeds[0]
+ # Run negative prompt language adapter
+ if self.language_adapter is not None:
+ negative_prompt_embeds = self._adapt_language(negative_prompt_embeds)
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder, lora_scale)
+
+ return prompt_embeds, negative_prompt_embeds
+
+ def run_safety_checker(self, image, device, dtype):
+ if self.safety_checker is None:
+ has_nsfw_concept = None
+ else:
+ if torch.is_tensor(image):
+ feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
+ else:
+ feature_extractor_input = self.image_processor.numpy_to_pil(image)
+ safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
+ image, has_nsfw_concept = self.safety_checker(
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
+ )
+ return image, has_nsfw_concept
+
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ ):
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
+ shape = (
+ batch_size,
+ num_channels_latents,
+ int(height) // self.vae_scale_factor,
+ int(width) // self.vae_scale_factor,
+ )
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
+ def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
+ """
+ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
+
+ Args:
+ timesteps (`torch.Tensor`):
+ generate embedding vectors at these timesteps
+ embedding_dim (`int`, *optional*, defaults to 512):
+ dimension of the embeddings to generate
+ dtype:
+ data type of the generated embeddings
+
+ Returns:
+ `torch.Tensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
+ """
+ assert len(w.shape) == 1
+ w = w * 1000.0
+
+ half_dim = embedding_dim // 2
+ emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
+ emb = w.to(dtype)[:, None] * emb[None, :]
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+ if embedding_dim % 2 == 1: # zero pad
+ emb = torch.nn.functional.pad(emb, (0, 1))
+ assert emb.shape == (w.shape[0], embedding_dim)
+ return emb
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def guidance_rescale(self):
+ return self._guidance_rescale
+
+ @property
+ def clip_skip(self):
+ return self._clip_skip
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
+
+ @property
+ def cross_attention_kwargs(self):
+ return self._cross_attention_kwargs
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ timesteps: List[int] = None,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ guidance_rescale: float = 0.0,
+ clip_skip: Optional[int] = None,
+ **kwargs,
+ ):
+ r"""
+ The call function to the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
+ height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
+ passed will be used. Must be in descending order.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ A higher guidance scale value encourages the model to generate images closely linked to the text
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
+ generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor is generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
+ provided, text embeddings are generated from the `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
+ Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when
+ using zero terminal SNR.
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
+ otherwise a `tuple` is returned where the first element is a list with the generated images and the
+ second element is a list of `bool`s indicating whether the corresponding generated image contains
+ "not-safe-for-work" (nsfw) content.
+ """
+
+ # 0. Default height and width to unet
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
+ # to deal with lora scaling and other possible forward hooks
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ height,
+ width,
+ negative_prompt,
+ prompt_embeds,
+ negative_prompt_embeds,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._guidance_rescale = guidance_rescale
+ self._clip_skip = clip_skip
+ self._cross_attention_kwargs = cross_attention_kwargs
+ self._interrupt = False
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ # 3. Encode input prompt
+ lora_scale = (
+ self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
+ )
+
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
+ prompt,
+ device,
+ num_images_per_prompt,
+ self.do_classifier_free_guidance,
+ negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ lora_scale=lora_scale,
+ clip_skip=self.clip_skip,
+ )
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ if self.do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+
+ # 4. Prepare timesteps
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.unet.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 6.2 Optionally get Guidance Scale Embedding
+ timestep_cond = None
+ if self.unet.config.time_cond_proj_dim is not None:
+ guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
+ timestep_cond = self.get_guidance_scale_embedding(
+ guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
+ ).to(device=device, dtype=latents.dtype)
+
+ # 7. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ self._num_timesteps = len(timesteps)
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ timestep_cond=timestep_cond,
+ cross_attention_kwargs=self.cross_attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ # perform guidance
+ if self.do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if not output_type == "latent":
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
+ 0
+ ]
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
+ else:
+ image = latents
+ has_nsfw_concept = None
+
+ if has_nsfw_concept is None:
+ do_denormalize = [True] * image.shape[0]
+ else:
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
+
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
diff --git a/diffusers/examples/community/hd_painter.py b/diffusers/examples/community/hd_painter.py
new file mode 100644
index 0000000000000000000000000000000000000000..df41be9ef7b126e5bf60be2aca13555f478dc580
--- /dev/null
+++ b/diffusers/examples/community/hd_painter.py
@@ -0,0 +1,994 @@
+import math
+import numbers
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from diffusers.image_processor import PipelineImageInput
+from diffusers.models import AsymmetricAutoencoderKL, ImageProjection
+from diffusers.models.attention_processor import Attention, AttnProcessor
+from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
+from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint import (
+ StableDiffusionInpaintPipeline,
+ retrieve_timesteps,
+)
+from diffusers.utils import deprecate
+
+
+class RASGAttnProcessor:
+ def __init__(self, mask, token_idx, scale_factor):
+ self.attention_scores = None # Stores the last output of the similarity matrix here. Each layer will get its own RASGAttnProcessor assigned
+ self.mask = mask
+ self.token_idx = token_idx
+ self.scale_factor = scale_factor
+ self.mask_resoltuion = mask.shape[-1] * mask.shape[-2] # 64 x 64 if the image is 512x512
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ temb: Optional[torch.Tensor] = None,
+ scale: float = 1.0,
+ ) -> torch.Tensor:
+ # Same as the default AttnProcessor up untill the part where similarity matrix gets saved
+ downscale_factor = self.mask_resoltuion // hidden_states.shape[1]
+ residual = hidden_states
+
+ if attn.spatial_norm is not None:
+ hidden_states = attn.spatial_norm(hidden_states, temb)
+
+ input_ndim = hidden_states.ndim
+
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+
+ if attn.group_norm is not None:
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = attn.to_q(hidden_states)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ elif attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ query = attn.head_to_batch_dim(query)
+ key = attn.head_to_batch_dim(key)
+ value = attn.head_to_batch_dim(value)
+
+ # Automatically recognize the resolution and save the attention similarity values
+ # We need to use the values before the softmax function, hence the rewritten get_attention_scores function.
+ if downscale_factor == self.scale_factor**2:
+ self.attention_scores = get_attention_scores(attn, query, key, attention_mask)
+ attention_probs = self.attention_scores.softmax(dim=-1)
+ attention_probs = attention_probs.to(query.dtype)
+ else:
+ attention_probs = attn.get_attention_scores(query, key, attention_mask) # Original code
+
+ hidden_states = torch.bmm(attention_probs, value)
+ hidden_states = attn.batch_to_head_dim(hidden_states)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ return hidden_states
+
+
+class PAIntAAttnProcessor:
+ def __init__(self, transformer_block, mask, token_idx, do_classifier_free_guidance, scale_factors):
+ self.transformer_block = transformer_block # Stores the parent transformer block.
+ self.mask = mask
+ self.scale_factors = scale_factors
+ self.do_classifier_free_guidance = do_classifier_free_guidance
+ self.token_idx = token_idx
+ self.shape = mask.shape[2:]
+ self.mask_resoltuion = mask.shape[-1] * mask.shape[-2] # 64 x 64
+ self.default_processor = AttnProcessor()
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ temb: Optional[torch.Tensor] = None,
+ scale: float = 1.0,
+ ) -> torch.Tensor:
+ # Automatically recognize the resolution of the current attention layer and resize the masks accordingly
+ downscale_factor = self.mask_resoltuion // hidden_states.shape[1]
+
+ mask = None
+ for factor in self.scale_factors:
+ if downscale_factor == factor**2:
+ shape = (self.shape[0] // factor, self.shape[1] // factor)
+ mask = F.interpolate(self.mask, shape, mode="bicubic") # B, 1, H, W
+ break
+ if mask is None:
+ return self.default_processor(attn, hidden_states, encoder_hidden_states, attention_mask, temb, scale)
+
+ # STARTS HERE
+ residual = hidden_states
+ # Save the input hidden_states for later use
+ input_hidden_states = hidden_states
+
+ # ================================================== #
+ # =============== SELF ATTENTION 1 ================= #
+ # ================================================== #
+
+ if attn.spatial_norm is not None:
+ hidden_states = attn.spatial_norm(hidden_states, temb)
+
+ input_ndim = hidden_states.ndim
+
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+
+ if attn.group_norm is not None:
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = attn.to_q(hidden_states)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ elif attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ query = attn.head_to_batch_dim(query)
+ key = attn.head_to_batch_dim(key)
+ value = attn.head_to_batch_dim(value)
+
+ # self_attention_probs = attn.get_attention_scores(query, key, attention_mask) # We can't use post-softmax attention scores in this case
+ self_attention_scores = get_attention_scores(
+ attn, query, key, attention_mask
+ ) # The custom function returns pre-softmax probabilities
+ self_attention_probs = self_attention_scores.softmax(
+ dim=-1
+ ) # Manually compute the probabilities here, the scores will be reused in the second part of PAIntA
+ self_attention_probs = self_attention_probs.to(query.dtype)
+
+ hidden_states = torch.bmm(self_attention_probs, value)
+ hidden_states = attn.batch_to_head_dim(hidden_states)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ # x = x + self.attn1(self.norm1(x))
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ if attn.residual_connection: # So many residuals everywhere
+ hidden_states = hidden_states + residual
+
+ self_attention_output_hidden_states = hidden_states / attn.rescale_output_factor
+
+ # ================================================== #
+ # ============ BasicTransformerBlock =============== #
+ # ================================================== #
+ # We use a hack by running the code from the BasicTransformerBlock that is between Self and Cross attentions here
+ # The other option would've been modifying the BasicTransformerBlock and adding this functionality here.
+ # I assumed that changing the BasicTransformerBlock would have been a bigger deal and decided to use this hack isntead.
+
+ # The SelfAttention block recieves the normalized latents from the BasicTransformerBlock,
+ # But the residual of the output is the non-normalized version.
+ # Therefore we unnormalize the input hidden state here
+ unnormalized_input_hidden_states = (
+ input_hidden_states + self.transformer_block.norm1.bias
+ ) * self.transformer_block.norm1.weight
+
+ # TODO: return if neccessary
+ # if self.use_ada_layer_norm_zero:
+ # attn_output = gate_msa.unsqueeze(1) * attn_output
+ # elif self.use_ada_layer_norm_single:
+ # attn_output = gate_msa * attn_output
+
+ transformer_hidden_states = self_attention_output_hidden_states + unnormalized_input_hidden_states
+ if transformer_hidden_states.ndim == 4:
+ transformer_hidden_states = transformer_hidden_states.squeeze(1)
+
+ # TODO: return if neccessary
+ # 2.5 GLIGEN Control
+ # if gligen_kwargs is not None:
+ # transformer_hidden_states = self.fuser(transformer_hidden_states, gligen_kwargs["objs"])
+ # NOTE: we experimented with using GLIGEN and HDPainter together, the results were not that great
+
+ # 3. Cross-Attention
+ if self.transformer_block.use_ada_layer_norm:
+ # transformer_norm_hidden_states = self.transformer_block.norm2(transformer_hidden_states, timestep)
+ raise NotImplementedError()
+ elif self.transformer_block.use_ada_layer_norm_zero or self.transformer_block.use_layer_norm:
+ transformer_norm_hidden_states = self.transformer_block.norm2(transformer_hidden_states)
+ elif self.transformer_block.use_ada_layer_norm_single:
+ # For PixArt norm2 isn't applied here:
+ # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
+ transformer_norm_hidden_states = transformer_hidden_states
+ elif self.transformer_block.use_ada_layer_norm_continuous:
+ # transformer_norm_hidden_states = self.transformer_block.norm2(transformer_hidden_states, added_cond_kwargs["pooled_text_emb"])
+ raise NotImplementedError()
+ else:
+ raise ValueError("Incorrect norm")
+
+ if self.transformer_block.pos_embed is not None and self.transformer_block.use_ada_layer_norm_single is False:
+ transformer_norm_hidden_states = self.transformer_block.pos_embed(transformer_norm_hidden_states)
+
+ # ================================================== #
+ # ================= CROSS ATTENTION ================ #
+ # ================================================== #
+
+ # We do an initial pass of the CrossAttention up to obtaining the similarity matrix here.
+ # The similarity matrix is used to obtain scaling coefficients for the attention matrix of the self attention
+ # We reuse the previously computed self-attention matrix, and only repeat the steps after the softmax
+
+ cross_attention_input_hidden_states = (
+ transformer_norm_hidden_states # Renaming the variable for the sake of readability
+ )
+
+ # TODO: check if classifier_free_guidance is being used before splitting here
+ if self.do_classifier_free_guidance:
+ # Our scaling coefficients depend only on the conditional part, so we split the inputs
+ (
+ _cross_attention_input_hidden_states_unconditional,
+ cross_attention_input_hidden_states_conditional,
+ ) = cross_attention_input_hidden_states.chunk(2)
+
+ # Same split for the encoder_hidden_states i.e. the tokens
+ # Since the SelfAttention processors don't get the encoder states as input, we inject them into the processor in the begining.
+ _encoder_hidden_states_unconditional, encoder_hidden_states_conditional = self.encoder_hidden_states.chunk(
+ 2
+ )
+ else:
+ cross_attention_input_hidden_states_conditional = cross_attention_input_hidden_states
+ encoder_hidden_states_conditional = self.encoder_hidden_states.chunk(2)
+
+ # Rename the variables for the sake of readability
+ # The part below is the beginning of the __call__ function of the following CrossAttention layer
+ cross_attention_hidden_states = cross_attention_input_hidden_states_conditional
+ cross_attention_encoder_hidden_states = encoder_hidden_states_conditional
+
+ attn2 = self.transformer_block.attn2
+
+ if attn2.spatial_norm is not None:
+ cross_attention_hidden_states = attn2.spatial_norm(cross_attention_hidden_states, temb)
+
+ input_ndim = cross_attention_hidden_states.ndim
+
+ if input_ndim == 4:
+ batch_size, channel, height, width = cross_attention_hidden_states.shape
+ cross_attention_hidden_states = cross_attention_hidden_states.view(
+ batch_size, channel, height * width
+ ).transpose(1, 2)
+
+ (
+ batch_size,
+ sequence_length,
+ _,
+ ) = cross_attention_hidden_states.shape # It is definitely a cross attention, so no need for an if block
+ # TODO: change the attention_mask here
+ attention_mask = attn2.prepare_attention_mask(
+ None, sequence_length, batch_size
+ ) # I assume the attention mask is the same...
+
+ if attn2.group_norm is not None:
+ cross_attention_hidden_states = attn2.group_norm(cross_attention_hidden_states.transpose(1, 2)).transpose(
+ 1, 2
+ )
+
+ query2 = attn2.to_q(cross_attention_hidden_states)
+
+ if attn2.norm_cross:
+ cross_attention_encoder_hidden_states = attn2.norm_encoder_hidden_states(
+ cross_attention_encoder_hidden_states
+ )
+
+ key2 = attn2.to_k(cross_attention_encoder_hidden_states)
+ query2 = attn2.head_to_batch_dim(query2)
+ key2 = attn2.head_to_batch_dim(key2)
+
+ cross_attention_probs = attn2.get_attention_scores(query2, key2, attention_mask)
+
+ # CrossAttention ends here, the remaining part is not used
+
+ # ================================================== #
+ # ================ SELF ATTENTION 2 ================ #
+ # ================================================== #
+ # DEJA VU!
+
+ mask = (mask > 0.5).to(self_attention_output_hidden_states.dtype)
+ m = mask.to(self_attention_output_hidden_states.device)
+ # m = rearrange(m, 'b c h w -> b (h w) c').contiguous()
+ m = m.permute(0, 2, 3, 1).reshape((m.shape[0], -1, m.shape[1])).contiguous() # B HW 1
+ m = torch.matmul(m, m.permute(0, 2, 1)) + (1 - m)
+
+ # # Compute scaling coefficients for the similarity matrix
+ # # Select the cross attention values for the correct tokens only!
+ # cross_attention_probs = cross_attention_probs.mean(dim = 0)
+ # cross_attention_probs = cross_attention_probs[:, self.token_idx].sum(dim=1)
+
+ # cross_attention_probs = cross_attention_probs.reshape(shape)
+ # gaussian_smoothing = GaussianSmoothing(channels=1, kernel_size=3, sigma=0.5, dim=2).to(self_attention_output_hidden_states.device)
+ # cross_attention_probs = gaussian_smoothing(cross_attention_probs.unsqueeze(0))[0] # optional smoothing
+ # cross_attention_probs = cross_attention_probs.reshape(-1)
+ # cross_attention_probs = ((cross_attention_probs - torch.median(cross_attention_probs.ravel())) / torch.max(cross_attention_probs.ravel())).clip(0, 1)
+
+ # c = (1 - m) * cross_attention_probs.reshape(1, 1, -1) + m # PAIntA scaling coefficients
+
+ # Compute scaling coefficients for the similarity matrix
+ # Select the cross attention values for the correct tokens only!
+
+ batch_size, dims, channels = cross_attention_probs.shape
+ batch_size = batch_size // attn.heads
+ cross_attention_probs = cross_attention_probs.reshape((batch_size, attn.heads, dims, channels)) # B, D, HW, T
+
+ cross_attention_probs = cross_attention_probs.mean(dim=1) # B, HW, T
+ cross_attention_probs = cross_attention_probs[..., self.token_idx].sum(dim=-1) # B, HW
+ cross_attention_probs = cross_attention_probs.reshape((batch_size,) + shape) # , B, H, W
+
+ gaussian_smoothing = GaussianSmoothing(channels=1, kernel_size=3, sigma=0.5, dim=2).to(
+ self_attention_output_hidden_states.device
+ )
+ cross_attention_probs = gaussian_smoothing(cross_attention_probs[:, None])[:, 0] # optional smoothing B, H, W
+
+ # Median normalization
+ cross_attention_probs = cross_attention_probs.reshape(batch_size, -1) # B, HW
+ cross_attention_probs = (
+ cross_attention_probs - cross_attention_probs.median(dim=-1, keepdim=True).values
+ ) / cross_attention_probs.max(dim=-1, keepdim=True).values
+ cross_attention_probs = cross_attention_probs.clip(0, 1)
+
+ c = (1 - m) * cross_attention_probs.reshape(batch_size, 1, -1) + m
+ c = c.repeat_interleave(attn.heads, 0) # BD, HW
+ if self.do_classifier_free_guidance:
+ c = torch.cat([c, c]) # 2BD, HW
+
+ # Rescaling the original self-attention matrix
+ self_attention_scores_rescaled = self_attention_scores * c
+ self_attention_probs_rescaled = self_attention_scores_rescaled.softmax(dim=-1)
+
+ # Continuing the self attention normally using the new matrix
+ hidden_states = torch.bmm(self_attention_probs_rescaled, value)
+ hidden_states = attn.batch_to_head_dim(hidden_states)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + input_hidden_states
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ return hidden_states
+
+
+class StableDiffusionHDPainterPipeline(StableDiffusionInpaintPipeline):
+ def get_tokenized_prompt(self, prompt):
+ out = self.tokenizer(prompt)
+ return [self.tokenizer.decode(x) for x in out["input_ids"]]
+
+ def init_attn_processors(
+ self,
+ mask,
+ token_idx,
+ use_painta=True,
+ use_rasg=True,
+ painta_scale_factors=[2, 4], # 64x64 -> [16x16, 32x32]
+ rasg_scale_factor=4, # 64x64 -> 16x16
+ self_attention_layer_name="attn1",
+ cross_attention_layer_name="attn2",
+ list_of_painta_layer_names=None,
+ list_of_rasg_layer_names=None,
+ ):
+ default_processor = AttnProcessor()
+ width, height = mask.shape[-2:]
+ width, height = width // self.vae_scale_factor, height // self.vae_scale_factor
+
+ painta_scale_factors = [x * self.vae_scale_factor for x in painta_scale_factors]
+ rasg_scale_factor = self.vae_scale_factor * rasg_scale_factor
+
+ attn_processors = {}
+ for x in self.unet.attn_processors:
+ if (list_of_painta_layer_names is None and self_attention_layer_name in x) or (
+ list_of_painta_layer_names is not None and x in list_of_painta_layer_names
+ ):
+ if use_painta:
+ transformer_block = self.unet.get_submodule(x.replace(".attn1.processor", ""))
+ attn_processors[x] = PAIntAAttnProcessor(
+ transformer_block, mask, token_idx, self.do_classifier_free_guidance, painta_scale_factors
+ )
+ else:
+ attn_processors[x] = default_processor
+ elif (list_of_rasg_layer_names is None and cross_attention_layer_name in x) or (
+ list_of_rasg_layer_names is not None and x in list_of_rasg_layer_names
+ ):
+ if use_rasg:
+ attn_processors[x] = RASGAttnProcessor(mask, token_idx, rasg_scale_factor)
+ else:
+ attn_processors[x] = default_processor
+
+ self.unet.set_attn_processor(attn_processors)
+ # import json
+ # with open('/home/hayk.manukyan/repos/diffusers/debug.txt', 'a') as f:
+ # json.dump({x:str(y) for x,y in self.unet.attn_processors.items()}, f, indent=4)
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ image: PipelineImageInput = None,
+ mask_image: PipelineImageInput = None,
+ masked_image_latents: torch.Tensor = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ padding_mask_crop: Optional[int] = None,
+ strength: float = 1.0,
+ num_inference_steps: int = 50,
+ timesteps: List[int] = None,
+ guidance_scale: float = 7.5,
+ positive_prompt: Optional[str] = "",
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.01,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ ip_adapter_image: Optional[PipelineImageInput] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ clip_skip: int = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ use_painta=True,
+ use_rasg=True,
+ self_attention_layer_name=".attn1",
+ cross_attention_layer_name=".attn2",
+ painta_scale_factors=[2, 4], # 16 x 16 and 32 x 32
+ rasg_scale_factor=4, # 16x16 by default
+ list_of_painta_layer_names=None,
+ list_of_rasg_layer_names=None,
+ **kwargs,
+ ):
+ callback = kwargs.pop("callback", None)
+ callback_steps = kwargs.pop("callback_steps", None)
+
+ if callback is not None:
+ deprecate(
+ "callback",
+ "1.0.0",
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
+ )
+ if callback_steps is not None:
+ deprecate(
+ "callback_steps",
+ "1.0.0",
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
+ )
+
+ # 0. Default height and width to unet
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+ #
+ prompt_no_positives = prompt
+ if isinstance(prompt, list):
+ prompt = [x + positive_prompt for x in prompt]
+ else:
+ prompt = prompt + positive_prompt
+
+ # 1. Check inputs
+ self.check_inputs(
+ prompt,
+ image,
+ mask_image,
+ height,
+ width,
+ strength,
+ callback_steps,
+ negative_prompt,
+ prompt_embeds,
+ negative_prompt_embeds,
+ callback_on_step_end_tensor_inputs,
+ padding_mask_crop,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._clip_skip = clip_skip
+ self._cross_attention_kwargs = cross_attention_kwargs
+ self._interrupt = False
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # assert batch_size == 1, "Does not work with batch size > 1 currently"
+
+ device = self._execution_device
+
+ # 3. Encode input prompt
+ text_encoder_lora_scale = (
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
+ )
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
+ prompt,
+ device,
+ num_images_per_prompt,
+ self.do_classifier_free_guidance,
+ negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ lora_scale=text_encoder_lora_scale,
+ clip_skip=self.clip_skip,
+ )
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ if self.do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+
+ if ip_adapter_image is not None:
+ output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
+ image_embeds, negative_image_embeds = self.encode_image(
+ ip_adapter_image, device, num_images_per_prompt, output_hidden_state
+ )
+ if self.do_classifier_free_guidance:
+ image_embeds = torch.cat([negative_image_embeds, image_embeds])
+
+ # 4. set timesteps
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
+ timesteps, num_inference_steps = self.get_timesteps(
+ num_inference_steps=num_inference_steps, strength=strength, device=device
+ )
+ # check that number of inference steps is not < 1 - as this doesn't make sense
+ if num_inference_steps < 1:
+ raise ValueError(
+ f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline"
+ f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
+ )
+ # at which timestep to set the initial noise (n.b. 50% if strength is 0.5)
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
+ # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
+ is_strength_max = strength == 1.0
+
+ # 5. Preprocess mask and image
+
+ if padding_mask_crop is not None:
+ crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop)
+ resize_mode = "fill"
+ else:
+ crops_coords = None
+ resize_mode = "default"
+
+ original_image = image
+ init_image = self.image_processor.preprocess(
+ image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode
+ )
+ init_image = init_image.to(dtype=torch.float32)
+
+ # 6. Prepare latent variables
+ num_channels_latents = self.vae.config.latent_channels
+ num_channels_unet = self.unet.config.in_channels
+ return_image_latents = num_channels_unet == 4
+
+ latents_outputs = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ image=init_image,
+ timestep=latent_timestep,
+ is_strength_max=is_strength_max,
+ return_noise=True,
+ return_image_latents=return_image_latents,
+ )
+
+ if return_image_latents:
+ latents, noise, image_latents = latents_outputs
+ else:
+ latents, noise = latents_outputs
+
+ # 7. Prepare mask latent variables
+ mask_condition = self.mask_processor.preprocess(
+ mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords
+ )
+
+ if masked_image_latents is None:
+ masked_image = init_image * (mask_condition < 0.5)
+ else:
+ masked_image = masked_image_latents
+
+ mask, masked_image_latents = self.prepare_mask_latents(
+ mask_condition,
+ masked_image,
+ batch_size * num_images_per_prompt,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ self.do_classifier_free_guidance,
+ )
+
+ # 7.5 Setting up HD-Painter
+
+ # Get the indices of the tokens to be modified by both RASG and PAIntA
+ token_idx = list(range(1, self.get_tokenized_prompt(prompt_no_positives).index("<|endoftext|>"))) + [
+ self.get_tokenized_prompt(prompt).index("<|endoftext|>")
+ ]
+
+ # Setting up the attention processors
+ self.init_attn_processors(
+ mask_condition,
+ token_idx,
+ use_painta,
+ use_rasg,
+ painta_scale_factors=painta_scale_factors,
+ rasg_scale_factor=rasg_scale_factor,
+ self_attention_layer_name=self_attention_layer_name,
+ cross_attention_layer_name=cross_attention_layer_name,
+ list_of_painta_layer_names=list_of_painta_layer_names,
+ list_of_rasg_layer_names=list_of_rasg_layer_names,
+ )
+
+ # 8. Check that sizes of mask, masked image and latents match
+ if num_channels_unet == 9:
+ # default case for runwayml/stable-diffusion-inpainting
+ num_channels_mask = mask.shape[1]
+ num_channels_masked_image = masked_image_latents.shape[1]
+ if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels:
+ raise ValueError(
+ f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
+ f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
+ f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
+ f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
+ " `pipeline.unet` or your `mask_image` or `image` input."
+ )
+ elif num_channels_unet != 4:
+ raise ValueError(
+ f"The unet {self.unet.__class__} should have either 4 or 9 input channels, not {self.unet.config.in_channels}."
+ )
+
+ # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ if use_rasg:
+ extra_step_kwargs["generator"] = None
+
+ # 9.1 Add image embeds for IP-Adapter
+ added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
+
+ # 9.2 Optionally get Guidance Scale Embedding
+ timestep_cond = None
+ if self.unet.config.time_cond_proj_dim is not None:
+ guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
+ timestep_cond = self.get_guidance_scale_embedding(
+ guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
+ ).to(device=device, dtype=latents.dtype)
+
+ # 10. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ self._num_timesteps = len(timesteps)
+ painta_active = True
+
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ if t < 500 and painta_active:
+ self.init_attn_processors(
+ mask_condition,
+ token_idx,
+ False,
+ use_rasg,
+ painta_scale_factors=painta_scale_factors,
+ rasg_scale_factor=rasg_scale_factor,
+ self_attention_layer_name=self_attention_layer_name,
+ cross_attention_layer_name=cross_attention_layer_name,
+ list_of_painta_layer_names=list_of_painta_layer_names,
+ list_of_rasg_layer_names=list_of_rasg_layer_names,
+ )
+ painta_active = False
+
+ with torch.enable_grad():
+ self.unet.zero_grad()
+ latents = latents.detach()
+ latents.requires_grad = True
+
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+
+ # concat latents, mask, masked_image_latents in the channel dimension
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ if num_channels_unet == 9:
+ latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
+
+ self.scheduler.latents = latents
+ self.encoder_hidden_states = prompt_embeds
+ for attn_processor in self.unet.attn_processors.values():
+ attn_processor.encoder_hidden_states = prompt_embeds
+
+ # predict the noise residual
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ timestep_cond=timestep_cond,
+ cross_attention_kwargs=self.cross_attention_kwargs,
+ added_cond_kwargs=added_cond_kwargs,
+ return_dict=False,
+ )[0]
+
+ # perform guidance
+ if self.do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ if use_rasg:
+ # Perform RASG
+ _, _, height, width = mask_condition.shape # 512 x 512
+ scale_factor = self.vae_scale_factor * rasg_scale_factor # 8 * 4 = 32
+
+ # TODO: Fix for > 1 batch_size
+ rasg_mask = F.interpolate(
+ mask_condition, (height // scale_factor, width // scale_factor), mode="bicubic"
+ )[0, 0] # mode is nearest by default, B, H, W
+
+ # Aggregate the saved attention maps
+ attn_map = []
+ for processor in self.unet.attn_processors.values():
+ if hasattr(processor, "attention_scores") and processor.attention_scores is not None:
+ if self.do_classifier_free_guidance:
+ attn_map.append(processor.attention_scores.chunk(2)[1]) # (B/2) x H, 256, 77
+ else:
+ attn_map.append(processor.attention_scores) # B x H, 256, 77 ?
+
+ attn_map = (
+ torch.cat(attn_map)
+ .mean(0)
+ .permute(1, 0)
+ .reshape((-1, height // scale_factor, width // scale_factor))
+ ) # 77, 16, 16
+
+ # Compute the attention score
+ attn_score = -sum(
+ [
+ F.binary_cross_entropy_with_logits(x - 1.0, rasg_mask.to(device))
+ for x in attn_map[token_idx]
+ ]
+ )
+
+ # Backward the score and compute the gradients
+ attn_score.backward()
+
+ # Normalzie the gradients and compute the noise component
+ variance_noise = latents.grad.detach()
+ # print("VARIANCE SHAPE", variance_noise.shape)
+ variance_noise -= torch.mean(variance_noise, [1, 2, 3], keepdim=True)
+ variance_noise /= torch.std(variance_noise, [1, 2, 3], keepdim=True)
+ else:
+ variance_noise = None
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(
+ noise_pred, t, latents, **extra_step_kwargs, return_dict=False, variance_noise=variance_noise
+ )[0]
+
+ if num_channels_unet == 4:
+ init_latents_proper = image_latents
+ if self.do_classifier_free_guidance:
+ init_mask, _ = mask.chunk(2)
+ else:
+ init_mask = mask
+
+ if i < len(timesteps) - 1:
+ noise_timestep = timesteps[i + 1]
+ init_latents_proper = self.scheduler.add_noise(
+ init_latents_proper, noise, torch.tensor([noise_timestep])
+ )
+
+ latents = (1 - init_mask) * init_latents_proper + init_mask * latents
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+ mask = callback_outputs.pop("mask", mask)
+ masked_image_latents = callback_outputs.pop("masked_image_latents", masked_image_latents)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
+
+ if not output_type == "latent":
+ condition_kwargs = {}
+ if isinstance(self.vae, AsymmetricAutoencoderKL):
+ init_image = init_image.to(device=device, dtype=masked_image_latents.dtype)
+ init_image_condition = init_image.clone()
+ init_image = self._encode_vae_image(init_image, generator=generator)
+ mask_condition = mask_condition.to(device=device, dtype=masked_image_latents.dtype)
+ condition_kwargs = {"image": init_image_condition, "mask": mask_condition}
+ image = self.vae.decode(
+ latents / self.vae.config.scaling_factor, return_dict=False, generator=generator, **condition_kwargs
+ )[0]
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
+ else:
+ image = latents
+ has_nsfw_concept = None
+
+ if has_nsfw_concept is None:
+ do_denormalize = [True] * image.shape[0]
+ else:
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
+
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
+
+ if padding_mask_crop is not None:
+ image = [self.image_processor.apply_overlay(mask_image, original_image, i, crops_coords) for i in image]
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
+
+
+# ============= Utility Functions ============== #
+
+
+class GaussianSmoothing(nn.Module):
+ """
+ Apply gaussian smoothing on a
+ 1d, 2d or 3d tensor. Filtering is performed seperately for each channel
+ in the input using a depthwise convolution.
+ Arguments:
+ channels (int, sequence): Number of channels of the input tensors. Output will
+ have this number of channels as well.
+ kernel_size (int, sequence): Size of the gaussian kernel.
+ sigma (float, sequence): Standard deviation of the gaussian kernel.
+ dim (int, optional): The number of dimensions of the data.
+ Default value is 2 (spatial).
+ """
+
+ def __init__(self, channels, kernel_size, sigma, dim=2):
+ super(GaussianSmoothing, self).__init__()
+ if isinstance(kernel_size, numbers.Number):
+ kernel_size = [kernel_size] * dim
+ if isinstance(sigma, numbers.Number):
+ sigma = [sigma] * dim
+
+ # The gaussian kernel is the product of the
+ # gaussian function of each dimension.
+ kernel = 1
+ meshgrids = torch.meshgrid([torch.arange(size, dtype=torch.float32) for size in kernel_size])
+ for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
+ mean = (size - 1) / 2
+ kernel *= 1 / (std * math.sqrt(2 * math.pi)) * torch.exp(-(((mgrid - mean) / (2 * std)) ** 2))
+
+ # Make sure sum of values in gaussian kernel equals 1.
+ kernel = kernel / torch.sum(kernel)
+
+ # Reshape to depthwise convolutional weight
+ kernel = kernel.view(1, 1, *kernel.size())
+ kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1))
+
+ self.register_buffer("weight", kernel)
+ self.groups = channels
+
+ if dim == 1:
+ self.conv = F.conv1d
+ elif dim == 2:
+ self.conv = F.conv2d
+ elif dim == 3:
+ self.conv = F.conv3d
+ else:
+ raise RuntimeError("Only 1, 2 and 3 dimensions are supported. Received {}.".format(dim))
+
+ def forward(self, input):
+ """
+ Apply gaussian filter to input.
+ Arguments:
+ input (torch.Tensor): Input to apply gaussian filter on.
+ Returns:
+ filtered (torch.Tensor): Filtered output.
+ """
+ return self.conv(input, weight=self.weight.to(input.dtype), groups=self.groups, padding="same")
+
+
+def get_attention_scores(
+ self, query: torch.Tensor, key: torch.Tensor, attention_mask: torch.Tensor = None
+) -> torch.Tensor:
+ r"""
+ Compute the attention scores.
+
+ Args:
+ query (`torch.Tensor`): The query tensor.
+ key (`torch.Tensor`): The key tensor.
+ attention_mask (`torch.Tensor`, *optional*): The attention mask to use. If `None`, no mask is applied.
+
+ Returns:
+ `torch.Tensor`: The attention probabilities/scores.
+ """
+ if self.upcast_attention:
+ query = query.float()
+ key = key.float()
+
+ if attention_mask is None:
+ baddbmm_input = torch.empty(
+ query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device
+ )
+ beta = 0
+ else:
+ baddbmm_input = attention_mask
+ beta = 1
+
+ attention_scores = torch.baddbmm(
+ baddbmm_input,
+ query,
+ key.transpose(-1, -2),
+ beta=beta,
+ alpha=self.scale,
+ )
+ del baddbmm_input
+
+ if self.upcast_softmax:
+ attention_scores = attention_scores.float()
+
+ return attention_scores
diff --git a/diffusers/examples/community/iadb.py b/diffusers/examples/community/iadb.py
new file mode 100644
index 0000000000000000000000000000000000000000..81e9e8d89dc6724fd9deb6e9d9f2b757d4cfb776
--- /dev/null
+++ b/diffusers/examples/community/iadb.py
@@ -0,0 +1,149 @@
+from typing import List, Optional, Tuple, Union
+
+import torch
+
+from diffusers import DiffusionPipeline
+from diffusers.configuration_utils import ConfigMixin
+from diffusers.pipelines.pipeline_utils import ImagePipelineOutput
+from diffusers.schedulers.scheduling_utils import SchedulerMixin
+
+
+class IADBScheduler(SchedulerMixin, ConfigMixin):
+ """
+ IADBScheduler is a scheduler for the Iterative α-(de)Blending denoising method. It is simple and minimalist.
+
+ For more details, see the original paper: https://arxiv.org/abs/2305.03486 and the blog post: https://ggx-research.github.io/publication/2023/05/10/publication-iadb.html
+ """
+
+ def step(
+ self,
+ model_output: torch.Tensor,
+ timestep: int,
+ x_alpha: torch.Tensor,
+ ) -> torch.Tensor:
+ """
+ Predict the sample at the previous timestep by reversing the ODE. Core function to propagate the diffusion
+ process from the learned model outputs (most often the predicted noise).
+
+ Args:
+ model_output (`torch.Tensor`): direct output from learned diffusion model. It is the direction from x0 to x1.
+ timestep (`float`): current timestep in the diffusion chain.
+ x_alpha (`torch.Tensor`): x_alpha sample for the current timestep
+
+ Returns:
+ `torch.Tensor`: the sample at the previous timestep
+
+ """
+ if self.num_inference_steps is None:
+ raise ValueError(
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
+ )
+
+ alpha = timestep / self.num_inference_steps
+ alpha_next = (timestep + 1) / self.num_inference_steps
+
+ d = model_output
+
+ x_alpha = x_alpha + (alpha_next - alpha) * d
+
+ return x_alpha
+
+ def set_timesteps(self, num_inference_steps: int):
+ self.num_inference_steps = num_inference_steps
+
+ def add_noise(
+ self,
+ original_samples: torch.Tensor,
+ noise: torch.Tensor,
+ alpha: torch.Tensor,
+ ) -> torch.Tensor:
+ return original_samples * alpha + noise * (1 - alpha)
+
+ def __len__(self):
+ return self.config.num_train_timesteps
+
+
+class IADBPipeline(DiffusionPipeline):
+ r"""
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Parameters:
+ unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image. Can be one of
+ [`DDPMScheduler`], or [`DDIMScheduler`].
+ """
+
+ def __init__(self, unet, scheduler):
+ super().__init__()
+
+ self.register_modules(unet=unet, scheduler=scheduler)
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ batch_size: int = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ num_inference_steps: int = 50,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ ) -> Union[ImagePipelineOutput, Tuple]:
+ r"""
+ Args:
+ batch_size (`int`, *optional*, defaults to 1):
+ The number of images to generate.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if `return_dict` is
+ True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images.
+ """
+
+ # Sample gaussian noise to begin loop
+ if isinstance(self.unet.config.sample_size, int):
+ image_shape = (
+ batch_size,
+ self.unet.config.in_channels,
+ self.unet.config.sample_size,
+ self.unet.config.sample_size,
+ )
+ else:
+ image_shape = (batch_size, self.unet.config.in_channels, *self.unet.config.sample_size)
+
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ image = torch.randn(image_shape, generator=generator, device=self.device, dtype=self.unet.dtype)
+
+ # set step values
+ self.scheduler.set_timesteps(num_inference_steps)
+ x_alpha = image.clone()
+ for t in self.progress_bar(range(num_inference_steps)):
+ alpha = t / num_inference_steps
+
+ # 1. predict noise model_output
+ model_output = self.unet(x_alpha, torch.tensor(alpha, device=x_alpha.device)).sample
+
+ # 2. step
+ x_alpha = self.scheduler.step(model_output, t, x_alpha)
+
+ image = (x_alpha * 0.5 + 0.5).clamp(0, 1)
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image,)
+
+ return ImagePipelineOutput(images=image)
diff --git a/diffusers/examples/community/imagic_stable_diffusion.py b/diffusers/examples/community/imagic_stable_diffusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..cea55dd3837bdac5bf973354af4d317ae7c16183
--- /dev/null
+++ b/diffusers/examples/community/imagic_stable_diffusion.py
@@ -0,0 +1,470 @@
+"""
+modeled after the textual_inversion.py / train_dreambooth.py and the work
+of justinpinkney here: https://github.com/justinpinkney/stable-diffusion/blob/main/notebooks/imagic.ipynb
+"""
+
+import inspect
+import warnings
+from typing import List, Optional, Union
+
+import numpy as np
+import PIL.Image
+import torch
+import torch.nn.functional as F
+from accelerate import Accelerator
+
+# TODO: remove and import from diffusers.utils when the new version of diffusers is released
+from packaging import version
+from tqdm.auto import tqdm
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
+
+from diffusers import DiffusionPipeline
+from diffusers.models import AutoencoderKL, UNet2DConditionModel
+from diffusers.pipelines.pipeline_utils import StableDiffusionMixin
+from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
+from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
+from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
+from diffusers.utils import logging
+
+
+if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
+ PIL_INTERPOLATION = {
+ "linear": PIL.Image.Resampling.BILINEAR,
+ "bilinear": PIL.Image.Resampling.BILINEAR,
+ "bicubic": PIL.Image.Resampling.BICUBIC,
+ "lanczos": PIL.Image.Resampling.LANCZOS,
+ "nearest": PIL.Image.Resampling.NEAREST,
+ }
+else:
+ PIL_INTERPOLATION = {
+ "linear": PIL.Image.LINEAR,
+ "bilinear": PIL.Image.BILINEAR,
+ "bicubic": PIL.Image.BICUBIC,
+ "lanczos": PIL.Image.LANCZOS,
+ "nearest": PIL.Image.NEAREST,
+ }
+# ------------------------------------------------------------------------------
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+def preprocess(image):
+ w, h = image.size
+ w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32
+ image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
+ image = np.array(image).astype(np.float32) / 255.0
+ image = image[None].transpose(0, 3, 1, 2)
+ image = torch.from_numpy(image)
+ return 2.0 * image - 1.0
+
+
+class ImagicStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin):
+ r"""
+ Pipeline for imagic image editing.
+ See paper here: https://arxiv.org/pdf/2210.09276.pdf
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offsensive or harmful.
+ Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
+ feature_extractor ([`CLIPImageProcessor`]):
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+ """
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPImageProcessor,
+ ):
+ super().__init__()
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+
+ def train(
+ self,
+ prompt: Union[str, List[str]],
+ image: Union[torch.Tensor, PIL.Image.Image],
+ height: Optional[int] = 512,
+ width: Optional[int] = 512,
+ generator: Optional[torch.Generator] = None,
+ embedding_learning_rate: float = 0.001,
+ diffusion_model_learning_rate: float = 2e-6,
+ text_embedding_optimization_steps: int = 500,
+ model_fine_tuning_optimization_steps: int = 1000,
+ **kwargs,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+ Args:
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image generation.
+ height (`int`, *optional*, defaults to 512):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to 512):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator`, *optional*):
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
+ deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+ accelerator = Accelerator(
+ gradient_accumulation_steps=1,
+ mixed_precision="fp16",
+ )
+
+ if "torch_device" in kwargs:
+ device = kwargs.pop("torch_device")
+ warnings.warn(
+ "`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0."
+ " Consider using `pipe.to(torch_device)` instead."
+ )
+
+ if device is None:
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ self.to(device)
+
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ # Freeze vae and unet
+ self.vae.requires_grad_(False)
+ self.unet.requires_grad_(False)
+ self.text_encoder.requires_grad_(False)
+ self.unet.eval()
+ self.vae.eval()
+ self.text_encoder.eval()
+
+ if accelerator.is_main_process:
+ accelerator.init_trackers(
+ "imagic",
+ config={
+ "embedding_learning_rate": embedding_learning_rate,
+ "text_embedding_optimization_steps": text_embedding_optimization_steps,
+ },
+ )
+
+ # get text embeddings for prompt
+ text_input = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_embeddings = torch.nn.Parameter(
+ self.text_encoder(text_input.input_ids.to(self.device))[0], requires_grad=True
+ )
+ text_embeddings = text_embeddings.detach()
+ text_embeddings.requires_grad_()
+ text_embeddings_orig = text_embeddings.clone()
+
+ # Initialize the optimizer
+ optimizer = torch.optim.Adam(
+ [text_embeddings], # only optimize the embeddings
+ lr=embedding_learning_rate,
+ )
+
+ if isinstance(image, PIL.Image.Image):
+ image = preprocess(image)
+
+ latents_dtype = text_embeddings.dtype
+ image = image.to(device=self.device, dtype=latents_dtype)
+ init_latent_image_dist = self.vae.encode(image).latent_dist
+ image_latents = init_latent_image_dist.sample(generator=generator)
+ image_latents = 0.18215 * image_latents
+
+ progress_bar = tqdm(range(text_embedding_optimization_steps), disable=not accelerator.is_local_main_process)
+ progress_bar.set_description("Steps")
+
+ global_step = 0
+
+ logger.info("First optimizing the text embedding to better reconstruct the init image")
+ for _ in range(text_embedding_optimization_steps):
+ with accelerator.accumulate(text_embeddings):
+ # Sample noise that we'll add to the latents
+ noise = torch.randn(image_latents.shape).to(image_latents.device)
+ timesteps = torch.randint(1000, (1,), device=image_latents.device)
+
+ # Add noise to the latents according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ noisy_latents = self.scheduler.add_noise(image_latents, noise, timesteps)
+
+ # Predict the noise residual
+ noise_pred = self.unet(noisy_latents, timesteps, text_embeddings).sample
+
+ loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
+ accelerator.backward(loss)
+
+ optimizer.step()
+ optimizer.zero_grad()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ global_step += 1
+
+ logs = {"loss": loss.detach().item()} # , "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+ accelerator.log(logs, step=global_step)
+
+ accelerator.wait_for_everyone()
+
+ text_embeddings.requires_grad_(False)
+
+ # Now we fine tune the unet to better reconstruct the image
+ self.unet.requires_grad_(True)
+ self.unet.train()
+ optimizer = torch.optim.Adam(
+ self.unet.parameters(), # only optimize unet
+ lr=diffusion_model_learning_rate,
+ )
+ progress_bar = tqdm(range(model_fine_tuning_optimization_steps), disable=not accelerator.is_local_main_process)
+
+ logger.info("Next fine tuning the entire model to better reconstruct the init image")
+ for _ in range(model_fine_tuning_optimization_steps):
+ with accelerator.accumulate(self.unet.parameters()):
+ # Sample noise that we'll add to the latents
+ noise = torch.randn(image_latents.shape).to(image_latents.device)
+ timesteps = torch.randint(1000, (1,), device=image_latents.device)
+
+ # Add noise to the latents according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ noisy_latents = self.scheduler.add_noise(image_latents, noise, timesteps)
+
+ # Predict the noise residual
+ noise_pred = self.unet(noisy_latents, timesteps, text_embeddings).sample
+
+ loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
+ accelerator.backward(loss)
+
+ optimizer.step()
+ optimizer.zero_grad()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ global_step += 1
+
+ logs = {"loss": loss.detach().item()} # , "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+ accelerator.log(logs, step=global_step)
+
+ accelerator.wait_for_everyone()
+ self.text_embeddings_orig = text_embeddings_orig
+ self.text_embeddings = text_embeddings
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ alpha: float = 1.2,
+ height: Optional[int] = 512,
+ width: Optional[int] = 512,
+ num_inference_steps: Optional[int] = 50,
+ generator: Optional[torch.Generator] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ guidance_scale: float = 7.5,
+ eta: float = 0.0,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+ Args:
+ alpha (`float`, *optional*, defaults to 1.2):
+ The interpolation factor between the original and optimized text embeddings. A value closer to 0
+ will resemble the original input image.
+ height (`int`, *optional*, defaults to 512):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to 512):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ generator (`torch.Generator`, *optional*):
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
+ deterministic.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+ if self.text_embeddings is None:
+ raise ValueError("Please run the pipe.train() before trying to generate an image.")
+ if self.text_embeddings_orig is None:
+ raise ValueError("Please run the pipe.train() before trying to generate an image.")
+
+ text_embeddings = alpha * self.text_embeddings_orig + (1 - alpha) * self.text_embeddings
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance:
+ uncond_tokens = [""]
+ max_length = self.tokenizer.model_max_length
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
+
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = uncond_embeddings.shape[1]
+ uncond_embeddings = uncond_embeddings.view(1, seq_len, -1)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
+
+ # get the initial random noise unless the user supplied it
+
+ # Unlike in other pipelines, latents need to be generated in the target device
+ # for 1-to-1 results reproducibility with the CompVis implementation.
+ # However this currently doesn't work in `mps`.
+ latents_shape = (1, self.unet.config.in_channels, height // 8, width // 8)
+ latents_dtype = text_embeddings.dtype
+ if self.device.type == "mps":
+ # randn does not exist on mps
+ latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to(
+ self.device
+ )
+ else:
+ latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)
+
+ # set timesteps
+ self.scheduler.set_timesteps(num_inference_steps)
+
+ # Some schedulers like PNDM have timesteps as arrays
+ # It's more optimized to move all timesteps to correct device beforehand
+ timesteps_tensor = self.scheduler.timesteps.to(self.device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ for i, t in enumerate(self.progress_bar(timesteps_tensor)):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+ latents = 1 / 0.18215 * latents
+ image = self.vae.decode(latents).sample
+
+ image = (image / 2 + 0.5).clamp(0, 1)
+
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+
+ if self.safety_checker is not None:
+ safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(
+ self.device
+ )
+ image, has_nsfw_concept = self.safety_checker(
+ images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)
+ )
+ else:
+ has_nsfw_concept = None
+
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
diff --git a/diffusers/examples/community/img2img_inpainting.py b/diffusers/examples/community/img2img_inpainting.py
new file mode 100644
index 0000000000000000000000000000000000000000..4dfb7a39155fc2052312fa03c13d56541add13ea
--- /dev/null
+++ b/diffusers/examples/community/img2img_inpainting.py
@@ -0,0 +1,437 @@
+import inspect
+from typing import Callable, List, Optional, Tuple, Union
+
+import numpy as np
+import PIL.Image
+import torch
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
+
+from diffusers import DiffusionPipeline
+from diffusers.configuration_utils import FrozenDict
+from diffusers.models import AutoencoderKL, UNet2DConditionModel
+from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
+from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
+from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
+from diffusers.utils import deprecate, logging
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+def prepare_mask_and_masked_image(image, mask):
+ image = np.array(image.convert("RGB"))
+ image = image[None].transpose(0, 3, 1, 2)
+ image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
+
+ mask = np.array(mask.convert("L"))
+ mask = mask.astype(np.float32) / 255.0
+ mask = mask[None, None]
+ mask[mask < 0.5] = 0
+ mask[mask >= 0.5] = 1
+ mask = torch.from_numpy(mask)
+
+ masked_image = image * (mask < 0.5)
+
+ return mask, masked_image
+
+
+def check_size(image, height, width):
+ if isinstance(image, PIL.Image.Image):
+ w, h = image.size
+ elif isinstance(image, torch.Tensor):
+ *_, h, w = image.shape
+
+ if h != height or w != width:
+ raise ValueError(f"Image size should be {height}x{width}, but got {h}x{w}")
+
+
+def overlay_inner_image(image, inner_image, paste_offset: Tuple[int] = (0, 0)):
+ inner_image = inner_image.convert("RGBA")
+ image = image.convert("RGB")
+
+ image.paste(inner_image, paste_offset, inner_image)
+ image = image.convert("RGB")
+
+ return image
+
+
+class ImageToImageInpaintingPipeline(DiffusionPipeline):
+ r"""
+ Pipeline for text-guided image-to-image inpainting using Stable Diffusion. *This is an experimental feature*.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ feature_extractor ([`CLIPImageProcessor`]):
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+ """
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPImageProcessor,
+ ):
+ super().__init__()
+
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
+ " file"
+ )
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["steps_offset"] = 1
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if safety_checker is None:
+ logger.warning(
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
+ )
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]],
+ image: Union[torch.Tensor, PIL.Image.Image],
+ inner_image: Union[torch.Tensor, PIL.Image.Image],
+ mask_image: Union[torch.Tensor, PIL.Image.Image],
+ height: int = 512,
+ width: int = 512,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[torch.Generator] = None,
+ latents: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
+ callback_steps: int = 1,
+ **kwargs,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image generation.
+ image (`torch.Tensor` or `PIL.Image.Image`):
+ `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will
+ be masked out with `mask_image` and repainted according to `prompt`.
+ inner_image (`torch.Tensor` or `PIL.Image.Image`):
+ `Image`, or tensor representing an image batch which will be overlayed onto `image`. Non-transparent
+ regions of `inner_image` must fit inside white pixels in `mask_image`. Expects four channels, with
+ the last channel representing the alpha channel, which will be used to blend `inner_image` with
+ `image`. If not provided, it will be forcibly cast to RGBA.
+ mask_image (`PIL.Image.Image`):
+ `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
+ repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted
+ to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)
+ instead of 3, so the expected shape would be `(B, H, W, 1)`.
+ height (`int`, *optional*, defaults to 512):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to 512):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator`, *optional*):
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
+ deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+
+ if isinstance(prompt, str):
+ batch_size = 1
+ elif isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ # check if input sizes are correct
+ check_size(image, height, width)
+ check_size(inner_image, height, width)
+ check_size(mask_image, height, width)
+
+ # get prompt text embeddings
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+
+ if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
+ removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+ text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
+ text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ bs_embed, seq_len, _ = text_embeddings.shape
+ text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
+ text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""]
+ elif type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ max_length = text_input_ids.shape[-1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
+
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = uncond_embeddings.shape[1]
+ uncond_embeddings = uncond_embeddings.repeat(batch_size, num_images_per_prompt, 1)
+ uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
+
+ # get the initial random noise unless the user supplied it
+ # Unlike in other pipelines, latents need to be generated in the target device
+ # for 1-to-1 results reproducibility with the CompVis implementation.
+ # However this currently doesn't work in `mps`.
+ num_channels_latents = self.vae.config.latent_channels
+ latents_shape = (batch_size * num_images_per_prompt, num_channels_latents, height // 8, width // 8)
+ latents_dtype = text_embeddings.dtype
+ if latents is None:
+ if self.device.type == "mps":
+ # randn does not exist on mps
+ latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to(
+ self.device
+ )
+ else:
+ latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)
+ else:
+ if latents.shape != latents_shape:
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
+ latents = latents.to(self.device)
+
+ # overlay the inner image
+ image = overlay_inner_image(image, inner_image)
+
+ # prepare mask and masked_image
+ mask, masked_image = prepare_mask_and_masked_image(image, mask_image)
+ mask = mask.to(device=self.device, dtype=text_embeddings.dtype)
+ masked_image = masked_image.to(device=self.device, dtype=text_embeddings.dtype)
+
+ # resize the mask to latents shape as we concatenate the mask to the latents
+ mask = torch.nn.functional.interpolate(mask, size=(height // 8, width // 8))
+
+ # encode the mask image into latents space so we can concatenate it to the latents
+ masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(generator=generator)
+ masked_image_latents = 0.18215 * masked_image_latents
+
+ # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
+ mask = mask.repeat(batch_size * num_images_per_prompt, 1, 1, 1)
+ masked_image_latents = masked_image_latents.repeat(batch_size * num_images_per_prompt, 1, 1, 1)
+
+ mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
+ masked_image_latents = (
+ torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents
+ )
+
+ num_channels_mask = mask.shape[1]
+ num_channels_masked_image = masked_image_latents.shape[1]
+
+ if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels:
+ raise ValueError(
+ f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
+ f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
+ f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
+ f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
+ " `pipeline.unet` or your `mask_image` or `image` input."
+ )
+
+ # set timesteps
+ self.scheduler.set_timesteps(num_inference_steps)
+
+ # Some schedulers like PNDM have timesteps as arrays
+ # It's more optimized to move all timesteps to correct device beforehand
+ timesteps_tensor = self.scheduler.timesteps.to(self.device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ for i, t in enumerate(self.progress_bar(timesteps_tensor)):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+
+ # concat latents, mask, masked_image_latents in the channel dimension
+ latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
+
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+ # call the callback, if provided
+ if callback is not None and i % callback_steps == 0:
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
+
+ latents = 1 / 0.18215 * latents
+ image = self.vae.decode(latents).sample
+
+ image = (image / 2 + 0.5).clamp(0, 1)
+
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+
+ if self.safety_checker is not None:
+ safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(
+ self.device
+ )
+ image, has_nsfw_concept = self.safety_checker(
+ images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)
+ )
+ else:
+ has_nsfw_concept = None
+
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
diff --git a/diffusers/examples/community/instaflow_one_step.py b/diffusers/examples/community/instaflow_one_step.py
new file mode 100644
index 0000000000000000000000000000000000000000..3fef02287186d37d87384f4e235eb2efa14bea21
--- /dev/null
+++ b/diffusers/examples/community/instaflow_one_step.py
@@ -0,0 +1,689 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import torch
+from packaging import version
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
+
+from diffusers.configuration_utils import FrozenDict
+from diffusers.image_processor import VaeImageProcessor
+from diffusers.loaders import FromSingleFileMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
+from diffusers.models import AutoencoderKL, UNet2DConditionModel
+from diffusers.models.lora import adjust_lora_scale_text_encoder
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
+from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
+from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
+from diffusers.schedulers import KarrasDiffusionSchedulers
+from diffusers.utils import (
+ deprecate,
+ logging,
+)
+from diffusers.utils.torch_utils import randn_tensor
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
+ """
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
+ """
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
+ # rescale the results from guidance (fixes overexposure)
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
+ return noise_cfg
+
+
+class InstaFlowPipeline(
+ DiffusionPipeline,
+ StableDiffusionMixin,
+ TextualInversionLoaderMixin,
+ StableDiffusionLoraLoaderMixin,
+ FromSingleFileMixin,
+):
+ r"""
+ Pipeline for text-to-image generation using Rectified Flow and Euler discretization.
+ This customized pipeline is based on StableDiffusionPipeline from the official Diffusers library (0.21.4)
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
+
+ The pipeline also inherits the following loading methods:
+ - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
+ - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
+ - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
+ - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
+ text_encoder ([`~transformers.CLIPTextModel`]):
+ Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
+ tokenizer ([`~transformers.CLIPTokenizer`]):
+ A `CLIPTokenizer` to tokenize text.
+ unet ([`UNet2DConditionModel`]):
+ A `UNet2DConditionModel` to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
+ about a model's potential harms.
+ feature_extractor ([`~transformers.CLIPImageProcessor`]):
+ A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
+ """
+
+ model_cpu_offload_seq = "text_encoder->unet->vae"
+ _optional_components = ["safety_checker", "feature_extractor"]
+ _exclude_from_cpu_offload = ["safety_checker"]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: KarrasDiffusionSchedulers,
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPImageProcessor,
+ requires_safety_checker: bool = True,
+ ):
+ super().__init__()
+
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
+ " file"
+ )
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["steps_offset"] = 1
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
+ )
+ deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["clip_sample"] = False
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if safety_checker is None and requires_safety_checker:
+ logger.warning(
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
+ )
+
+ if safety_checker is not None and feature_extractor is None:
+ raise ValueError(
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
+ )
+
+ is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
+ version.parse(unet.config._diffusers_version).base_version
+ ) < version.parse("0.9.0.dev0")
+ is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
+ deprecation_message = (
+ "The configuration file of the unet has set the default `sample_size` to smaller than"
+ " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
+ " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
+ " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
+ " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
+ " in the config might lead to incorrect results in future versions. If you have downloaded this"
+ " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
+ " the `unet/config.json` file"
+ )
+ deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(unet.config)
+ new_config["sample_size"] = 64
+ unet._internal_dict = FrozenDict(new_config)
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
+
+ def _encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt=None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ lora_scale: Optional[float] = None,
+ ):
+ deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
+ deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
+
+ prompt_embeds_tuple = self.encode_prompt(
+ prompt=prompt,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ do_classifier_free_guidance=do_classifier_free_guidance,
+ negative_prompt=negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ lora_scale=lora_scale,
+ )
+
+ # concatenate for backwards comp
+ prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
+
+ return prompt_embeds
+
+ def encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt=None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ lora_scale: Optional[float] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ lora_scale (`float`, *optional*):
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ """
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ # dynamically adjust the LoRA scale
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
+
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = self.tokenizer.batch_decode(
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
+ )
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = text_inputs.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ prompt_embeds = self.text_encoder(
+ text_input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ prompt_embeds = prompt_embeds[0]
+
+ if self.text_encoder is not None:
+ prompt_embeds_dtype = self.text_encoder.dtype
+ elif self.unet is not None:
+ prompt_embeds_dtype = self.unet.dtype
+ else:
+ prompt_embeds_dtype = prompt_embeds.dtype
+
+ prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = uncond_input.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ negative_prompt_embeds = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ negative_prompt_embeds = negative_prompt_embeds[0]
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ return prompt_embeds, negative_prompt_embeds
+
+ def run_safety_checker(self, image, device, dtype):
+ if self.safety_checker is None:
+ has_nsfw_concept = None
+ else:
+ if torch.is_tensor(image):
+ feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
+ else:
+ feature_extractor_input = self.image_processor.numpy_to_pil(image)
+ safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
+ image, has_nsfw_concept = self.safety_checker(
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
+ )
+ return image, has_nsfw_concept
+
+ def decode_latents(self, latents):
+ deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
+ deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
+
+ latents = 1 / self.vae.config.scaling_factor * latents
+ image = self.vae.decode(latents, return_dict=False)[0]
+ image = (image / 2 + 0.5).clamp(0, 1)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+ return image
+
+ def merge_dW_to_unet(pipe, dW_dict, alpha=1.0):
+ _tmp_sd = pipe.unet.state_dict()
+ for key in dW_dict.keys():
+ _tmp_sd[key] += dW_dict[key] * alpha
+ pipe.unet.load_state_dict(_tmp_sd, strict=False)
+ return pipe
+
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ callback_steps,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ ):
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
+ shape = (
+ batch_size,
+ num_channels_latents,
+ int(height) // self.vae_scale_factor,
+ int(width) // self.vae_scale_factor,
+ )
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
+ callback_steps: int = 1,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ guidance_rescale: float = 0.0,
+ ):
+ r"""
+ The call function to the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
+ height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ A higher guidance scale value encourages the model to generate images closely linked to the text
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
+ generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor is generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
+ provided, text embeddings are generated from the `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that calls every `callback_steps` steps during inference. The function is called with the
+ following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function is called. If not specified, the callback is called at
+ every step.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ guidance_rescale (`float`, *optional*, defaults to 0.7):
+ Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when
+ using zero terminal SNR.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
+ otherwise a `tuple` is returned where the first element is a list with the generated images and the
+ second element is a list of `bool`s indicating whether the corresponding generated image contains
+ "not-safe-for-work" (nsfw) content.
+ """
+ # 0. Default height and width to unet
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
+ )
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ text_encoder_lora_scale = (
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
+ )
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ lora_scale=text_encoder_lora_scale,
+ )
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ if do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+
+ # 4. Prepare timesteps
+ timesteps = [(1.0 - i / num_inference_steps) * 1000.0 for i in range(num_inference_steps)]
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.unet.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ dt = 1.0 / num_inference_steps
+
+ # 7. Denoising loop of Euler discretization from t = 0 to t = 1
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+
+ vec_t = torch.ones((latent_model_input.shape[0],), device=latents.device) * t
+
+ v_pred = self.unet(latent_model_input, vec_t, encoder_hidden_states=prompt_embeds).sample
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ v_pred_neg, v_pred_text = v_pred.chunk(2)
+ v_pred = v_pred_neg + guidance_scale * (v_pred_text - v_pred_neg)
+
+ latents = latents + dt * v_pred
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
+
+ if not output_type == "latent":
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
+ else:
+ image = latents
+ has_nsfw_concept = None
+
+ if has_nsfw_concept is None:
+ do_denormalize = [True] * image.shape[0]
+ else:
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
+
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
diff --git a/diffusers/examples/community/interpolate_stable_diffusion.py b/diffusers/examples/community/interpolate_stable_diffusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..52b2707f33f7eb7eddc7e6339b30d80b97cd13e5
--- /dev/null
+++ b/diffusers/examples/community/interpolate_stable_diffusion.py
@@ -0,0 +1,498 @@
+import inspect
+import time
+from pathlib import Path
+from typing import Callable, List, Optional, Union
+
+import numpy as np
+import torch
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
+
+from diffusers.configuration_utils import FrozenDict
+from diffusers.models import AutoencoderKL, UNet2DConditionModel
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
+from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
+from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
+from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
+from diffusers.utils import deprecate, logging
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+def slerp(t, v0, v1, DOT_THRESHOLD=0.9995):
+ """helper function to spherically interpolate two arrays v1 v2"""
+
+ if not isinstance(v0, np.ndarray):
+ inputs_are_torch = True
+ input_device = v0.device
+ v0 = v0.cpu().numpy()
+ v1 = v1.cpu().numpy()
+
+ dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1)))
+ if np.abs(dot) > DOT_THRESHOLD:
+ v2 = (1 - t) * v0 + t * v1
+ else:
+ theta_0 = np.arccos(dot)
+ sin_theta_0 = np.sin(theta_0)
+ theta_t = theta_0 * t
+ sin_theta_t = np.sin(theta_t)
+ s0 = np.sin(theta_0 - theta_t) / sin_theta_0
+ s1 = sin_theta_t / sin_theta_0
+ v2 = s0 * v0 + s1 * v1
+
+ if inputs_are_torch:
+ v2 = torch.from_numpy(v2).to(input_device)
+
+ return v2
+
+
+class StableDiffusionWalkPipeline(DiffusionPipeline, StableDiffusionMixin):
+ r"""
+ Pipeline for text-to-image generation using Stable Diffusion.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
+ feature_extractor ([`CLIPImageProcessor`]):
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+ """
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPImageProcessor,
+ ):
+ super().__init__()
+
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
+ " file"
+ )
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["steps_offset"] = 1
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if safety_checker is None:
+ logger.warning(
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
+ )
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Optional[Union[str, List[str]]] = None,
+ height: int = 512,
+ width: int = 512,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[torch.Generator] = None,
+ latents: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
+ callback_steps: int = 1,
+ text_embeddings: Optional[torch.Tensor] = None,
+ **kwargs,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*, defaults to `None`):
+ The prompt or prompts to guide the image generation. If not provided, `text_embeddings` is required.
+ height (`int`, *optional*, defaults to 512):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to 512):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator`, *optional*):
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
+ deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ text_embeddings (`torch.Tensor`, *optional*, defaults to `None`):
+ Pre-generated text embeddings to be used as inputs for image generation. Can be used in place of
+ `prompt` to avoid re-computing the embeddings. If not provided, the embeddings will be generated from
+ the supplied `prompt`.
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if text_embeddings is None:
+ if isinstance(prompt, str):
+ batch_size = 1
+ elif isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ # get prompt text embeddings
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+
+ if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
+ removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
+ print(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+ text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
+ text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
+ else:
+ batch_size = text_embeddings.shape[0]
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ bs_embed, seq_len, _ = text_embeddings.shape
+ text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
+ text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ max_length = self.tokenizer.model_max_length
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
+
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = uncond_embeddings.shape[1]
+ uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
+ uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
+
+ # get the initial random noise unless the user supplied it
+
+ # Unlike in other pipelines, latents need to be generated in the target device
+ # for 1-to-1 results reproducibility with the CompVis implementation.
+ # However this currently doesn't work in `mps`.
+ latents_shape = (batch_size * num_images_per_prompt, self.unet.config.in_channels, height // 8, width // 8)
+ latents_dtype = text_embeddings.dtype
+ if latents is None:
+ if self.device.type == "mps":
+ # randn does not work reproducibly on mps
+ latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to(
+ self.device
+ )
+ else:
+ latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)
+ else:
+ if latents.shape != latents_shape:
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
+ latents = latents.to(self.device)
+
+ # set timesteps
+ self.scheduler.set_timesteps(num_inference_steps)
+
+ # Some schedulers like PNDM have timesteps as arrays
+ # It's more optimized to move all timesteps to correct device beforehand
+ timesteps_tensor = self.scheduler.timesteps.to(self.device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ for i, t in enumerate(self.progress_bar(timesteps_tensor)):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+ # call the callback, if provided
+ if callback is not None and i % callback_steps == 0:
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
+
+ latents = 1 / 0.18215 * latents
+ image = self.vae.decode(latents).sample
+
+ image = (image / 2 + 0.5).clamp(0, 1)
+
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+
+ if self.safety_checker is not None:
+ safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(
+ self.device
+ )
+ image, has_nsfw_concept = self.safety_checker(
+ images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)
+ )
+ else:
+ has_nsfw_concept = None
+
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
+
+ def embed_text(self, text):
+ """takes in text and turns it into text embeddings"""
+ text_input = self.tokenizer(
+ text,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ with torch.no_grad():
+ embed = self.text_encoder(text_input.input_ids.to(self.device))[0]
+ return embed
+
+ def get_noise(self, seed, dtype=torch.float32, height=512, width=512):
+ """Takes in random seed and returns corresponding noise vector"""
+ return torch.randn(
+ (1, self.unet.config.in_channels, height // 8, width // 8),
+ generator=torch.Generator(device=self.device).manual_seed(seed),
+ device=self.device,
+ dtype=dtype,
+ )
+
+ def walk(
+ self,
+ prompts: List[str],
+ seeds: List[int],
+ num_interpolation_steps: Optional[int] = 6,
+ output_dir: Optional[str] = "./dreams",
+ name: Optional[str] = None,
+ batch_size: Optional[int] = 1,
+ height: Optional[int] = 512,
+ width: Optional[int] = 512,
+ guidance_scale: Optional[float] = 7.5,
+ num_inference_steps: Optional[int] = 50,
+ eta: Optional[float] = 0.0,
+ ) -> List[str]:
+ """
+ Walks through a series of prompts and seeds, interpolating between them and saving the results to disk.
+
+ Args:
+ prompts (`List[str]`):
+ List of prompts to generate images for.
+ seeds (`List[int]`):
+ List of seeds corresponding to provided prompts. Must be the same length as prompts.
+ num_interpolation_steps (`int`, *optional*, defaults to 6):
+ Number of interpolation steps to take between prompts.
+ output_dir (`str`, *optional*, defaults to `./dreams`):
+ Directory to save the generated images to.
+ name (`str`, *optional*, defaults to `None`):
+ Subdirectory of `output_dir` to save the generated images to. If `None`, the name will
+ be the current time.
+ batch_size (`int`, *optional*, defaults to 1):
+ Number of images to generate at once.
+ height (`int`, *optional*, defaults to 512):
+ Height of the generated images.
+ width (`int`, *optional*, defaults to 512):
+ Width of the generated images.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+
+ Returns:
+ `List[str]`: List of paths to the generated images.
+ """
+ if not len(prompts) == len(seeds):
+ raise ValueError(
+ f"Number of prompts and seeds must be equalGot {len(prompts)} prompts and {len(seeds)} seeds"
+ )
+
+ name = name or time.strftime("%Y%m%d-%H%M%S")
+ save_path = Path(output_dir) / name
+ save_path.mkdir(exist_ok=True, parents=True)
+
+ frame_idx = 0
+ frame_filepaths = []
+ for prompt_a, prompt_b, seed_a, seed_b in zip(prompts, prompts[1:], seeds, seeds[1:]):
+ # Embed Text
+ embed_a = self.embed_text(prompt_a)
+ embed_b = self.embed_text(prompt_b)
+
+ # Get Noise
+ noise_dtype = embed_a.dtype
+ noise_a = self.get_noise(seed_a, noise_dtype, height, width)
+ noise_b = self.get_noise(seed_b, noise_dtype, height, width)
+
+ noise_batch, embeds_batch = None, None
+ T = np.linspace(0.0, 1.0, num_interpolation_steps)
+ for i, t in enumerate(T):
+ noise = slerp(float(t), noise_a, noise_b)
+ embed = torch.lerp(embed_a, embed_b, t)
+
+ noise_batch = noise if noise_batch is None else torch.cat([noise_batch, noise], dim=0)
+ embeds_batch = embed if embeds_batch is None else torch.cat([embeds_batch, embed], dim=0)
+
+ batch_is_ready = embeds_batch.shape[0] == batch_size or i + 1 == T.shape[0]
+ if batch_is_ready:
+ outputs = self(
+ latents=noise_batch,
+ text_embeddings=embeds_batch,
+ height=height,
+ width=width,
+ guidance_scale=guidance_scale,
+ eta=eta,
+ num_inference_steps=num_inference_steps,
+ )
+ noise_batch, embeds_batch = None, None
+
+ for image in outputs["images"]:
+ frame_filepath = str(save_path / f"frame_{frame_idx:06d}.png")
+ image.save(frame_filepath)
+ frame_filepaths.append(frame_filepath)
+ frame_idx += 1
+ return frame_filepaths
diff --git a/diffusers/examples/community/ip_adapter_face_id.py b/diffusers/examples/community/ip_adapter_face_id.py
new file mode 100644
index 0000000000000000000000000000000000000000..c7dc775eeee363ac7fcf10f04a5e4efb1ff2bf41
--- /dev/null
+++ b/diffusers/examples/community/ip_adapter_face_id.py
@@ -0,0 +1,1128 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from packaging import version
+from safetensors import safe_open
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
+
+from diffusers.configuration_utils import FrozenDict
+from diffusers.image_processor import VaeImageProcessor
+from diffusers.loaders import (
+ FromSingleFileMixin,
+ IPAdapterMixin,
+ StableDiffusionLoraLoaderMixin,
+ TextualInversionLoaderMixin,
+)
+from diffusers.models import AutoencoderKL, UNet2DConditionModel
+from diffusers.models.attention_processor import (
+ AttnProcessor,
+ AttnProcessor2_0,
+ IPAdapterAttnProcessor,
+ IPAdapterAttnProcessor2_0,
+)
+from diffusers.models.embeddings import MultiIPAdapterImageProjection
+from diffusers.models.lora import adjust_lora_scale_text_encoder
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
+from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
+from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
+from diffusers.schedulers import KarrasDiffusionSchedulers
+from diffusers.utils import (
+ USE_PEFT_BACKEND,
+ _get_model_file,
+ deprecate,
+ logging,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
+from diffusers.utils.torch_utils import randn_tensor
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class IPAdapterFullImageProjection(nn.Module):
+ def __init__(self, image_embed_dim=1024, cross_attention_dim=1024, mult=1, num_tokens=1):
+ super().__init__()
+ from diffusers.models.attention import FeedForward
+
+ self.num_tokens = num_tokens
+ self.cross_attention_dim = cross_attention_dim
+ self.ff = FeedForward(image_embed_dim, cross_attention_dim * num_tokens, mult=mult, activation_fn="gelu")
+ self.norm = nn.LayerNorm(cross_attention_dim)
+
+ def forward(self, image_embeds: torch.Tensor):
+ x = self.ff(image_embeds)
+ x = x.reshape(-1, self.num_tokens, self.cross_attention_dim)
+ return self.norm(x)
+
+
+def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
+ """
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
+ """
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
+ # rescale the results from guidance (fixes overexposure)
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
+ return noise_cfg
+
+
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ **kwargs,
+):
+ """
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used,
+ `timesteps` must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
+ timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
+ must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class IPAdapterFaceIDStableDiffusionPipeline(
+ DiffusionPipeline,
+ StableDiffusionMixin,
+ TextualInversionLoaderMixin,
+ StableDiffusionLoraLoaderMixin,
+ IPAdapterMixin,
+ FromSingleFileMixin,
+):
+ r"""
+ Pipeline for text-to-image generation using Stable Diffusion.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
+
+ The pipeline also inherits the following loading methods:
+ - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
+ - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
+ - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
+ - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
+ - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
+ text_encoder ([`~transformers.CLIPTextModel`]):
+ Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
+ tokenizer ([`~transformers.CLIPTokenizer`]):
+ A `CLIPTokenizer` to tokenize text.
+ unet ([`UNet2DConditionModel`]):
+ A `UNet2DConditionModel` to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
+ about a model's potential harms.
+ feature_extractor ([`~transformers.CLIPImageProcessor`]):
+ A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
+ """
+
+ model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
+ _optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
+ _exclude_from_cpu_offload = ["safety_checker"]
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: KarrasDiffusionSchedulers,
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPImageProcessor,
+ image_encoder: CLIPVisionModelWithProjection = None,
+ requires_safety_checker: bool = True,
+ ):
+ super().__init__()
+
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
+ " file"
+ )
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["steps_offset"] = 1
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
+ )
+ deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["clip_sample"] = False
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if safety_checker is None and requires_safety_checker:
+ logger.warning(
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
+ )
+
+ if safety_checker is not None and feature_extractor is None:
+ raise ValueError(
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
+ )
+
+ is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
+ version.parse(unet.config._diffusers_version).base_version
+ ) < version.parse("0.9.0.dev0")
+ is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
+ deprecation_message = (
+ "The configuration file of the unet has set the default `sample_size` to smaller than"
+ " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
+ " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
+ " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
+ " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
+ " in the config might lead to incorrect results in future versions. If you have downloaded this"
+ " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
+ " the `unet/config.json` file"
+ )
+ deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(unet.config)
+ new_config["sample_size"] = 64
+ unet._internal_dict = FrozenDict(new_config)
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ image_encoder=image_encoder,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
+
+ def load_ip_adapter_face_id(self, pretrained_model_name_or_path_or_dict, weight_name, **kwargs):
+ cache_dir = kwargs.pop("cache_dir", None)
+ force_download = kwargs.pop("force_download", False)
+ proxies = kwargs.pop("proxies", None)
+ local_files_only = kwargs.pop("local_files_only", None)
+ token = kwargs.pop("token", None)
+ revision = kwargs.pop("revision", None)
+ subfolder = kwargs.pop("subfolder", None)
+
+ user_agent = {
+ "file_type": "attn_procs_weights",
+ "framework": "pytorch",
+ }
+ model_file = _get_model_file(
+ pretrained_model_name_or_path_or_dict,
+ weights_name=weight_name,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ token=token,
+ revision=revision,
+ subfolder=subfolder,
+ user_agent=user_agent,
+ )
+ if weight_name.endswith(".safetensors"):
+ state_dict = {"image_proj": {}, "ip_adapter": {}}
+ with safe_open(model_file, framework="pt", device="cpu") as f:
+ for key in f.keys():
+ if key.startswith("image_proj."):
+ state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
+ elif key.startswith("ip_adapter."):
+ state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
+ else:
+ state_dict = torch.load(model_file, map_location="cpu")
+ self._load_ip_adapter_weights(state_dict)
+
+ def convert_ip_adapter_image_proj_to_diffusers(self, state_dict):
+ updated_state_dict = {}
+ clip_embeddings_dim_in = state_dict["proj.0.weight"].shape[1]
+ clip_embeddings_dim_out = state_dict["proj.0.weight"].shape[0]
+ multiplier = clip_embeddings_dim_out // clip_embeddings_dim_in
+ norm_layer = "norm.weight"
+ cross_attention_dim = state_dict[norm_layer].shape[0]
+ num_tokens = state_dict["proj.2.weight"].shape[0] // cross_attention_dim
+
+ image_projection = IPAdapterFullImageProjection(
+ cross_attention_dim=cross_attention_dim,
+ image_embed_dim=clip_embeddings_dim_in,
+ mult=multiplier,
+ num_tokens=num_tokens,
+ )
+
+ for key, value in state_dict.items():
+ diffusers_name = key.replace("proj.0", "ff.net.0.proj")
+ diffusers_name = diffusers_name.replace("proj.2", "ff.net.2")
+ updated_state_dict[diffusers_name] = value
+
+ image_projection.load_state_dict(updated_state_dict)
+ return image_projection
+
+ def _load_ip_adapter_weights(self, state_dict):
+ num_image_text_embeds = 4
+
+ self.unet.encoder_hid_proj = None
+
+ # set ip-adapter cross-attention processors & load state_dict
+ attn_procs = {}
+ lora_dict = {}
+ key_id = 0
+ for name in self.unet.attn_processors.keys():
+ cross_attention_dim = None if name.endswith("attn1.processor") else self.unet.config.cross_attention_dim
+ if name.startswith("mid_block"):
+ hidden_size = self.unet.config.block_out_channels[-1]
+ elif name.startswith("up_blocks"):
+ block_id = int(name[len("up_blocks.")])
+ hidden_size = list(reversed(self.unet.config.block_out_channels))[block_id]
+ elif name.startswith("down_blocks"):
+ block_id = int(name[len("down_blocks.")])
+ hidden_size = self.unet.config.block_out_channels[block_id]
+ if cross_attention_dim is None or "motion_modules" in name:
+ attn_processor_class = (
+ AttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else AttnProcessor
+ )
+ attn_procs[name] = attn_processor_class()
+
+ lora_dict.update(
+ {f"unet.{name}.to_k_lora.down.weight": state_dict["ip_adapter"][f"{key_id}.to_k_lora.down.weight"]}
+ )
+ lora_dict.update(
+ {f"unet.{name}.to_q_lora.down.weight": state_dict["ip_adapter"][f"{key_id}.to_q_lora.down.weight"]}
+ )
+ lora_dict.update(
+ {f"unet.{name}.to_v_lora.down.weight": state_dict["ip_adapter"][f"{key_id}.to_v_lora.down.weight"]}
+ )
+ lora_dict.update(
+ {
+ f"unet.{name}.to_out_lora.down.weight": state_dict["ip_adapter"][
+ f"{key_id}.to_out_lora.down.weight"
+ ]
+ }
+ )
+ lora_dict.update(
+ {f"unet.{name}.to_k_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_k_lora.up.weight"]}
+ )
+ lora_dict.update(
+ {f"unet.{name}.to_q_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_q_lora.up.weight"]}
+ )
+ lora_dict.update(
+ {f"unet.{name}.to_v_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_v_lora.up.weight"]}
+ )
+ lora_dict.update(
+ {f"unet.{name}.to_out_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_out_lora.up.weight"]}
+ )
+ key_id += 1
+ else:
+ attn_processor_class = (
+ IPAdapterAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else IPAdapterAttnProcessor
+ )
+ attn_procs[name] = attn_processor_class(
+ hidden_size=hidden_size,
+ cross_attention_dim=cross_attention_dim,
+ scale=1.0,
+ num_tokens=num_image_text_embeds,
+ ).to(dtype=self.dtype, device=self.device)
+
+ lora_dict.update(
+ {f"unet.{name}.to_k_lora.down.weight": state_dict["ip_adapter"][f"{key_id}.to_k_lora.down.weight"]}
+ )
+ lora_dict.update(
+ {f"unet.{name}.to_q_lora.down.weight": state_dict["ip_adapter"][f"{key_id}.to_q_lora.down.weight"]}
+ )
+ lora_dict.update(
+ {f"unet.{name}.to_v_lora.down.weight": state_dict["ip_adapter"][f"{key_id}.to_v_lora.down.weight"]}
+ )
+ lora_dict.update(
+ {
+ f"unet.{name}.to_out_lora.down.weight": state_dict["ip_adapter"][
+ f"{key_id}.to_out_lora.down.weight"
+ ]
+ }
+ )
+ lora_dict.update(
+ {f"unet.{name}.to_k_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_k_lora.up.weight"]}
+ )
+ lora_dict.update(
+ {f"unet.{name}.to_q_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_q_lora.up.weight"]}
+ )
+ lora_dict.update(
+ {f"unet.{name}.to_v_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_v_lora.up.weight"]}
+ )
+ lora_dict.update(
+ {f"unet.{name}.to_out_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_out_lora.up.weight"]}
+ )
+
+ value_dict = {}
+ value_dict.update({"to_k_ip.0.weight": state_dict["ip_adapter"][f"{key_id}.to_k_ip.weight"]})
+ value_dict.update({"to_v_ip.0.weight": state_dict["ip_adapter"][f"{key_id}.to_v_ip.weight"]})
+ attn_procs[name].load_state_dict(value_dict)
+ key_id += 1
+
+ self.unet.set_attn_processor(attn_procs)
+
+ self.load_lora_weights(lora_dict, adapter_name="faceid")
+ self.set_adapters(["faceid"], adapter_weights=[1.0])
+
+ # convert IP-Adapter Image Projection layers to diffusers
+ image_projection = self.convert_ip_adapter_image_proj_to_diffusers(state_dict["image_proj"])
+ image_projection_layers = [image_projection.to(device=self.device, dtype=self.dtype)]
+
+ self.unet.encoder_hid_proj = MultiIPAdapterImageProjection(image_projection_layers)
+ self.unet.config.encoder_hid_dim_type = "ip_image_proj"
+
+ def set_ip_adapter_scale(self, scale):
+ unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
+ for attn_processor in unet.attn_processors.values():
+ if isinstance(attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)):
+ attn_processor.scale = [scale]
+
+ def _encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt=None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ lora_scale: Optional[float] = None,
+ **kwargs,
+ ):
+ deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
+ deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
+
+ prompt_embeds_tuple = self.encode_prompt(
+ prompt=prompt,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ do_classifier_free_guidance=do_classifier_free_guidance,
+ negative_prompt=negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ lora_scale=lora_scale,
+ **kwargs,
+ )
+
+ # concatenate for backwards comp
+ prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
+
+ return prompt_embeds
+
+ def encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt=None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ lora_scale: Optional[float] = None,
+ clip_skip: Optional[int] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ lora_scale (`float`, *optional*):
+ A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+ """
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ # dynamically adjust the LoRA scale
+ if not USE_PEFT_BACKEND:
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
+ else:
+ scale_lora_layers(self.text_encoder, lora_scale)
+
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ # textual inversion: process multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = self.tokenizer.batch_decode(
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
+ )
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = text_inputs.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ if clip_skip is None:
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
+ prompt_embeds = prompt_embeds[0]
+ else:
+ prompt_embeds = self.text_encoder(
+ text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True
+ )
+ # Access the `hidden_states` first, that contains a tuple of
+ # all the hidden states from the encoder layers. Then index into
+ # the tuple to access the hidden states from the desired layer.
+ prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
+ # We also need to apply the final LayerNorm here to not mess with the
+ # representations. The `last_hidden_states` that we typically use for
+ # obtaining the final prompt representations passes through the LayerNorm
+ # layer.
+ prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
+
+ if self.text_encoder is not None:
+ prompt_embeds_dtype = self.text_encoder.dtype
+ elif self.unet is not None:
+ prompt_embeds_dtype = self.unet.dtype
+ else:
+ prompt_embeds_dtype = prompt_embeds.dtype
+
+ prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ # textual inversion: process multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = uncond_input.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ negative_prompt_embeds = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ negative_prompt_embeds = negative_prompt_embeds[0]
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder, lora_scale)
+
+ return prompt_embeds, negative_prompt_embeds
+
+ def run_safety_checker(self, image, device, dtype):
+ if self.safety_checker is None:
+ has_nsfw_concept = None
+ else:
+ if torch.is_tensor(image):
+ feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
+ else:
+ feature_extractor_input = self.image_processor.numpy_to_pil(image)
+ safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
+ image, has_nsfw_concept = self.safety_checker(
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
+ )
+ return image, has_nsfw_concept
+
+ def decode_latents(self, latents):
+ deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
+ deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
+
+ latents = 1 / self.vae.config.scaling_factor * latents
+ image = self.vae.decode(latents, return_dict=False)[0]
+ image = (image / 2 + 0.5).clamp(0, 1)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+ return image
+
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ callback_steps,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ ):
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
+ shape = (
+ batch_size,
+ num_channels_latents,
+ int(height) // self.vae_scale_factor,
+ int(width) // self.vae_scale_factor,
+ )
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
+ def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
+ """
+ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
+
+ Args:
+ timesteps (`torch.Tensor`):
+ generate embedding vectors at these timesteps
+ embedding_dim (`int`, *optional*, defaults to 512):
+ dimension of the embeddings to generate
+ dtype:
+ data type of the generated embeddings
+
+ Returns:
+ `torch.Tensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
+ """
+ assert len(w.shape) == 1
+ w = w * 1000.0
+
+ half_dim = embedding_dim // 2
+ emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
+ emb = w.to(dtype)[:, None] * emb[None, :]
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+ if embedding_dim % 2 == 1: # zero pad
+ emb = torch.nn.functional.pad(emb, (0, 1))
+ assert emb.shape == (w.shape[0], embedding_dim)
+ return emb
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def guidance_rescale(self):
+ return self._guidance_rescale
+
+ @property
+ def clip_skip(self):
+ return self._clip_skip
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
+
+ @property
+ def cross_attention_kwargs(self):
+ return self._cross_attention_kwargs
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ timesteps: List[int] = None,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ image_embeds: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ guidance_rescale: float = 0.0,
+ clip_skip: Optional[int] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ **kwargs,
+ ):
+ r"""
+ The call function to the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
+ height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
+ passed will be used. Must be in descending order.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ A higher guidance scale value encourages the model to generate images closely linked to the text
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
+ generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor is generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
+ provided, text embeddings are generated from the `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
+ image_embeds (`torch.Tensor`, *optional*):
+ Pre-generated image embeddings.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
+ Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when
+ using zero terminal SNR.
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
+ otherwise a `tuple` is returned where the first element is a list with the generated images and the
+ second element is a list of `bool`s indicating whether the corresponding generated image contains
+ "not-safe-for-work" (nsfw) content.
+ """
+
+ callback = kwargs.pop("callback", None)
+ callback_steps = kwargs.pop("callback_steps", None)
+
+ if callback is not None:
+ deprecate(
+ "callback",
+ "1.0.0",
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
+ )
+ if callback_steps is not None:
+ deprecate(
+ "callback_steps",
+ "1.0.0",
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
+ )
+
+ # 0. Default height and width to unet
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
+ # to deal with lora scaling and other possible forward hooks
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ height,
+ width,
+ callback_steps,
+ negative_prompt,
+ prompt_embeds,
+ negative_prompt_embeds,
+ callback_on_step_end_tensor_inputs,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._guidance_rescale = guidance_rescale
+ self._clip_skip = clip_skip
+ self._cross_attention_kwargs = cross_attention_kwargs
+ self._interrupt = False
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ # 3. Encode input prompt
+ lora_scale = (
+ self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
+ )
+
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
+ prompt,
+ device,
+ num_images_per_prompt,
+ self.do_classifier_free_guidance,
+ negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ lora_scale=lora_scale,
+ clip_skip=self.clip_skip,
+ )
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ if self.do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+
+ if image_embeds is not None:
+ image_embeds = torch.stack([image_embeds] * num_images_per_prompt, dim=0).to(
+ device=device, dtype=prompt_embeds.dtype
+ )
+ negative_image_embeds = torch.zeros_like(image_embeds)
+ if self.do_classifier_free_guidance:
+ image_embeds = torch.cat([negative_image_embeds, image_embeds])
+ image_embeds = [image_embeds]
+ # 4. Prepare timesteps
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.unet.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 6.1 Add image embeds for IP-Adapter
+ added_cond_kwargs = {"image_embeds": image_embeds} if image_embeds is not None else {}
+
+ # 6.2 Optionally get Guidance Scale Embedding
+ timestep_cond = None
+ if self.unet.config.time_cond_proj_dim is not None:
+ guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
+ timestep_cond = self.get_guidance_scale_embedding(
+ guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
+ ).to(device=device, dtype=latents.dtype)
+
+ # 7. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ self._num_timesteps = len(timesteps)
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ timestep_cond=timestep_cond,
+ cross_attention_kwargs=self.cross_attention_kwargs,
+ added_cond_kwargs=added_cond_kwargs,
+ return_dict=False,
+ )[0]
+
+ # perform guidance
+ if self.do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
+
+ if not output_type == "latent":
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
+ 0
+ ]
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
+ else:
+ image = latents
+ has_nsfw_concept = None
+
+ if has_nsfw_concept is None:
+ do_denormalize = [True] * image.shape[0]
+ else:
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
+
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
diff --git a/diffusers/examples/community/kohya_hires_fix.py b/diffusers/examples/community/kohya_hires_fix.py
new file mode 100644
index 0000000000000000000000000000000000000000..0e36f32b19a3c421c7e75e3d67d002661563e89f
--- /dev/null
+++ b/diffusers/examples/community/kohya_hires_fix.py
@@ -0,0 +1,468 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.utils.checkpoint
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
+
+from diffusers.configuration_utils import register_to_config
+from diffusers.image_processor import VaeImageProcessor
+from diffusers.models.autoencoders import AutoencoderKL
+from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel, UNet2DConditionOutput
+from diffusers.pipelines.stable_diffusion import StableDiffusionPipeline
+from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
+from diffusers.schedulers import KarrasDiffusionSchedulers
+from diffusers.utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class UNet2DConditionModelHighResFix(UNet2DConditionModel):
+ r"""
+ A conditional 2D UNet model that applies Kohya fix proposed for high resolution image generation.
+
+ This model inherits from [`UNet2DConditionModel`]. Check the superclass documentation for learning about all the parameters.
+
+ Parameters:
+ high_res_fix (`List[Dict]`, *optional*, defaults to `[{'timestep': 600, 'scale_factor': 0.5, 'block_num': 1}]`):
+ Enables Kohya fix for high resolution generation. The activation maps are scaled based on the scale_factor up to the timestep at specified block_num.
+ """
+
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(self, high_res_fix: List[Dict] = [{"timestep": 600, "scale_factor": 0.5, "block_num": 1}], **kwargs):
+ super().__init__(**kwargs)
+ if high_res_fix:
+ self.config.high_res_fix = sorted(high_res_fix, key=lambda x: x["timestep"], reverse=True)
+
+ @classmethod
+ def _resize(cls, sample, target=None, scale_factor=1, mode="bicubic"):
+ dtype = sample.dtype
+ if dtype == torch.bfloat16:
+ sample = sample.to(torch.float32)
+
+ if target is not None:
+ if sample.shape[-2:] != target.shape[-2:]:
+ sample = nn.functional.interpolate(sample, size=target.shape[-2:], mode=mode, align_corners=False)
+ elif scale_factor != 1:
+ sample = nn.functional.interpolate(sample, scale_factor=scale_factor, mode=mode, align_corners=False)
+
+ return sample.to(dtype)
+
+ def forward(
+ self,
+ sample: torch.FloatTensor,
+ timestep: Union[torch.Tensor, float, int],
+ encoder_hidden_states: torch.Tensor,
+ class_labels: Optional[torch.Tensor] = None,
+ timestep_cond: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
+ down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ return_dict: bool = True,
+ ) -> Union[UNet2DConditionOutput, Tuple]:
+ r"""
+ The [`UNet2DConditionModel`] forward method.
+
+ Args:
+ sample (`torch.FloatTensor`):
+ The noisy input tensor with the following shape `(batch, channel, height, width)`.
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
+ encoder_hidden_states (`torch.FloatTensor`):
+ The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
+ class_labels (`torch.Tensor`, *optional*, defaults to `None`):
+ Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
+ timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`):
+ Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed
+ through the `self.time_embedding` layer to obtain the timestep embeddings.
+ attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
+ negative values to the attention scores corresponding to "discard" tokens.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ added_cond_kwargs: (`dict`, *optional*):
+ A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that
+ are passed along to the UNet blocks.
+ down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*):
+ A tuple of tensors that if specified are added to the residuals of down unet blocks.
+ mid_block_additional_residual: (`torch.Tensor`, *optional*):
+ A tensor that if specified is added to the residual of the middle unet block.
+ down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
+ additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s)
+ encoder_attention_mask (`torch.Tensor`):
+ A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
+ `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
+ which adds large negative values to the attention scores corresponding to "discard" tokens.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
+ tuple.
+
+ Returns:
+ [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
+ If `return_dict` is True, an [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] is returned,
+ otherwise a `tuple` is returned where the first element is the sample tensor.
+ """
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
+ # on the fly if necessary.
+ default_overall_up_factor = 2**self.num_upsamplers
+
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
+ forward_upsample_size = False
+ upsample_size = None
+
+ for dim in sample.shape[-2:]:
+ if dim % default_overall_up_factor != 0:
+ # Forward upsample size to force interpolation output size.
+ forward_upsample_size = True
+ break
+
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
+ # expects mask of shape:
+ # [batch, key_tokens]
+ # adds singleton query_tokens dimension:
+ # [batch, 1, key_tokens]
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
+ if attention_mask is not None:
+ # assume that mask is expressed as:
+ # (1 = keep, 0 = discard)
+ # convert mask into a bias that can be added to attention scores:
+ # (keep = +0, discard = -10000.0)
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
+ attention_mask = attention_mask.unsqueeze(1)
+
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
+ if encoder_attention_mask is not None:
+ encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
+
+ # 0. center input if necessary
+ if self.config.center_input_sample:
+ sample = 2 * sample - 1.0
+
+ # 1. time
+ t_emb = self.get_time_embed(sample=sample, timestep=timestep)
+ emb = self.time_embedding(t_emb, timestep_cond)
+ aug_emb = None
+
+ class_emb = self.get_class_embed(sample=sample, class_labels=class_labels)
+ if class_emb is not None:
+ if self.config.class_embeddings_concat:
+ emb = torch.cat([emb, class_emb], dim=-1)
+ else:
+ emb = emb + class_emb
+
+ aug_emb = self.get_aug_embed(
+ emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
+ )
+ if self.config.addition_embed_type == "image_hint":
+ aug_emb, hint = aug_emb
+ sample = torch.cat([sample, hint], dim=1)
+
+ emb = emb + aug_emb if aug_emb is not None else emb
+
+ if self.time_embed_act is not None:
+ emb = self.time_embed_act(emb)
+
+ encoder_hidden_states = self.process_encoder_hidden_states(
+ encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
+ )
+
+ # 2. pre-process
+ sample = self.conv_in(sample)
+
+ # 2.5 GLIGEN position net
+ if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None:
+ cross_attention_kwargs = cross_attention_kwargs.copy()
+ gligen_args = cross_attention_kwargs.pop("gligen")
+ cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}
+
+ # 3. down
+ # we're popping the `scale` instead of getting it because otherwise `scale` will be propagated
+ # to the internal blocks and will raise deprecation warnings. this will be confusing for our users.
+ if cross_attention_kwargs is not None:
+ cross_attention_kwargs = cross_attention_kwargs.copy()
+ lora_scale = cross_attention_kwargs.pop("scale", 1.0)
+ else:
+ lora_scale = 1.0
+
+ if USE_PEFT_BACKEND:
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
+ scale_lora_layers(self, lora_scale)
+
+ is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
+ # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets
+ is_adapter = down_intrablock_additional_residuals is not None
+ # maintain backward compatibility for legacy usage, where
+ # T2I-Adapter and ControlNet both use down_block_additional_residuals arg
+ # but can only use one or the other
+ if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None:
+ deprecate(
+ "T2I should not use down_block_additional_residuals",
+ "1.3.0",
+ "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \
+ and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \
+ for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ",
+ standard_warn=False,
+ )
+ down_intrablock_additional_residuals = down_block_additional_residuals
+ is_adapter = True
+
+ down_block_res_samples = (sample,)
+ for down_i, downsample_block in enumerate(self.down_blocks):
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
+ # For t2i-adapter CrossAttnDownBlock2D
+ additional_residuals = {}
+ if is_adapter and len(down_intrablock_additional_residuals) > 0:
+ additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0)
+
+ sample, res_samples = downsample_block(
+ hidden_states=sample,
+ temb=emb,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ cross_attention_kwargs=cross_attention_kwargs,
+ encoder_attention_mask=encoder_attention_mask,
+ **additional_residuals,
+ )
+
+ else:
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
+ if is_adapter and len(down_intrablock_additional_residuals) > 0:
+ sample += down_intrablock_additional_residuals.pop(0)
+
+ down_block_res_samples += res_samples
+
+ # kohya high res fix
+ if self.config.high_res_fix:
+ for high_res_fix in self.config.high_res_fix:
+ if timestep > high_res_fix["timestep"] and down_i == high_res_fix["block_num"]:
+ sample = self.__class__._resize(sample, scale_factor=high_res_fix["scale_factor"])
+ break
+
+ if is_controlnet:
+ new_down_block_res_samples = ()
+
+ for down_block_res_sample, down_block_additional_residual in zip(
+ down_block_res_samples, down_block_additional_residuals
+ ):
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
+ new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
+
+ down_block_res_samples = new_down_block_res_samples
+
+ # 4. mid
+ if self.mid_block is not None:
+ if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
+ sample = self.mid_block(
+ sample,
+ emb,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ cross_attention_kwargs=cross_attention_kwargs,
+ encoder_attention_mask=encoder_attention_mask,
+ )
+ else:
+ sample = self.mid_block(sample, emb)
+
+ # To support T2I-Adapter-XL
+ if (
+ is_adapter
+ and len(down_intrablock_additional_residuals) > 0
+ and sample.shape == down_intrablock_additional_residuals[0].shape
+ ):
+ sample += down_intrablock_additional_residuals.pop(0)
+
+ if is_controlnet:
+ sample = sample + mid_block_additional_residual
+
+ # 5. up
+ for i, upsample_block in enumerate(self.up_blocks):
+ is_final_block = i == len(self.up_blocks) - 1
+
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
+
+ # up scaling of kohya high res fix
+ if self.config.high_res_fix is not None:
+ if res_samples[0].shape[-2:] != sample.shape[-2:]:
+ sample = self.__class__._resize(sample, target=res_samples[0])
+ res_samples_up_sampled = (res_samples[0],)
+ for res_sample in res_samples[1:]:
+ res_samples_up_sampled += (self.__class__._resize(res_sample, target=res_samples[0]),)
+ res_samples = res_samples_up_sampled
+
+ # if we have not reached the final block and need to forward the
+ # upsample size, we do it here
+ if not is_final_block and forward_upsample_size:
+ upsample_size = down_block_res_samples[-1].shape[2:]
+
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
+ sample = upsample_block(
+ hidden_states=sample,
+ temb=emb,
+ res_hidden_states_tuple=res_samples,
+ encoder_hidden_states=encoder_hidden_states,
+ cross_attention_kwargs=cross_attention_kwargs,
+ upsample_size=upsample_size,
+ attention_mask=attention_mask,
+ encoder_attention_mask=encoder_attention_mask,
+ )
+ else:
+ sample = upsample_block(
+ hidden_states=sample,
+ temb=emb,
+ res_hidden_states_tuple=res_samples,
+ upsample_size=upsample_size,
+ )
+
+ # 6. post-process
+ if self.conv_norm_out:
+ sample = self.conv_norm_out(sample)
+ sample = self.conv_act(sample)
+ sample = self.conv_out(sample)
+
+ if USE_PEFT_BACKEND:
+ # remove `lora_scale` from each PEFT layer
+ unscale_lora_layers(self, lora_scale)
+
+ if not return_dict:
+ return (sample,)
+
+ return UNet2DConditionOutput(sample=sample)
+
+ @classmethod
+ def from_unet(cls, unet: UNet2DConditionModel, high_res_fix: list):
+ config = dict((unet.config))
+ config["high_res_fix"] = high_res_fix
+ unet_high_res = cls(**config)
+ unet_high_res.load_state_dict(unet.state_dict())
+ unet_high_res.to(unet.dtype)
+ return unet_high_res
+
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import DiffusionPipeline
+
+ >>> pipe = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4",
+ custom_pipeline="kohya_hires_fix",
+ torch_dtype=torch.float16,
+ high_res_fix=[{'timestep': 600,
+ 'scale_factor': 0.5,
+ 'block_num': 1}])
+ >>> pipe = pipe.to("cuda")
+
+ >>> prompt = "a photo of an astronaut riding a horse on mars"
+ >>> image = pipe(prompt, height=1000, width=1600).images[0]
+ ```
+"""
+
+
+class StableDiffusionHighResFixPipeline(StableDiffusionPipeline):
+ r"""
+ Pipeline for text-to-image generation using Stable Diffusion with Kohya fix for high resolution generation.
+
+ This model inherits from [`StableDiffusionPipeline`]. Check the superclass documentation for the generic methods.
+
+ The pipeline also inherits the following loading methods:
+ - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
+ - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
+ - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
+ - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
+ - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
+ text_encoder ([`~transformers.CLIPTextModel`]):
+ Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
+ tokenizer ([`~transformers.CLIPTokenizer`]):
+ A `CLIPTokenizer` to tokenize text.
+ unet ([`UNet2DConditionModel`]):
+ A `UNet2DConditionModel` to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
+ about a model's potential harms.
+ feature_extractor ([`~transformers.CLIPImageProcessor`]):
+ A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
+ high_res_fix (`List[Dict]`, *optional*, defaults to `[{'timestep': 600, 'scale_factor': 0.5, 'block_num': 1}]`):
+ Enables Kohya fix for high resolution generation. The activation maps are scaled based on the scale_factor up to the timestep at specified block_num.
+ """
+
+ model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
+ _optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
+ _exclude_from_cpu_offload = ["safety_checker"]
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: KarrasDiffusionSchedulers,
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPImageProcessor,
+ image_encoder: CLIPVisionModelWithProjection = None,
+ requires_safety_checker: bool = True,
+ high_res_fix: List[Dict] = [{"timestep": 600, "scale_factor": 0.5, "block_num": 1}],
+ ):
+ super().__init__(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ image_encoder=image_encoder,
+ requires_safety_checker=requires_safety_checker,
+ )
+
+ unet = UNet2DConditionModelHighResFix.from_unet(unet=unet, high_res_fix=high_res_fix)
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ image_encoder=image_encoder,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
diff --git a/diffusers/examples/community/latent_consistency_img2img.py b/diffusers/examples/community/latent_consistency_img2img.py
new file mode 100644
index 0000000000000000000000000000000000000000..5fe53ab6b8301ed2c7eb3a5cc503f97fc3141f51
--- /dev/null
+++ b/diffusers/examples/community/latent_consistency_img2img.py
@@ -0,0 +1,821 @@
+# Copyright 2024 Stanford University Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
+# and https://github.com/hojonathanho/diffusion
+
+import math
+from dataclasses import dataclass
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import PIL.Image
+import torch
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
+
+from diffusers import AutoencoderKL, ConfigMixin, DiffusionPipeline, SchedulerMixin, UNet2DConditionModel, logging
+from diffusers.configuration_utils import register_to_config
+from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
+from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
+from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
+from diffusers.utils import BaseOutput
+from diffusers.utils.torch_utils import randn_tensor
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class LatentConsistencyModelImg2ImgPipeline(DiffusionPipeline):
+ _optional_components = ["scheduler"]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: "LCMSchedulerWithTimestamp",
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPImageProcessor,
+ requires_safety_checker: bool = True,
+ ):
+ super().__init__()
+
+ scheduler = (
+ scheduler
+ if scheduler is not None
+ else LCMSchedulerWithTimestamp(
+ beta_start=0.00085, beta_end=0.0120, beta_schedule="scaled_linear", prediction_type="epsilon"
+ )
+ )
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+
+ def _encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ prompt_embeds: None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ """
+
+ if prompt is not None and isinstance(prompt, str):
+ pass
+ elif prompt is not None and isinstance(prompt, list):
+ len(prompt)
+ else:
+ prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = self.tokenizer.batch_decode(
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
+ )
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = text_inputs.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ prompt_embeds = self.text_encoder(
+ text_input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ prompt_embeds = prompt_embeds[0]
+
+ if self.text_encoder is not None:
+ prompt_embeds_dtype = self.text_encoder.dtype
+ elif self.unet is not None:
+ prompt_embeds_dtype = self.unet.dtype
+ else:
+ prompt_embeds_dtype = prompt_embeds.dtype
+
+ prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # Don't need to get uncond prompt embedding because of LCM Guided Distillation
+ return prompt_embeds
+
+ def run_safety_checker(self, image, device, dtype):
+ if self.safety_checker is None:
+ has_nsfw_concept = None
+ else:
+ if torch.is_tensor(image):
+ feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
+ else:
+ feature_extractor_input = self.image_processor.numpy_to_pil(image)
+ safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
+ image, has_nsfw_concept = self.safety_checker(
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
+ )
+ return image, has_nsfw_concept
+
+ def prepare_latents(
+ self,
+ image,
+ timestep,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ dtype,
+ device,
+ latents=None,
+ generator=None,
+ ):
+ shape = (
+ batch_size,
+ num_channels_latents,
+ int(height) // self.vae_scale_factor,
+ int(width) // self.vae_scale_factor,
+ )
+
+ if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
+ raise ValueError(
+ f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
+ )
+
+ image = image.to(device=device, dtype=dtype)
+
+ # batch_size = batch_size * num_images_per_prompt
+
+ if image.shape[1] == 4:
+ init_latents = image
+
+ else:
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ elif isinstance(generator, list):
+ init_latents = [
+ self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
+ ]
+ init_latents = torch.cat(init_latents, dim=0)
+ else:
+ init_latents = self.vae.encode(image).latent_dist.sample(generator)
+
+ init_latents = self.vae.config.scaling_factor * init_latents
+
+ if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
+ # expand init_latents for batch_size
+ (
+ f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial"
+ " images (`image`). Initial images are now duplicating to match the number of text prompts. Note"
+ " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
+ " your script to pass as many initial images as text prompts to suppress this warning."
+ )
+ # deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False)
+ additional_image_per_prompt = batch_size // init_latents.shape[0]
+ init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0)
+ elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
+ raise ValueError(
+ f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
+ )
+ else:
+ init_latents = torch.cat([init_latents], dim=0)
+
+ shape = init_latents.shape
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+
+ # get latents
+ init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
+ latents = init_latents
+
+ return latents
+
+ def get_w_embedding(self, w, embedding_dim=512, dtype=torch.float32):
+ """
+ see https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
+ Args:
+ timesteps: torch.Tensor: generate embedding vectors at these timesteps
+ embedding_dim: int: dimension of the embeddings to generate
+ dtype: data type of the generated embeddings
+ Returns:
+ embedding vectors with shape `(len(timesteps), embedding_dim)`
+ """
+ assert len(w.shape) == 1
+ w = w * 1000.0
+
+ half_dim = embedding_dim // 2
+ emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
+ emb = w.to(dtype)[:, None] * emb[None, :]
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+ if embedding_dim % 2 == 1: # zero pad
+ emb = torch.nn.functional.pad(emb, (0, 1))
+ assert emb.shape == (w.shape[0], embedding_dim)
+ return emb
+
+ def get_timesteps(self, num_inference_steps, strength, device):
+ # get the original timestep using init_timestep
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
+
+ t_start = max(num_inference_steps - init_timestep, 0)
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
+
+ return timesteps, num_inference_steps - t_start
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ image: PipelineImageInput = None,
+ strength: float = 0.8,
+ height: Optional[int] = 768,
+ width: Optional[int] = 768,
+ guidance_scale: float = 7.5,
+ num_images_per_prompt: Optional[int] = 1,
+ latents: Optional[torch.Tensor] = None,
+ num_inference_steps: int = 4,
+ lcm_origin_steps: int = 50,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ ):
+ # 0. Default height and width to unet
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+ # do_classifier_free_guidance = guidance_scale > 0.0 # In LCM Implementation: cfg_noise = noise_cond + cfg_scale * (noise_cond - noise_uncond) , (cfg_scale > 0.0 using CFG)
+
+ # 3. Encode input prompt
+ prompt_embeds = self._encode_prompt(
+ prompt,
+ device,
+ num_images_per_prompt,
+ prompt_embeds=prompt_embeds,
+ )
+
+ # 3.5 encode image
+ image = self.image_processor.preprocess(image)
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(strength, num_inference_steps, lcm_origin_steps)
+ # timesteps = self.scheduler.timesteps
+ # timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, 1.0, device)
+ timesteps = self.scheduler.timesteps
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
+
+ print("timesteps: ", timesteps)
+
+ # 5. Prepare latent variable
+ num_channels_latents = self.unet.config.in_channels
+ if latents is None:
+ latents = self.prepare_latents(
+ image,
+ latent_timestep,
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ latents,
+ )
+ bs = batch_size * num_images_per_prompt
+
+ # 6. Get Guidance Scale Embedding
+ w = torch.tensor(guidance_scale).repeat(bs)
+ w_embedding = self.get_w_embedding(w, embedding_dim=256).to(device=device, dtype=latents.dtype)
+
+ # 7. LCM MultiStep Sampling Loop:
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ ts = torch.full((bs,), t, device=device, dtype=torch.long)
+ latents = latents.to(prompt_embeds.dtype)
+
+ # model prediction (v-prediction, eps, x)
+ model_pred = self.unet(
+ latents,
+ ts,
+ timestep_cond=w_embedding,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs=cross_attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents, denoised = self.scheduler.step(model_pred, i, t, latents, return_dict=False)
+
+ # # call the callback, if provided
+ # if i == len(timesteps) - 1:
+ progress_bar.update()
+
+ denoised = denoised.to(prompt_embeds.dtype)
+ if not output_type == "latent":
+ image = self.vae.decode(denoised / self.vae.config.scaling_factor, return_dict=False)[0]
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
+ else:
+ image = denoised
+ has_nsfw_concept = None
+
+ if has_nsfw_concept is None:
+ do_denormalize = [True] * image.shape[0]
+ else:
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
+
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
+
+
+@dataclass
+# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DDIM
+class LCMSchedulerOutput(BaseOutput):
+ """
+ Output class for the scheduler's `step` function output.
+ Args:
+ prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
+ Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
+ denoising loop.
+ pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
+ The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
+ `pred_original_sample` can be used to preview progress or for guidance.
+ """
+
+ prev_sample: torch.Tensor
+ denoised: Optional[torch.Tensor] = None
+
+
+# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
+def betas_for_alpha_bar(
+ num_diffusion_timesteps,
+ max_beta=0.999,
+ alpha_transform_type="cosine",
+):
+ """
+ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
+ (1-beta) over time from t = [0,1].
+ Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
+ to that part of the diffusion process.
+ Args:
+ num_diffusion_timesteps (`int`): the number of betas to produce.
+ max_beta (`float`): the maximum beta to use; use values lower than 1 to
+ prevent singularities.
+ alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
+ Choose from `cosine` or `exp`
+ Returns:
+ betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
+ """
+ if alpha_transform_type == "cosine":
+
+ def alpha_bar_fn(t):
+ return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
+
+ elif alpha_transform_type == "exp":
+
+ def alpha_bar_fn(t):
+ return math.exp(t * -12.0)
+
+ else:
+ raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
+
+ betas = []
+ for i in range(num_diffusion_timesteps):
+ t1 = i / num_diffusion_timesteps
+ t2 = (i + 1) / num_diffusion_timesteps
+ betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
+ return torch.tensor(betas, dtype=torch.float32)
+
+
+def rescale_zero_terminal_snr(betas):
+ """
+ Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
+ Args:
+ betas (`torch.Tensor`):
+ the betas that the scheduler is being initialized with.
+ Returns:
+ `torch.Tensor`: rescaled betas with zero terminal SNR
+ """
+ # Convert betas to alphas_bar_sqrt
+ alphas = 1.0 - betas
+ alphas_cumprod = torch.cumprod(alphas, dim=0)
+ alphas_bar_sqrt = alphas_cumprod.sqrt()
+
+ # Store old values.
+ alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
+ alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
+
+ # Shift so the last timestep is zero.
+ alphas_bar_sqrt -= alphas_bar_sqrt_T
+
+ # Scale so the first timestep is back to the old value.
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
+
+ # Convert alphas_bar_sqrt to betas
+ alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
+ alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
+ alphas = torch.cat([alphas_bar[0:1], alphas])
+ betas = 1 - alphas
+
+ return betas
+
+
+class LCMSchedulerWithTimestamp(SchedulerMixin, ConfigMixin):
+ """
+ This class modifies LCMScheduler to add a timestamp argument to set_timesteps
+
+
+ `LCMScheduler` extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with
+ non-Markovian guidance.
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
+ methods the library implements for all schedulers such as loading and saving.
+ Args:
+ num_train_timesteps (`int`, defaults to 1000):
+ The number of diffusion steps to train the model.
+ beta_start (`float`, defaults to 0.0001):
+ The starting `beta` value of inference.
+ beta_end (`float`, defaults to 0.02):
+ The final `beta` value.
+ beta_schedule (`str`, defaults to `"linear"`):
+ The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
+ `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
+ trained_betas (`np.ndarray`, *optional*):
+ Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
+ clip_sample (`bool`, defaults to `True`):
+ Clip the predicted sample for numerical stability.
+ clip_sample_range (`float`, defaults to 1.0):
+ The maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
+ set_alpha_to_one (`bool`, defaults to `True`):
+ Each diffusion step uses the alphas product value at that step and at the previous one. For the final step
+ there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
+ otherwise it uses the alpha value at step 0.
+ steps_offset (`int`, defaults to 0):
+ An offset added to the inference steps, as required by some model families.
+ prediction_type (`str`, defaults to `epsilon`, *optional*):
+ Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
+ `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
+ Video](https://imagen.research.google/video/paper.pdf) paper).
+ thresholding (`bool`, defaults to `False`):
+ Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
+ as Stable Diffusion.
+ dynamic_thresholding_ratio (`float`, defaults to 0.995):
+ The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
+ sample_max_value (`float`, defaults to 1.0):
+ The threshold value for dynamic thresholding. Valid only when `thresholding=True`.
+ timestep_spacing (`str`, defaults to `"leading"`):
+ The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
+ Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
+ rescale_betas_zero_snr (`bool`, defaults to `False`):
+ Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
+ dark samples instead of limiting it to samples with medium brightness. Loosely related to
+ [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
+ """
+
+ # _compatibles = [e.name for e in KarrasDiffusionSchedulers]
+ order = 1
+
+ @register_to_config
+ def __init__(
+ self,
+ num_train_timesteps: int = 1000,
+ beta_start: float = 0.0001,
+ beta_end: float = 0.02,
+ beta_schedule: str = "linear",
+ trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
+ clip_sample: bool = True,
+ set_alpha_to_one: bool = True,
+ steps_offset: int = 0,
+ prediction_type: str = "epsilon",
+ thresholding: bool = False,
+ dynamic_thresholding_ratio: float = 0.995,
+ clip_sample_range: float = 1.0,
+ sample_max_value: float = 1.0,
+ timestep_spacing: str = "leading",
+ rescale_betas_zero_snr: bool = False,
+ ):
+ if trained_betas is not None:
+ self.betas = torch.tensor(trained_betas, dtype=torch.float32)
+ elif beta_schedule == "linear":
+ self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
+ elif beta_schedule == "scaled_linear":
+ # this schedule is very specific to the latent diffusion model.
+ self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
+ elif beta_schedule == "squaredcos_cap_v2":
+ # Glide cosine schedule
+ self.betas = betas_for_alpha_bar(num_train_timesteps)
+ else:
+ raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
+
+ # Rescale for zero SNR
+ if rescale_betas_zero_snr:
+ self.betas = rescale_zero_terminal_snr(self.betas)
+
+ self.alphas = 1.0 - self.betas
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
+
+ # At every step in ddim, we are looking into the previous alphas_cumprod
+ # For the final step, there is no previous alphas_cumprod because we are already at 0
+ # `set_alpha_to_one` decides whether we set this parameter simply to one or
+ # whether we use the final alpha of the "non-previous" one.
+ self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
+
+ # standard deviation of the initial noise distribution
+ self.init_noise_sigma = 1.0
+
+ # setable values
+ self.num_inference_steps = None
+ self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))
+
+ def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
+ """
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
+ current timestep.
+ Args:
+ sample (`torch.Tensor`):
+ The input sample.
+ timestep (`int`, *optional*):
+ The current timestep in the diffusion chain.
+ Returns:
+ `torch.Tensor`:
+ A scaled input sample.
+ """
+ return sample
+
+ def _get_variance(self, timestep, prev_timestep):
+ alpha_prod_t = self.alphas_cumprod[timestep]
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
+ beta_prod_t = 1 - alpha_prod_t
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
+
+ variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
+
+ return variance
+
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
+ def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
+ """
+ "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
+ prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
+ s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
+ pixels from saturation at each step. We find that dynamic thresholding results in significantly better
+ photorealism as well as better image-text alignment, especially when using very large guidance weights."
+ https://arxiv.org/abs/2205.11487
+ """
+ dtype = sample.dtype
+ batch_size, channels, height, width = sample.shape
+
+ if dtype not in (torch.float32, torch.float64):
+ sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
+
+ # Flatten sample for doing quantile calculation along each image
+ sample = sample.reshape(batch_size, channels * height * width)
+
+ abs_sample = sample.abs() # "a certain percentile absolute pixel value"
+
+ s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
+ s = torch.clamp(
+ s, min=1, max=self.config.sample_max_value
+ ) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
+
+ s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
+ sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
+
+ sample = sample.reshape(batch_size, channels, height, width)
+ sample = sample.to(dtype)
+
+ return sample
+
+ def set_timesteps(
+ self, stength, num_inference_steps: int, lcm_origin_steps: int, device: Union[str, torch.device] = None
+ ):
+ """
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
+ Args:
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model.
+ """
+
+ if num_inference_steps > self.config.num_train_timesteps:
+ raise ValueError(
+ f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
+ f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
+ f" maximal {self.config.num_train_timesteps} timesteps."
+ )
+
+ self.num_inference_steps = num_inference_steps
+
+ # LCM Timesteps Setting: # Linear Spacing
+ c = self.config.num_train_timesteps // lcm_origin_steps
+ lcm_origin_timesteps = (
+ np.asarray(list(range(1, int(lcm_origin_steps * stength) + 1))) * c - 1
+ ) # LCM Training Steps Schedule
+ skipping_step = len(lcm_origin_timesteps) // num_inference_steps
+ timesteps = lcm_origin_timesteps[::-skipping_step][:num_inference_steps] # LCM Inference Steps Schedule
+
+ self.timesteps = torch.from_numpy(timesteps.copy()).to(device)
+
+ def get_scalings_for_boundary_condition_discrete(self, t):
+ self.sigma_data = 0.5 # Default: 0.5
+
+ # By dividing 0.1: This is almost a delta function at t=0.
+ c_skip = self.sigma_data**2 / ((t / 0.1) ** 2 + self.sigma_data**2)
+ c_out = (t / 0.1) / ((t / 0.1) ** 2 + self.sigma_data**2) ** 0.5
+ return c_skip, c_out
+
+ def step(
+ self,
+ model_output: torch.Tensor,
+ timeindex: int,
+ timestep: int,
+ sample: torch.Tensor,
+ eta: float = 0.0,
+ use_clipped_model_output: bool = False,
+ generator=None,
+ variance_noise: Optional[torch.Tensor] = None,
+ return_dict: bool = True,
+ ) -> Union[LCMSchedulerOutput, Tuple]:
+ """
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
+ process from the learned model outputs (most often the predicted noise).
+ Args:
+ model_output (`torch.Tensor`):
+ The direct output from learned diffusion model.
+ timestep (`float`):
+ The current discrete timestep in the diffusion chain.
+ sample (`torch.Tensor`):
+ A current instance of a sample created by the diffusion process.
+ eta (`float`):
+ The weight of noise for added noise in diffusion step.
+ use_clipped_model_output (`bool`, defaults to `False`):
+ If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary
+ because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no
+ clipping has happened, "corrected" `model_output` would coincide with the one provided as input and
+ `use_clipped_model_output` has no effect.
+ generator (`torch.Generator`, *optional*):
+ A random number generator.
+ variance_noise (`torch.Tensor`):
+ Alternative to generating noise with `generator` by directly providing the noise for the variance
+ itself. Useful for methods such as [`CycleDiffusion`].
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] or `tuple`.
+ Returns:
+ [`~schedulers.scheduling_utils.LCMSchedulerOutput`] or `tuple`:
+ If return_dict is `True`, [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] is returned, otherwise a
+ tuple is returned where the first element is the sample tensor.
+ """
+ if self.num_inference_steps is None:
+ raise ValueError(
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
+ )
+
+ # 1. get previous step value
+ prev_timeindex = timeindex + 1
+ if prev_timeindex < len(self.timesteps):
+ prev_timestep = self.timesteps[prev_timeindex]
+ else:
+ prev_timestep = timestep
+
+ # 2. compute alphas, betas
+ alpha_prod_t = self.alphas_cumprod[timestep]
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
+
+ beta_prod_t = 1 - alpha_prod_t
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
+
+ # 3. Get scalings for boundary conditions
+ c_skip, c_out = self.get_scalings_for_boundary_condition_discrete(timestep)
+
+ # 4. Different Parameterization:
+ parameterization = self.config.prediction_type
+
+ if parameterization == "epsilon": # noise-prediction
+ pred_x0 = (sample - beta_prod_t.sqrt() * model_output) / alpha_prod_t.sqrt()
+
+ elif parameterization == "sample": # x-prediction
+ pred_x0 = model_output
+
+ elif parameterization == "v_prediction": # v-prediction
+ pred_x0 = alpha_prod_t.sqrt() * sample - beta_prod_t.sqrt() * model_output
+
+ # 4. Denoise model output using boundary conditions
+ denoised = c_out * pred_x0 + c_skip * sample
+
+ # 5. Sample z ~ N(0, I), For MultiStep Inference
+ # Noise is not used for one-step sampling.
+ if len(self.timesteps) > 1:
+ noise = torch.randn(model_output.shape).to(model_output.device)
+ prev_sample = alpha_prod_t_prev.sqrt() * denoised + beta_prod_t_prev.sqrt() * noise
+ else:
+ prev_sample = denoised
+
+ if not return_dict:
+ return (prev_sample, denoised)
+
+ return LCMSchedulerOutput(prev_sample=prev_sample, denoised=denoised)
+
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
+ def add_noise(
+ self,
+ original_samples: torch.Tensor,
+ noise: torch.Tensor,
+ timesteps: torch.IntTensor,
+ ) -> torch.Tensor:
+ # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
+ alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
+ timesteps = timesteps.to(original_samples.device)
+
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
+ while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
+
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
+ while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
+
+ noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
+ return noisy_samples
+
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity
+ def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor:
+ # Make sure alphas_cumprod and timestep have same device and dtype as sample
+ alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype)
+ timesteps = timesteps.to(sample.device)
+
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
+ while len(sqrt_alpha_prod.shape) < len(sample.shape):
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
+
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
+ while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
+
+ velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
+ return velocity
+
+ def __len__(self):
+ return self.config.num_train_timesteps
diff --git a/diffusers/examples/community/latent_consistency_interpolate.py b/diffusers/examples/community/latent_consistency_interpolate.py
new file mode 100644
index 0000000000000000000000000000000000000000..84adc125b1911981d6057ecade44578cd17fb817
--- /dev/null
+++ b/diffusers/examples/community/latent_consistency_interpolate.py
@@ -0,0 +1,999 @@
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import numpy as np
+import torch
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
+
+from diffusers.image_processor import VaeImageProcessor
+from diffusers.loaders import FromSingleFileMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
+from diffusers.models import AutoencoderKL, UNet2DConditionModel
+from diffusers.models.lora import adjust_lora_scale_text_encoder
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
+from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
+from diffusers.schedulers import LCMScheduler
+from diffusers.utils import (
+ USE_PEFT_BACKEND,
+ deprecate,
+ logging,
+ replace_example_docstring,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
+from diffusers.utils.torch_utils import randn_tensor
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> import numpy as np
+
+ >>> from diffusers import DiffusionPipeline
+
+ >>> pipe = DiffusionPipeline.from_pretrained("SimianLuo/LCM_Dreamshaper_v7", custom_pipeline="latent_consistency_interpolate")
+ >>> # To save GPU memory, torch.float16 can be used, but it may compromise image quality.
+ >>> pipe.to(torch_device="cuda", torch_dtype=torch.float32)
+
+ >>> prompts = ["A cat", "A dog", "A horse"]
+ >>> num_inference_steps = 4
+ >>> num_interpolation_steps = 24
+ >>> seed = 1337
+
+ >>> torch.manual_seed(seed)
+ >>> np.random.seed(seed)
+
+ >>> images = pipe(
+ prompt=prompts,
+ height=512,
+ width=512,
+ num_inference_steps=num_inference_steps,
+ num_interpolation_steps=num_interpolation_steps,
+ guidance_scale=8.0,
+ embedding_interpolation_type="lerp",
+ latent_interpolation_type="slerp",
+ process_batch_size=4, # Make it higher or lower based on your GPU memory
+ generator=torch.Generator(seed),
+ )
+
+ >>> # Save the images as a video
+ >>> import imageio
+ >>> from PIL import Image
+
+ >>> def pil_to_video(images: List[Image.Image], filename: str, fps: int = 60) -> None:
+ frames = [np.array(image) for image in images]
+ with imageio.get_writer(filename, fps=fps) as video_writer:
+ for frame in frames:
+ video_writer.append_data(frame)
+
+ >>> pil_to_video(images, "lcm_interpolate.mp4", fps=24)
+ ```
+"""
+
+
+def lerp(
+ v0: Union[torch.Tensor, np.ndarray],
+ v1: Union[torch.Tensor, np.ndarray],
+ t: Union[float, torch.Tensor, np.ndarray],
+) -> Union[torch.Tensor, np.ndarray]:
+ """
+ Linearly interpolate between two vectors/tensors.
+
+ Args:
+ v0 (`torch.Tensor` or `np.ndarray`): First vector/tensor.
+ v1 (`torch.Tensor` or `np.ndarray`): Second vector/tensor.
+ t: (`float`, `torch.Tensor`, or `np.ndarray`):
+ Interpolation factor. If float, must be between 0 and 1. If np.ndarray or
+ torch.Tensor, must be one dimensional with values between 0 and 1.
+
+ Returns:
+ Union[torch.Tensor, np.ndarray]
+ Interpolated vector/tensor between v0 and v1.
+ """
+ inputs_are_torch = False
+ t_is_float = False
+
+ if isinstance(v0, torch.Tensor):
+ inputs_are_torch = True
+ input_device = v0.device
+ v0 = v0.cpu().numpy()
+ v1 = v1.cpu().numpy()
+
+ if isinstance(t, torch.Tensor):
+ inputs_are_torch = True
+ input_device = t.device
+ t = t.cpu().numpy()
+ elif isinstance(t, float):
+ t_is_float = True
+ t = np.array([t])
+
+ t = t[..., None]
+ v0 = v0[None, ...]
+ v1 = v1[None, ...]
+ v2 = (1 - t) * v0 + t * v1
+
+ if t_is_float and v0.ndim > 1:
+ assert v2.shape[0] == 1
+ v2 = np.squeeze(v2, axis=0)
+ if inputs_are_torch:
+ v2 = torch.from_numpy(v2).to(input_device)
+
+ return v2
+
+
+def slerp(
+ v0: Union[torch.Tensor, np.ndarray],
+ v1: Union[torch.Tensor, np.ndarray],
+ t: Union[float, torch.Tensor, np.ndarray],
+ DOT_THRESHOLD=0.9995,
+) -> Union[torch.Tensor, np.ndarray]:
+ """
+ Spherical linear interpolation between two vectors/tensors.
+
+ Args:
+ v0 (`torch.Tensor` or `np.ndarray`): First vector/tensor.
+ v1 (`torch.Tensor` or `np.ndarray`): Second vector/tensor.
+ t: (`float`, `torch.Tensor`, or `np.ndarray`):
+ Interpolation factor. If float, must be between 0 and 1. If np.ndarray or
+ torch.Tensor, must be one dimensional with values between 0 and 1.
+ DOT_THRESHOLD (`float`, *optional*, default=0.9995):
+ Threshold for when to use linear interpolation instead of spherical interpolation.
+
+ Returns:
+ `torch.Tensor` or `np.ndarray`:
+ Interpolated vector/tensor between v0 and v1.
+ """
+ inputs_are_torch = False
+ t_is_float = False
+
+ if isinstance(v0, torch.Tensor):
+ inputs_are_torch = True
+ input_device = v0.device
+ v0 = v0.cpu().numpy()
+ v1 = v1.cpu().numpy()
+
+ if isinstance(t, torch.Tensor):
+ inputs_are_torch = True
+ input_device = t.device
+ t = t.cpu().numpy()
+ elif isinstance(t, float):
+ t_is_float = True
+ t = np.array([t], dtype=v0.dtype)
+
+ dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1)))
+ if np.abs(dot) > DOT_THRESHOLD:
+ # v1 and v2 are close to parallel
+ # Use linear interpolation instead
+ v2 = lerp(v0, v1, t)
+ else:
+ theta_0 = np.arccos(dot)
+ sin_theta_0 = np.sin(theta_0)
+ theta_t = theta_0 * t
+ sin_theta_t = np.sin(theta_t)
+ s0 = np.sin(theta_0 - theta_t) / sin_theta_0
+ s1 = sin_theta_t / sin_theta_0
+ s0 = s0[..., None]
+ s1 = s1[..., None]
+ v0 = v0[None, ...]
+ v1 = v1[None, ...]
+ v2 = s0 * v0 + s1 * v1
+
+ if t_is_float and v0.ndim > 1:
+ assert v2.shape[0] == 1
+ v2 = np.squeeze(v2, axis=0)
+ if inputs_are_torch:
+ v2 = torch.from_numpy(v2).to(input_device)
+
+ return v2
+
+
+class LatentConsistencyModelWalkPipeline(
+ DiffusionPipeline,
+ StableDiffusionMixin,
+ TextualInversionLoaderMixin,
+ StableDiffusionLoraLoaderMixin,
+ FromSingleFileMixin,
+):
+ r"""
+ Pipeline for text-to-image generation using a latent consistency model.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
+
+ The pipeline also inherits the following loading methods:
+ - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
+ - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
+ - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
+ - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
+ text_encoder ([`~transformers.CLIPTextModel`]):
+ Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
+ tokenizer ([`~transformers.CLIPTokenizer`]):
+ A `CLIPTokenizer` to tokenize text.
+ unet ([`UNet2DConditionModel`]):
+ A `UNet2DConditionModel` to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Currently only
+ supports [`LCMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
+ about a model's potential harms.
+ feature_extractor ([`~transformers.CLIPImageProcessor`]):
+ A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
+ requires_safety_checker (`bool`, *optional*, defaults to `True`):
+ Whether the pipeline requires a safety checker component.
+ """
+
+ model_cpu_offload_seq = "text_encoder->unet->vae"
+ _optional_components = ["safety_checker", "feature_extractor"]
+ _exclude_from_cpu_offload = ["safety_checker"]
+ _callback_tensor_inputs = ["latents", "denoised", "prompt_embeds", "w_embedding"]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: LCMScheduler,
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPImageProcessor,
+ requires_safety_checker: bool = True,
+ ):
+ super().__init__()
+
+ if safety_checker is None and requires_safety_checker:
+ logger.warning(
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
+ )
+
+ if safety_checker is not None and feature_extractor is None:
+ raise ValueError(
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
+ )
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt=None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ lora_scale: Optional[float] = None,
+ clip_skip: Optional[int] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ lora_scale (`float`, *optional*):
+ A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+ """
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ # dynamically adjust the LoRA scale
+ if not USE_PEFT_BACKEND:
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
+ else:
+ scale_lora_layers(self.text_encoder, lora_scale)
+
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ # textual inversion: process multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = self.tokenizer.batch_decode(
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
+ )
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = text_inputs.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ if clip_skip is None:
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
+ prompt_embeds = prompt_embeds[0]
+ else:
+ prompt_embeds = self.text_encoder(
+ text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True
+ )
+ # Access the `hidden_states` first, that contains a tuple of
+ # all the hidden states from the encoder layers. Then index into
+ # the tuple to access the hidden states from the desired layer.
+ prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
+ # We also need to apply the final LayerNorm here to not mess with the
+ # representations. The `last_hidden_states` that we typically use for
+ # obtaining the final prompt representations passes through the LayerNorm
+ # layer.
+ prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
+
+ if self.text_encoder is not None:
+ prompt_embeds_dtype = self.text_encoder.dtype
+ elif self.unet is not None:
+ prompt_embeds_dtype = self.unet.dtype
+ else:
+ prompt_embeds_dtype = prompt_embeds.dtype
+
+ prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ # textual inversion: process multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = uncond_input.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ negative_prompt_embeds = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ negative_prompt_embeds = negative_prompt_embeds[0]
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder, lora_scale)
+
+ return prompt_embeds, negative_prompt_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
+ def run_safety_checker(self, image, device, dtype):
+ if self.safety_checker is None:
+ has_nsfw_concept = None
+ else:
+ if torch.is_tensor(image):
+ feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
+ else:
+ feature_extractor_input = self.image_processor.numpy_to_pil(image)
+ safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
+ image, has_nsfw_concept = self.safety_checker(
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
+ )
+ return image, has_nsfw_concept
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
+ shape = (
+ batch_size,
+ num_channels_latents,
+ int(height) // self.vae_scale_factor,
+ int(width) // self.vae_scale_factor,
+ )
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
+ """
+ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
+
+ Args:
+ timesteps (`torch.Tensor`):
+ generate embedding vectors at these timesteps
+ embedding_dim (`int`, *optional*, defaults to 512):
+ dimension of the embeddings to generate
+ dtype:
+ data type of the generated embeddings
+
+ Returns:
+ `torch.Tensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
+ """
+ assert len(w.shape) == 1
+ w = w * 1000.0
+
+ half_dim = embedding_dim // 2
+ emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
+ emb = w.to(dtype)[:, None] * emb[None, :]
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+ if embedding_dim % 2 == 1: # zero pad
+ emb = torch.nn.functional.pad(emb, (0, 1))
+ assert emb.shape == (w.shape[0], embedding_dim)
+ return emb
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ # Currently StableDiffusionPipeline.check_inputs with negative prompt stuff removed
+ def check_inputs(
+ self,
+ prompt: Union[str, List[str]],
+ height: int,
+ width: int,
+ callback_steps: int,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ callback_on_step_end_tensor_inputs=None,
+ ):
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ @torch.no_grad()
+ def interpolate_embedding(
+ self,
+ start_embedding: torch.Tensor,
+ end_embedding: torch.Tensor,
+ num_interpolation_steps: Union[int, List[int]],
+ interpolation_type: str,
+ ) -> torch.Tensor:
+ if interpolation_type == "lerp":
+ interpolation_fn = lerp
+ elif interpolation_type == "slerp":
+ interpolation_fn = slerp
+ else:
+ raise ValueError(
+ f"embedding_interpolation_type must be one of ['lerp', 'slerp'], got {interpolation_type}."
+ )
+
+ embedding = torch.cat([start_embedding, end_embedding])
+ steps = torch.linspace(0, 1, num_interpolation_steps, dtype=embedding.dtype).cpu().numpy()
+ steps = np.expand_dims(steps, axis=tuple(range(1, embedding.ndim)))
+ interpolations = []
+
+ # Interpolate between text embeddings
+ # TODO(aryan): Think of a better way of doing this
+ # See if it can be done parallelly instead
+ for i in range(embedding.shape[0] - 1):
+ interpolations.append(interpolation_fn(embedding[i], embedding[i + 1], steps).squeeze(dim=1))
+
+ interpolations = torch.cat(interpolations)
+ return interpolations
+
+ @torch.no_grad()
+ def interpolate_latent(
+ self,
+ start_latent: torch.Tensor,
+ end_latent: torch.Tensor,
+ num_interpolation_steps: Union[int, List[int]],
+ interpolation_type: str,
+ ) -> torch.Tensor:
+ if interpolation_type == "lerp":
+ interpolation_fn = lerp
+ elif interpolation_type == "slerp":
+ interpolation_fn = slerp
+
+ latent = torch.cat([start_latent, end_latent])
+ steps = torch.linspace(0, 1, num_interpolation_steps, dtype=latent.dtype).cpu().numpy()
+ steps = np.expand_dims(steps, axis=tuple(range(1, latent.ndim)))
+ interpolations = []
+
+ # Interpolate between latents
+ # TODO: Think of a better way of doing this
+ # See if it can be done parallelly instead
+ for i in range(latent.shape[0] - 1):
+ interpolations.append(interpolation_fn(latent[i], latent[i + 1], steps).squeeze(dim=1))
+
+ return torch.cat(interpolations)
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def cross_attention_kwargs(self):
+ return self._cross_attention_kwargs
+
+ @property
+ def clip_skip(self):
+ return self._clip_skip
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 4,
+ num_interpolation_steps: int = 8,
+ original_inference_steps: int = None,
+ guidance_scale: float = 8.5,
+ num_images_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ clip_skip: Optional[int] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ embedding_interpolation_type: str = "lerp",
+ latent_interpolation_type: str = "slerp",
+ process_batch_size: int = 4,
+ **kwargs,
+ ):
+ r"""
+ The call function to the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
+ height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ original_inference_steps (`int`, *optional*):
+ The original number of inference steps use to generate a linearly-spaced timestep schedule, from which
+ we will draw `num_inference_steps` evenly spaced timesteps from as our final timestep schedule,
+ following the Skipping-Step method in the paper (see Section 4.3). If not set this will default to the
+ scheduler's `original_inference_steps` attribute.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ A higher guidance scale value encourages the model to generate images closely linked to the text
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
+ Note that the original latent consistency models paper uses a different CFG formulation where the
+ guidance scales are decreased by 1 (so in the paper formulation CFG is enabled when `guidance_scale >
+ 0`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
+ generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor is generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
+ provided, text embeddings are generated from the `prompt` input argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ embedding_interpolation_type (`str`, *optional*, defaults to `"lerp"`):
+ The type of interpolation to use for interpolating between text embeddings. Choose between `"lerp"` and `"slerp"`.
+ latent_interpolation_type (`str`, *optional*, defaults to `"slerp"`):
+ The type of interpolation to use for interpolating between latents. Choose between `"lerp"` and `"slerp"`.
+ process_batch_size (`int`, *optional*, defaults to 4):
+ The batch size to use for processing the images. This is useful when generating a large number of images
+ and you want to avoid running out of memory.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
+ otherwise a `tuple` is returned where the first element is a list with the generated images and the
+ second element is a list of `bool`s indicating whether the corresponding generated image contains
+ "not-safe-for-work" (nsfw) content.
+ """
+
+ callback = kwargs.pop("callback", None)
+ callback_steps = kwargs.pop("callback_steps", None)
+
+ if callback is not None:
+ deprecate(
+ "callback",
+ "1.0.0",
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
+ )
+ if callback_steps is not None:
+ deprecate(
+ "callback_steps",
+ "1.0.0",
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
+ )
+
+ # 0. Default height and width to unet
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(prompt, height, width, callback_steps, prompt_embeds, callback_on_step_end_tensor_inputs)
+ self._guidance_scale = guidance_scale
+ self._clip_skip = clip_skip
+ self._cross_attention_kwargs = cross_attention_kwargs
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+ if batch_size < 2:
+ raise ValueError(f"`prompt` must have length of at least 2 but found {batch_size}")
+ if num_images_per_prompt != 1:
+ raise ValueError("`num_images_per_prompt` must be `1` as no other value is supported yet")
+ if prompt_embeds is not None:
+ raise ValueError("`prompt_embeds` must be None since it is not supported yet")
+ if latents is not None:
+ raise ValueError("`latents` must be None since it is not supported yet")
+
+ device = self._execution_device
+ # do_classifier_free_guidance = guidance_scale > 1.0
+
+ lora_scale = (
+ self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
+ )
+
+ self.scheduler.set_timesteps(num_inference_steps, device, original_inference_steps=original_inference_steps)
+ timesteps = self.scheduler.timesteps
+ num_channels_latents = self.unet.config.in_channels
+ # bs = batch_size * num_images_per_prompt
+
+ # 3. Encode initial input prompt
+ prompt_embeds_1, _ = self.encode_prompt(
+ prompt[:1],
+ device,
+ num_images_per_prompt=num_images_per_prompt,
+ do_classifier_free_guidance=False,
+ negative_prompt=None,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=None,
+ lora_scale=lora_scale,
+ clip_skip=self.clip_skip,
+ )
+
+ # 4. Prepare initial latent variables
+ latents_1 = self.prepare_latents(
+ 1,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds_1.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, None)
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ self._num_timesteps = len(timesteps)
+ images = []
+
+ # 5. Iterate over prompts and perform latent walk. Note that we do this two prompts at a time
+ # otherwise the memory usage ends up being too high.
+ with self.progress_bar(total=batch_size - 1) as prompt_progress_bar:
+ for i in range(1, batch_size):
+ # 6. Encode current prompt
+ prompt_embeds_2, _ = self.encode_prompt(
+ prompt[i : i + 1],
+ device,
+ num_images_per_prompt=num_images_per_prompt,
+ do_classifier_free_guidance=False,
+ negative_prompt=None,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=None,
+ lora_scale=lora_scale,
+ clip_skip=self.clip_skip,
+ )
+
+ # 7. Prepare current latent variables
+ latents_2 = self.prepare_latents(
+ 1,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds_2.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 8. Interpolate between previous and current prompt embeddings and latents
+ inference_embeddings = self.interpolate_embedding(
+ start_embedding=prompt_embeds_1,
+ end_embedding=prompt_embeds_2,
+ num_interpolation_steps=num_interpolation_steps,
+ interpolation_type=embedding_interpolation_type,
+ )
+ inference_latents = self.interpolate_latent(
+ start_latent=latents_1,
+ end_latent=latents_2,
+ num_interpolation_steps=num_interpolation_steps,
+ interpolation_type=latent_interpolation_type,
+ )
+ next_prompt_embeds = inference_embeddings[-1:].detach().clone()
+ next_latents = inference_latents[-1:].detach().clone()
+ bs = num_interpolation_steps
+
+ # 9. Perform inference in batches. Note the use of `process_batch_size` to control the batch size
+ # of the inference. This is useful for reducing memory usage and can be configured based on the
+ # available GPU memory.
+ with self.progress_bar(
+ total=(bs + process_batch_size - 1) // process_batch_size
+ ) as batch_progress_bar:
+ for batch_index in range(0, bs, process_batch_size):
+ batch_inference_latents = inference_latents[batch_index : batch_index + process_batch_size]
+ batch_inference_embeddings = inference_embeddings[
+ batch_index : batch_index + process_batch_size
+ ]
+
+ self.scheduler.set_timesteps(
+ num_inference_steps, device, original_inference_steps=original_inference_steps
+ )
+ timesteps = self.scheduler.timesteps
+
+ current_bs = batch_inference_embeddings.shape[0]
+ w = torch.tensor(self.guidance_scale - 1).repeat(current_bs)
+ w_embedding = self.get_guidance_scale_embedding(
+ w, embedding_dim=self.unet.config.time_cond_proj_dim
+ ).to(device=device, dtype=latents_1.dtype)
+
+ # 10. Perform inference for current batch
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for index, t in enumerate(timesteps):
+ batch_inference_latents = batch_inference_latents.to(batch_inference_embeddings.dtype)
+
+ # model prediction (v-prediction, eps, x)
+ model_pred = self.unet(
+ batch_inference_latents,
+ t,
+ timestep_cond=w_embedding,
+ encoder_hidden_states=batch_inference_embeddings,
+ cross_attention_kwargs=self.cross_attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ # compute the previous noisy sample x_t -> x_t-1
+ batch_inference_latents, denoised = self.scheduler.step(
+ model_pred, t, batch_inference_latents, **extra_step_kwargs, return_dict=False
+ )
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, index, t, callback_kwargs)
+
+ batch_inference_latents = callback_outputs.pop("latents", batch_inference_latents)
+ batch_inference_embeddings = callback_outputs.pop(
+ "prompt_embeds", batch_inference_embeddings
+ )
+ w_embedding = callback_outputs.pop("w_embedding", w_embedding)
+ denoised = callback_outputs.pop("denoised", denoised)
+
+ # call the callback, if provided
+ if index == len(timesteps) - 1 or (
+ (index + 1) > num_warmup_steps and (index + 1) % self.scheduler.order == 0
+ ):
+ progress_bar.update()
+ if callback is not None and index % callback_steps == 0:
+ step_idx = index // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, batch_inference_latents)
+
+ denoised = denoised.to(batch_inference_embeddings.dtype)
+
+ # Note: This is not supported because you would get black images in your latent walk if
+ # NSFW concept is detected
+ # if not output_type == "latent":
+ # image = self.vae.decode(denoised / self.vae.config.scaling_factor, return_dict=False)[0]
+ # image, has_nsfw_concept = self.run_safety_checker(image, device, inference_embeddings.dtype)
+ # else:
+ # image = denoised
+ # has_nsfw_concept = None
+
+ # if has_nsfw_concept is None:
+ # do_denormalize = [True] * image.shape[0]
+ # else:
+ # do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
+
+ image = self.vae.decode(denoised / self.vae.config.scaling_factor, return_dict=False)[0]
+ do_denormalize = [True] * image.shape[0]
+ has_nsfw_concept = None
+
+ image = self.image_processor.postprocess(
+ image, output_type=output_type, do_denormalize=do_denormalize
+ )
+ images.append(image)
+
+ batch_progress_bar.update()
+
+ prompt_embeds_1 = next_prompt_embeds
+ latents_1 = next_latents
+
+ prompt_progress_bar.update()
+
+ # 11. Determine what should be returned
+ if output_type == "pil":
+ images = [image for image_list in images for image in image_list]
+ elif output_type == "np":
+ images = np.concatenate(images)
+ elif output_type == "pt":
+ images = torch.cat(images)
+ else:
+ raise ValueError("`output_type` must be one of 'pil', 'np' or 'pt'.")
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (images, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=images, nsfw_content_detected=has_nsfw_concept)
diff --git a/diffusers/examples/community/latent_consistency_txt2img.py b/diffusers/examples/community/latent_consistency_txt2img.py
new file mode 100644
index 0000000000000000000000000000000000000000..9f25a6db2722e9923a5a269385b166880ff73dce
--- /dev/null
+++ b/diffusers/examples/community/latent_consistency_txt2img.py
@@ -0,0 +1,729 @@
+# Copyright 2024 Stanford University Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
+# and https://github.com/hojonathanho/diffusion
+
+import math
+from dataclasses import dataclass
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
+
+from diffusers import AutoencoderKL, ConfigMixin, DiffusionPipeline, SchedulerMixin, UNet2DConditionModel, logging
+from diffusers.configuration_utils import register_to_config
+from diffusers.image_processor import VaeImageProcessor
+from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
+from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
+from diffusers.utils import BaseOutput
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class LatentConsistencyModelPipeline(DiffusionPipeline):
+ _optional_components = ["scheduler"]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: "LCMScheduler",
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPImageProcessor,
+ requires_safety_checker: bool = True,
+ ):
+ super().__init__()
+
+ scheduler = (
+ scheduler
+ if scheduler is not None
+ else LCMScheduler(
+ beta_start=0.00085, beta_end=0.0120, beta_schedule="scaled_linear", prediction_type="epsilon"
+ )
+ )
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+
+ def _encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ prompt_embeds: None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ """
+
+ if prompt is not None and isinstance(prompt, str):
+ pass
+ elif prompt is not None and isinstance(prompt, list):
+ len(prompt)
+ else:
+ prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = self.tokenizer.batch_decode(
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
+ )
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = text_inputs.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ prompt_embeds = self.text_encoder(
+ text_input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ prompt_embeds = prompt_embeds[0]
+
+ if self.text_encoder is not None:
+ prompt_embeds_dtype = self.text_encoder.dtype
+ elif self.unet is not None:
+ prompt_embeds_dtype = self.unet.dtype
+ else:
+ prompt_embeds_dtype = prompt_embeds.dtype
+
+ prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # Don't need to get uncond prompt embedding because of LCM Guided Distillation
+ return prompt_embeds
+
+ def run_safety_checker(self, image, device, dtype):
+ if self.safety_checker is None:
+ has_nsfw_concept = None
+ else:
+ if torch.is_tensor(image):
+ feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
+ else:
+ feature_extractor_input = self.image_processor.numpy_to_pil(image)
+ safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
+ image, has_nsfw_concept = self.safety_checker(
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
+ )
+ return image, has_nsfw_concept
+
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, latents=None):
+ shape = (
+ batch_size,
+ num_channels_latents,
+ int(height) // self.vae_scale_factor,
+ int(width) // self.vae_scale_factor,
+ )
+ if latents is None:
+ latents = torch.randn(shape, dtype=dtype).to(device)
+ else:
+ latents = latents.to(device)
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ def get_w_embedding(self, w, embedding_dim=512, dtype=torch.float32):
+ """
+ see https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
+ Args:
+ timesteps: torch.Tensor: generate embedding vectors at these timesteps
+ embedding_dim: int: dimension of the embeddings to generate
+ dtype: data type of the generated embeddings
+ Returns:
+ embedding vectors with shape `(len(timesteps), embedding_dim)`
+ """
+ assert len(w.shape) == 1
+ w = w * 1000.0
+
+ half_dim = embedding_dim // 2
+ emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
+ emb = w.to(dtype)[:, None] * emb[None, :]
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+ if embedding_dim % 2 == 1: # zero pad
+ emb = torch.nn.functional.pad(emb, (0, 1))
+ assert emb.shape == (w.shape[0], embedding_dim)
+ return emb
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ height: Optional[int] = 768,
+ width: Optional[int] = 768,
+ guidance_scale: float = 7.5,
+ num_images_per_prompt: Optional[int] = 1,
+ latents: Optional[torch.Tensor] = None,
+ num_inference_steps: int = 4,
+ lcm_origin_steps: int = 50,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ ):
+ # 0. Default height and width to unet
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+ # do_classifier_free_guidance = guidance_scale > 0.0 # In LCM Implementation: cfg_noise = noise_cond + cfg_scale * (noise_cond - noise_uncond) , (cfg_scale > 0.0 using CFG)
+
+ # 3. Encode input prompt
+ prompt_embeds = self._encode_prompt(
+ prompt,
+ device,
+ num_images_per_prompt,
+ prompt_embeds=prompt_embeds,
+ )
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, lcm_origin_steps)
+ timesteps = self.scheduler.timesteps
+
+ # 5. Prepare latent variable
+ num_channels_latents = self.unet.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ latents,
+ )
+ bs = batch_size * num_images_per_prompt
+
+ # 6. Get Guidance Scale Embedding
+ w = torch.tensor(guidance_scale).repeat(bs)
+ w_embedding = self.get_w_embedding(w, embedding_dim=256).to(device=device, dtype=latents.dtype)
+
+ # 7. LCM MultiStep Sampling Loop:
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ ts = torch.full((bs,), t, device=device, dtype=torch.long)
+ latents = latents.to(prompt_embeds.dtype)
+
+ # model prediction (v-prediction, eps, x)
+ model_pred = self.unet(
+ latents,
+ ts,
+ timestep_cond=w_embedding,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs=cross_attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents, denoised = self.scheduler.step(model_pred, i, t, latents, return_dict=False)
+
+ # # call the callback, if provided
+ # if i == len(timesteps) - 1:
+ progress_bar.update()
+
+ denoised = denoised.to(prompt_embeds.dtype)
+ if not output_type == "latent":
+ image = self.vae.decode(denoised / self.vae.config.scaling_factor, return_dict=False)[0]
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
+ else:
+ image = denoised
+ has_nsfw_concept = None
+
+ if has_nsfw_concept is None:
+ do_denormalize = [True] * image.shape[0]
+ else:
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
+
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
+
+
+@dataclass
+# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DDIM
+class LCMSchedulerOutput(BaseOutput):
+ """
+ Output class for the scheduler's `step` function output.
+ Args:
+ prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
+ Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
+ denoising loop.
+ pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
+ The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
+ `pred_original_sample` can be used to preview progress or for guidance.
+ """
+
+ prev_sample: torch.Tensor
+ denoised: Optional[torch.Tensor] = None
+
+
+# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
+def betas_for_alpha_bar(
+ num_diffusion_timesteps,
+ max_beta=0.999,
+ alpha_transform_type="cosine",
+):
+ """
+ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
+ (1-beta) over time from t = [0,1].
+ Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
+ to that part of the diffusion process.
+ Args:
+ num_diffusion_timesteps (`int`): the number of betas to produce.
+ max_beta (`float`): the maximum beta to use; use values lower than 1 to
+ prevent singularities.
+ alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
+ Choose from `cosine` or `exp`
+ Returns:
+ betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
+ """
+ if alpha_transform_type == "cosine":
+
+ def alpha_bar_fn(t):
+ return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
+
+ elif alpha_transform_type == "exp":
+
+ def alpha_bar_fn(t):
+ return math.exp(t * -12.0)
+
+ else:
+ raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
+
+ betas = []
+ for i in range(num_diffusion_timesteps):
+ t1 = i / num_diffusion_timesteps
+ t2 = (i + 1) / num_diffusion_timesteps
+ betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
+ return torch.tensor(betas, dtype=torch.float32)
+
+
+def rescale_zero_terminal_snr(betas):
+ """
+ Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
+ Args:
+ betas (`torch.Tensor`):
+ the betas that the scheduler is being initialized with.
+ Returns:
+ `torch.Tensor`: rescaled betas with zero terminal SNR
+ """
+ # Convert betas to alphas_bar_sqrt
+ alphas = 1.0 - betas
+ alphas_cumprod = torch.cumprod(alphas, dim=0)
+ alphas_bar_sqrt = alphas_cumprod.sqrt()
+
+ # Store old values.
+ alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
+ alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
+
+ # Shift so the last timestep is zero.
+ alphas_bar_sqrt -= alphas_bar_sqrt_T
+
+ # Scale so the first timestep is back to the old value.
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
+
+ # Convert alphas_bar_sqrt to betas
+ alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
+ alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
+ alphas = torch.cat([alphas_bar[0:1], alphas])
+ betas = 1 - alphas
+
+ return betas
+
+
+class LCMScheduler(SchedulerMixin, ConfigMixin):
+ """
+ `LCMScheduler` extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with
+ non-Markovian guidance.
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
+ methods the library implements for all schedulers such as loading and saving.
+ Args:
+ num_train_timesteps (`int`, defaults to 1000):
+ The number of diffusion steps to train the model.
+ beta_start (`float`, defaults to 0.0001):
+ The starting `beta` value of inference.
+ beta_end (`float`, defaults to 0.02):
+ The final `beta` value.
+ beta_schedule (`str`, defaults to `"linear"`):
+ The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
+ `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
+ trained_betas (`np.ndarray`, *optional*):
+ Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
+ clip_sample (`bool`, defaults to `True`):
+ Clip the predicted sample for numerical stability.
+ clip_sample_range (`float`, defaults to 1.0):
+ The maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
+ set_alpha_to_one (`bool`, defaults to `True`):
+ Each diffusion step uses the alphas product value at that step and at the previous one. For the final step
+ there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
+ otherwise it uses the alpha value at step 0.
+ steps_offset (`int`, defaults to 0):
+ An offset added to the inference steps, as required by some model families.
+ prediction_type (`str`, defaults to `epsilon`, *optional*):
+ Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
+ `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
+ Video](https://imagen.research.google/video/paper.pdf) paper).
+ thresholding (`bool`, defaults to `False`):
+ Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
+ as Stable Diffusion.
+ dynamic_thresholding_ratio (`float`, defaults to 0.995):
+ The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
+ sample_max_value (`float`, defaults to 1.0):
+ The threshold value for dynamic thresholding. Valid only when `thresholding=True`.
+ timestep_spacing (`str`, defaults to `"leading"`):
+ The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
+ Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
+ rescale_betas_zero_snr (`bool`, defaults to `False`):
+ Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
+ dark samples instead of limiting it to samples with medium brightness. Loosely related to
+ [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
+ """
+
+ # _compatibles = [e.name for e in KarrasDiffusionSchedulers]
+ order = 1
+
+ @register_to_config
+ def __init__(
+ self,
+ num_train_timesteps: int = 1000,
+ beta_start: float = 0.0001,
+ beta_end: float = 0.02,
+ beta_schedule: str = "linear",
+ trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
+ clip_sample: bool = True,
+ set_alpha_to_one: bool = True,
+ steps_offset: int = 0,
+ prediction_type: str = "epsilon",
+ thresholding: bool = False,
+ dynamic_thresholding_ratio: float = 0.995,
+ clip_sample_range: float = 1.0,
+ sample_max_value: float = 1.0,
+ timestep_spacing: str = "leading",
+ rescale_betas_zero_snr: bool = False,
+ ):
+ if trained_betas is not None:
+ self.betas = torch.tensor(trained_betas, dtype=torch.float32)
+ elif beta_schedule == "linear":
+ self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
+ elif beta_schedule == "scaled_linear":
+ # this schedule is very specific to the latent diffusion model.
+ self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
+ elif beta_schedule == "squaredcos_cap_v2":
+ # Glide cosine schedule
+ self.betas = betas_for_alpha_bar(num_train_timesteps)
+ else:
+ raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
+
+ # Rescale for zero SNR
+ if rescale_betas_zero_snr:
+ self.betas = rescale_zero_terminal_snr(self.betas)
+
+ self.alphas = 1.0 - self.betas
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
+
+ # At every step in ddim, we are looking into the previous alphas_cumprod
+ # For the final step, there is no previous alphas_cumprod because we are already at 0
+ # `set_alpha_to_one` decides whether we set this parameter simply to one or
+ # whether we use the final alpha of the "non-previous" one.
+ self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
+
+ # standard deviation of the initial noise distribution
+ self.init_noise_sigma = 1.0
+
+ # setable values
+ self.num_inference_steps = None
+ self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))
+
+ def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
+ """
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
+ current timestep.
+ Args:
+ sample (`torch.Tensor`):
+ The input sample.
+ timestep (`int`, *optional*):
+ The current timestep in the diffusion chain.
+ Returns:
+ `torch.Tensor`:
+ A scaled input sample.
+ """
+ return sample
+
+ def _get_variance(self, timestep, prev_timestep):
+ alpha_prod_t = self.alphas_cumprod[timestep]
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
+ beta_prod_t = 1 - alpha_prod_t
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
+
+ variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
+
+ return variance
+
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
+ def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
+ """
+ "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
+ prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
+ s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
+ pixels from saturation at each step. We find that dynamic thresholding results in significantly better
+ photorealism as well as better image-text alignment, especially when using very large guidance weights."
+ https://arxiv.org/abs/2205.11487
+ """
+ dtype = sample.dtype
+ batch_size, channels, height, width = sample.shape
+
+ if dtype not in (torch.float32, torch.float64):
+ sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
+
+ # Flatten sample for doing quantile calculation along each image
+ sample = sample.reshape(batch_size, channels * height * width)
+
+ abs_sample = sample.abs() # "a certain percentile absolute pixel value"
+
+ s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
+ s = torch.clamp(
+ s, min=1, max=self.config.sample_max_value
+ ) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
+
+ s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
+ sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
+
+ sample = sample.reshape(batch_size, channels, height, width)
+ sample = sample.to(dtype)
+
+ return sample
+
+ def set_timesteps(self, num_inference_steps: int, lcm_origin_steps: int, device: Union[str, torch.device] = None):
+ """
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
+ Args:
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model.
+ """
+
+ if num_inference_steps > self.config.num_train_timesteps:
+ raise ValueError(
+ f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
+ f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
+ f" maximal {self.config.num_train_timesteps} timesteps."
+ )
+
+ self.num_inference_steps = num_inference_steps
+
+ # LCM Timesteps Setting: # Linear Spacing
+ c = self.config.num_train_timesteps // lcm_origin_steps
+ lcm_origin_timesteps = np.asarray(list(range(1, lcm_origin_steps + 1))) * c - 1 # LCM Training Steps Schedule
+ skipping_step = len(lcm_origin_timesteps) // num_inference_steps
+ timesteps = lcm_origin_timesteps[::-skipping_step][:num_inference_steps] # LCM Inference Steps Schedule
+
+ self.timesteps = torch.from_numpy(timesteps.copy()).to(device)
+
+ def get_scalings_for_boundary_condition_discrete(self, t):
+ self.sigma_data = 0.5 # Default: 0.5
+
+ # By dividing 0.1: This is almost a delta function at t=0.
+ c_skip = self.sigma_data**2 / ((t / 0.1) ** 2 + self.sigma_data**2)
+ c_out = (t / 0.1) / ((t / 0.1) ** 2 + self.sigma_data**2) ** 0.5
+ return c_skip, c_out
+
+ def step(
+ self,
+ model_output: torch.Tensor,
+ timeindex: int,
+ timestep: int,
+ sample: torch.Tensor,
+ eta: float = 0.0,
+ use_clipped_model_output: bool = False,
+ generator=None,
+ variance_noise: Optional[torch.Tensor] = None,
+ return_dict: bool = True,
+ ) -> Union[LCMSchedulerOutput, Tuple]:
+ """
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
+ process from the learned model outputs (most often the predicted noise).
+ Args:
+ model_output (`torch.Tensor`):
+ The direct output from learned diffusion model.
+ timestep (`float`):
+ The current discrete timestep in the diffusion chain.
+ sample (`torch.Tensor`):
+ A current instance of a sample created by the diffusion process.
+ eta (`float`):
+ The weight of noise for added noise in diffusion step.
+ use_clipped_model_output (`bool`, defaults to `False`):
+ If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary
+ because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no
+ clipping has happened, "corrected" `model_output` would coincide with the one provided as input and
+ `use_clipped_model_output` has no effect.
+ generator (`torch.Generator`, *optional*):
+ A random number generator.
+ variance_noise (`torch.Tensor`):
+ Alternative to generating noise with `generator` by directly providing the noise for the variance
+ itself. Useful for methods such as [`CycleDiffusion`].
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] or `tuple`.
+ Returns:
+ [`~schedulers.scheduling_utils.LCMSchedulerOutput`] or `tuple`:
+ If return_dict is `True`, [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] is returned, otherwise a
+ tuple is returned where the first element is the sample tensor.
+ """
+ if self.num_inference_steps is None:
+ raise ValueError(
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
+ )
+
+ # 1. get previous step value
+ prev_timeindex = timeindex + 1
+ if prev_timeindex < len(self.timesteps):
+ prev_timestep = self.timesteps[prev_timeindex]
+ else:
+ prev_timestep = timestep
+
+ # 2. compute alphas, betas
+ alpha_prod_t = self.alphas_cumprod[timestep]
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
+
+ beta_prod_t = 1 - alpha_prod_t
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
+
+ # 3. Get scalings for boundary conditions
+ c_skip, c_out = self.get_scalings_for_boundary_condition_discrete(timestep)
+
+ # 4. Different Parameterization:
+ parameterization = self.config.prediction_type
+
+ if parameterization == "epsilon": # noise-prediction
+ pred_x0 = (sample - beta_prod_t.sqrt() * model_output) / alpha_prod_t.sqrt()
+
+ elif parameterization == "sample": # x-prediction
+ pred_x0 = model_output
+
+ elif parameterization == "v_prediction": # v-prediction
+ pred_x0 = alpha_prod_t.sqrt() * sample - beta_prod_t.sqrt() * model_output
+
+ # 4. Denoise model output using boundary conditions
+ denoised = c_out * pred_x0 + c_skip * sample
+
+ # 5. Sample z ~ N(0, I), For MultiStep Inference
+ # Noise is not used for one-step sampling.
+ if len(self.timesteps) > 1:
+ noise = torch.randn(model_output.shape).to(model_output.device)
+ prev_sample = alpha_prod_t_prev.sqrt() * denoised + beta_prod_t_prev.sqrt() * noise
+ else:
+ prev_sample = denoised
+
+ if not return_dict:
+ return (prev_sample, denoised)
+
+ return LCMSchedulerOutput(prev_sample=prev_sample, denoised=denoised)
+
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
+ def add_noise(
+ self,
+ original_samples: torch.Tensor,
+ noise: torch.Tensor,
+ timesteps: torch.IntTensor,
+ ) -> torch.Tensor:
+ # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
+ alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
+ timesteps = timesteps.to(original_samples.device)
+
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
+ while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
+
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
+ while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
+
+ noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
+ return noisy_samples
+
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity
+ def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor:
+ # Make sure alphas_cumprod and timestep have same device and dtype as sample
+ alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype)
+ timesteps = timesteps.to(sample.device)
+
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
+ while len(sqrt_alpha_prod.shape) < len(sample.shape):
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
+
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
+ while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
+
+ velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
+ return velocity
+
+ def __len__(self):
+ return self.config.num_train_timesteps
diff --git a/diffusers/examples/community/llm_grounded_diffusion.py b/diffusers/examples/community/llm_grounded_diffusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..49c07491135496ee0737672da7278280f0b6b4af
--- /dev/null
+++ b/diffusers/examples/community/llm_grounded_diffusion.py
@@ -0,0 +1,1563 @@
+# Copyright 2024 Long Lian, the GLIGEN Authors, and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# This is a single file implementation of LMD+. See README.md for examples.
+
+import ast
+import gc
+import inspect
+import math
+import warnings
+from collections.abc import Iterable
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import torch
+import torch.nn.functional as F
+from packaging import version
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
+
+from diffusers.configuration_utils import FrozenDict
+from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
+from diffusers.loaders import (
+ FromSingleFileMixin,
+ IPAdapterMixin,
+ StableDiffusionLoraLoaderMixin,
+ TextualInversionLoaderMixin,
+)
+from diffusers.models import AutoencoderKL, UNet2DConditionModel
+from diffusers.models.attention import Attention, GatedSelfAttentionDense
+from diffusers.models.attention_processor import AttnProcessor2_0
+from diffusers.models.lora import adjust_lora_scale_text_encoder
+from diffusers.pipelines import DiffusionPipeline
+from diffusers.pipelines.pipeline_utils import StableDiffusionMixin
+from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
+from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
+from diffusers.schedulers import KarrasDiffusionSchedulers
+from diffusers.utils import (
+ USE_PEFT_BACKEND,
+ deprecate,
+ logging,
+ replace_example_docstring,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
+from diffusers.utils.torch_utils import randn_tensor
+
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import DiffusionPipeline
+
+ >>> pipe = DiffusionPipeline.from_pretrained(
+ ... "longlian/lmd_plus",
+ ... custom_pipeline="llm_grounded_diffusion",
+ ... custom_revision="main",
+ ... variant="fp16", torch_dtype=torch.float16
+ ... )
+ >>> pipe.enable_model_cpu_offload()
+
+ >>> # Generate an image described by the prompt and
+ >>> # insert objects described by text at the region defined by bounding boxes
+ >>> prompt = "a waterfall and a modern high speed train in a beautiful forest with fall foliage"
+ >>> boxes = [[0.1387, 0.2051, 0.4277, 0.7090], [0.4980, 0.4355, 0.8516, 0.7266]]
+ >>> phrases = ["a waterfall", "a modern high speed train"]
+
+ >>> images = pipe(
+ ... prompt=prompt,
+ ... phrases=phrases,
+ ... boxes=boxes,
+ ... gligen_scheduled_sampling_beta=0.4,
+ ... output_type="pil",
+ ... num_inference_steps=50,
+ ... lmd_guidance_kwargs={}
+ ... ).images
+
+ >>> images[0].save("./lmd_plus_generation.jpg")
+
+ >>> # Generate directly from a text prompt and an LLM response
+ >>> prompt = "a waterfall and a modern high speed train in a beautiful forest with fall foliage"
+ >>> phrases, boxes, bg_prompt, neg_prompt = pipe.parse_llm_response(\"""
+ [('a waterfall', [71, 105, 148, 258]), ('a modern high speed train', [255, 223, 181, 149])]
+ Background prompt: A beautiful forest with fall foliage
+ Negative prompt:
+ \""")
+
+ >> images = pipe(
+ ... prompt=prompt,
+ ... negative_prompt=neg_prompt,
+ ... phrases=phrases,
+ ... boxes=boxes,
+ ... gligen_scheduled_sampling_beta=0.4,
+ ... output_type="pil",
+ ... num_inference_steps=50,
+ ... lmd_guidance_kwargs={}
+ ... ).images
+
+ >>> images[0].save("./lmd_plus_generation.jpg")
+
+images[0]
+
+ ```
+"""
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+# All keys in Stable Diffusion models: [('down', 0, 0, 0), ('down', 0, 1, 0), ('down', 1, 0, 0), ('down', 1, 1, 0), ('down', 2, 0, 0), ('down', 2, 1, 0), ('mid', 0, 0, 0), ('up', 1, 0, 0), ('up', 1, 1, 0), ('up', 1, 2, 0), ('up', 2, 0, 0), ('up', 2, 1, 0), ('up', 2, 2, 0), ('up', 3, 0, 0), ('up', 3, 1, 0), ('up', 3, 2, 0)]
+# Note that the first up block is `UpBlock2D` rather than `CrossAttnUpBlock2D` and does not have attention. The last index is always 0 in our case since we have one `BasicTransformerBlock` in each `Transformer2DModel`.
+DEFAULT_GUIDANCE_ATTN_KEYS = [
+ ("mid", 0, 0, 0),
+ ("up", 1, 0, 0),
+ ("up", 1, 1, 0),
+ ("up", 1, 2, 0),
+]
+
+
+def convert_attn_keys(key):
+ """Convert the attention key from tuple format to the torch state format"""
+
+ if key[0] == "mid":
+ assert key[1] == 0, f"mid block only has one block but the index is {key[1]}"
+ return f"{key[0]}_block.attentions.{key[2]}.transformer_blocks.{key[3]}.attn2.processor"
+
+ return f"{key[0]}_blocks.{key[1]}.attentions.{key[2]}.transformer_blocks.{key[3]}.attn2.processor"
+
+
+DEFAULT_GUIDANCE_ATTN_KEYS = [convert_attn_keys(key) for key in DEFAULT_GUIDANCE_ATTN_KEYS]
+
+
+def scale_proportion(obj_box, H, W):
+ # Separately rounding box_w and box_h to allow shift invariant box sizes. Otherwise box sizes may change when both coordinates being rounded end with ".5".
+ x_min, y_min = round(obj_box[0] * W), round(obj_box[1] * H)
+ box_w, box_h = round((obj_box[2] - obj_box[0]) * W), round((obj_box[3] - obj_box[1]) * H)
+ x_max, y_max = x_min + box_w, y_min + box_h
+
+ x_min, y_min = max(x_min, 0), max(y_min, 0)
+ x_max, y_max = min(x_max, W), min(y_max, H)
+
+ return x_min, y_min, x_max, y_max
+
+
+# Adapted from the parent class `AttnProcessor2_0`
+class AttnProcessorWithHook(AttnProcessor2_0):
+ def __init__(
+ self,
+ attn_processor_key,
+ hidden_size,
+ cross_attention_dim,
+ hook=None,
+ fast_attn=True,
+ enabled=True,
+ ):
+ super().__init__()
+ self.attn_processor_key = attn_processor_key
+ self.hidden_size = hidden_size
+ self.cross_attention_dim = cross_attention_dim
+ self.hook = hook
+ self.fast_attn = fast_attn
+ self.enabled = enabled
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states,
+ encoder_hidden_states=None,
+ attention_mask=None,
+ temb=None,
+ scale: float = 1.0,
+ ):
+ residual = hidden_states
+
+ if attn.spatial_norm is not None:
+ hidden_states = attn.spatial_norm(hidden_states, temb)
+
+ input_ndim = hidden_states.ndim
+
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+
+ if attention_mask is not None:
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+
+ if attn.group_norm is not None:
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ args = () if USE_PEFT_BACKEND else (scale,)
+ query = attn.to_q(hidden_states, *args)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ elif attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ key = attn.to_k(encoder_hidden_states, *args)
+ value = attn.to_v(encoder_hidden_states, *args)
+
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+
+ if (self.hook is not None and self.enabled) or not self.fast_attn:
+ query_batch_dim = attn.head_to_batch_dim(query)
+ key_batch_dim = attn.head_to_batch_dim(key)
+ value_batch_dim = attn.head_to_batch_dim(value)
+ attention_probs = attn.get_attention_scores(query_batch_dim, key_batch_dim, attention_mask)
+
+ if self.hook is not None and self.enabled:
+ # Call the hook with query, key, value, and attention maps
+ self.hook(
+ self.attn_processor_key,
+ query_batch_dim,
+ key_batch_dim,
+ value_batch_dim,
+ attention_probs,
+ )
+
+ if self.fast_attn:
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ if attention_mask is not None:
+ # scaled_dot_product_attention expects attention_mask shape to be
+ # (batch, heads, source_length, target_length)
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
+
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
+ # TODO: add support for attn.scale when we move to Torch 2.1
+ hidden_states = F.scaled_dot_product_attention(
+ query,
+ key,
+ value,
+ attn_mask=attention_mask,
+ dropout_p=0.0,
+ is_causal=False,
+ )
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states = hidden_states.to(query.dtype)
+ else:
+ hidden_states = torch.bmm(attention_probs, value)
+ hidden_states = attn.batch_to_head_dim(hidden_states)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states, *args)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ return hidden_states
+
+
+class LLMGroundedDiffusionPipeline(
+ DiffusionPipeline,
+ StableDiffusionMixin,
+ TextualInversionLoaderMixin,
+ StableDiffusionLoraLoaderMixin,
+ IPAdapterMixin,
+ FromSingleFileMixin,
+):
+ r"""
+ Pipeline for layout-grounded text-to-image generation using LLM-grounded Diffusion (LMD+): https://arxiv.org/pdf/2305.13655.pdf.
+
+ This model inherits from [`StableDiffusionPipeline`] and aims at implementing the pipeline with minimal modifications. Check the superclass documentation for the generic methods
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
+
+ This is a simplified implementation that does not perform latent or attention transfer from single object generation to overall generation. The final image is generated directly with attention and adapters control.
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
+ text_encoder ([`~transformers.CLIPTextModel`]):
+ Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
+ tokenizer ([`~transformers.CLIPTokenizer`]):
+ A `CLIPTokenizer` to tokenize text.
+ unet ([`UNet2DConditionModel`]):
+ A `UNet2DConditionModel` to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
+ about a model's potential harms.
+ feature_extractor ([`~transformers.CLIPImageProcessor`]):
+ A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
+ requires_safety_checker (bool):
+ Whether a safety checker is needed for this pipeline.
+ """
+
+ model_cpu_offload_seq = "text_encoder->unet->vae"
+ _optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
+ _exclude_from_cpu_offload = ["safety_checker"]
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
+
+ objects_text = "Objects: "
+ bg_prompt_text = "Background prompt: "
+ bg_prompt_text_no_trailing_space = bg_prompt_text.rstrip()
+ neg_prompt_text = "Negative prompt: "
+ neg_prompt_text_no_trailing_space = neg_prompt_text.rstrip()
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: KarrasDiffusionSchedulers,
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPImageProcessor,
+ image_encoder: CLIPVisionModelWithProjection = None,
+ requires_safety_checker: bool = True,
+ ):
+ # This is copied from StableDiffusionPipeline, with hook initizations for LMD+.
+ super().__init__()
+
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
+ " file"
+ )
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["steps_offset"] = 1
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
+ )
+ deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["clip_sample"] = False
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if safety_checker is None and requires_safety_checker:
+ logger.warning(
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
+ )
+
+ if safety_checker is not None and feature_extractor is None:
+ raise ValueError(
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
+ )
+
+ is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
+ version.parse(unet.config._diffusers_version).base_version
+ ) < version.parse("0.9.0.dev0")
+ is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
+ deprecation_message = (
+ "The configuration file of the unet has set the default `sample_size` to smaller than"
+ " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
+ " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
+ " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
+ " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
+ " in the config might lead to incorrect results in future versions. If you have downloaded this"
+ " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
+ " the `unet/config.json` file"
+ )
+ deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(unet.config)
+ new_config["sample_size"] = 64
+ unet._internal_dict = FrozenDict(new_config)
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ image_encoder=image_encoder,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
+
+ # Initialize the attention hooks for LLM-grounded Diffusion
+ self.register_attn_hooks(unet)
+ self._saved_attn = None
+
+ def attn_hook(self, name, query, key, value, attention_probs):
+ if name in DEFAULT_GUIDANCE_ATTN_KEYS:
+ self._saved_attn[name] = attention_probs
+
+ @classmethod
+ def convert_box(cls, box, height, width):
+ # box: x, y, w, h (in 512 format) -> x_min, y_min, x_max, y_max
+ x_min, y_min = box[0] / width, box[1] / height
+ w_box, h_box = box[2] / width, box[3] / height
+
+ x_max, y_max = x_min + w_box, y_min + h_box
+
+ return x_min, y_min, x_max, y_max
+
+ @classmethod
+ def _parse_response_with_negative(cls, text):
+ if not text:
+ raise ValueError("LLM response is empty")
+
+ if cls.objects_text in text:
+ text = text.split(cls.objects_text)[1]
+
+ text_split = text.split(cls.bg_prompt_text_no_trailing_space)
+ if len(text_split) == 2:
+ gen_boxes, text_rem = text_split
+ else:
+ raise ValueError(f"LLM response is incomplete: {text}")
+
+ text_split = text_rem.split(cls.neg_prompt_text_no_trailing_space)
+
+ if len(text_split) == 2:
+ bg_prompt, neg_prompt = text_split
+ else:
+ raise ValueError(f"LLM response is incomplete: {text}")
+
+ try:
+ gen_boxes = ast.literal_eval(gen_boxes)
+ except SyntaxError as e:
+ # Sometimes the response is in plain text
+ if "No objects" in gen_boxes or gen_boxes.strip() == "":
+ gen_boxes = []
+ else:
+ raise e
+ bg_prompt = bg_prompt.strip()
+ neg_prompt = neg_prompt.strip()
+
+ # LLM may return "None" to mean no negative prompt provided.
+ if neg_prompt == "None":
+ neg_prompt = ""
+
+ return gen_boxes, bg_prompt, neg_prompt
+
+ @classmethod
+ def parse_llm_response(cls, response, canvas_height=512, canvas_width=512):
+ # Infer from spec
+ gen_boxes, bg_prompt, neg_prompt = cls._parse_response_with_negative(text=response)
+
+ gen_boxes = sorted(gen_boxes, key=lambda gen_box: gen_box[0])
+
+ phrases = [name for name, _ in gen_boxes]
+ boxes = [cls.convert_box(box, height=canvas_height, width=canvas_width) for _, box in gen_boxes]
+
+ return phrases, boxes, bg_prompt, neg_prompt
+
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ callback_steps,
+ phrases,
+ boxes,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ phrase_indices=None,
+ ):
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+ elif prompt is None and phrase_indices is None:
+ raise ValueError("If the prompt is None, the phrase_indices cannot be None")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ if len(phrases) != len(boxes):
+ raise ValueError(
+ "length of `phrases` and `boxes` has to be same, but"
+ f" got: `phrases` {len(phrases)} != `boxes` {len(boxes)}"
+ )
+
+ def register_attn_hooks(self, unet):
+ """Registering hooks to obtain the attention maps for guidance"""
+
+ attn_procs = {}
+
+ for name in unet.attn_processors.keys():
+ # Only obtain the queries and keys from cross-attention
+ if name.endswith("attn1.processor") or name.endswith("fuser.attn.processor"):
+ # Keep the same attn_processors for self-attention (no hooks for self-attention)
+ attn_procs[name] = unet.attn_processors[name]
+ continue
+
+ cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
+
+ if name.startswith("mid_block"):
+ hidden_size = unet.config.block_out_channels[-1]
+ elif name.startswith("up_blocks"):
+ block_id = int(name[len("up_blocks.")])
+ hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
+ elif name.startswith("down_blocks"):
+ block_id = int(name[len("down_blocks.")])
+ hidden_size = unet.config.block_out_channels[block_id]
+
+ attn_procs[name] = AttnProcessorWithHook(
+ attn_processor_key=name,
+ hidden_size=hidden_size,
+ cross_attention_dim=cross_attention_dim,
+ hook=self.attn_hook,
+ fast_attn=True,
+ # Not enabled by default
+ enabled=False,
+ )
+
+ unet.set_attn_processor(attn_procs)
+
+ def enable_fuser(self, enabled=True):
+ for module in self.unet.modules():
+ if isinstance(module, GatedSelfAttentionDense):
+ module.enabled = enabled
+
+ def enable_attn_hook(self, enabled=True):
+ for module in self.unet.attn_processors.values():
+ if isinstance(module, AttnProcessorWithHook):
+ module.enabled = enabled
+
+ def get_token_map(self, prompt, padding="do_not_pad", verbose=False):
+ """Get a list of mapping: prompt index to str (prompt in a list of token str)"""
+ fg_prompt_tokens = self.tokenizer([prompt], padding=padding, max_length=77, return_tensors="np")
+ input_ids = fg_prompt_tokens["input_ids"][0]
+
+ token_map = []
+ for ind, item in enumerate(input_ids.tolist()):
+ token = self.tokenizer._convert_id_to_token(item)
+
+ if verbose:
+ logger.info(f"{ind}, {token} ({item})")
+
+ token_map.append(token)
+
+ return token_map
+
+ def get_phrase_indices(
+ self,
+ prompt,
+ phrases,
+ token_map=None,
+ add_suffix_if_not_found=False,
+ verbose=False,
+ ):
+ for obj in phrases:
+ # Suffix the prompt with object name for attention guidance if object is not in the prompt, using "|" to separate the prompt and the suffix
+ if obj not in prompt:
+ prompt += "| " + obj
+
+ if token_map is None:
+ # We allow using a pre-computed token map.
+ token_map = self.get_token_map(prompt=prompt, padding="do_not_pad", verbose=verbose)
+ token_map_str = " ".join(token_map)
+
+ phrase_indices = []
+
+ for obj in phrases:
+ phrase_token_map = self.get_token_map(prompt=obj, padding="do_not_pad", verbose=verbose)
+ # Remove and in substr
+ phrase_token_map = phrase_token_map[1:-1]
+ phrase_token_map_len = len(phrase_token_map)
+ phrase_token_map_str = " ".join(phrase_token_map)
+
+ if verbose:
+ logger.info(
+ "Full str:",
+ token_map_str,
+ "Substr:",
+ phrase_token_map_str,
+ "Phrase:",
+ phrases,
+ )
+
+ # Count the number of token before substr
+ # The substring comes with a trailing space that needs to be removed by minus one in the index.
+ obj_first_index = len(token_map_str[: token_map_str.index(phrase_token_map_str) - 1].split(" "))
+
+ obj_position = list(range(obj_first_index, obj_first_index + phrase_token_map_len))
+ phrase_indices.append(obj_position)
+
+ if add_suffix_if_not_found:
+ return phrase_indices, prompt
+
+ return phrase_indices
+
+ def add_ca_loss_per_attn_map_to_loss(
+ self,
+ loss,
+ attn_map,
+ object_number,
+ bboxes,
+ phrase_indices,
+ fg_top_p=0.2,
+ bg_top_p=0.2,
+ fg_weight=1.0,
+ bg_weight=1.0,
+ ):
+ # b is the number of heads, not batch
+ b, i, j = attn_map.shape
+ H = W = int(math.sqrt(i))
+ for obj_idx in range(object_number):
+ obj_loss = 0
+ mask = torch.zeros(size=(H, W), device="cuda")
+ obj_boxes = bboxes[obj_idx]
+
+ # We support two level (one box per phrase) and three level (multiple boxes per phrase)
+ if not isinstance(obj_boxes[0], Iterable):
+ obj_boxes = [obj_boxes]
+
+ for obj_box in obj_boxes:
+ # x_min, y_min, x_max, y_max = int(obj_box[0] * W), int(obj_box[1] * H), int(obj_box[2] * W), int(obj_box[3] * H)
+ x_min, y_min, x_max, y_max = scale_proportion(obj_box, H=H, W=W)
+ mask[y_min:y_max, x_min:x_max] = 1
+
+ for obj_position in phrase_indices[obj_idx]:
+ # Could potentially optimize to compute this for loop in batch.
+ # Could crop the ref cross attention before saving to save memory.
+
+ ca_map_obj = attn_map[:, :, obj_position].reshape(b, H, W)
+
+ # shape: (b, H * W)
+ ca_map_obj = attn_map[:, :, obj_position] # .reshape(b, H, W)
+ k_fg = (mask.sum() * fg_top_p).long().clamp_(min=1)
+ k_bg = ((1 - mask).sum() * bg_top_p).long().clamp_(min=1)
+
+ mask_1d = mask.view(1, -1)
+
+ # Max-based loss function
+
+ # Take the topk over spatial dimension, and then take the sum over heads dim
+ # The mean is over k_fg and k_bg dimension, so we don't need to sum and divide on our own.
+ obj_loss += (1 - (ca_map_obj * mask_1d).topk(k=k_fg).values.mean(dim=1)).sum(dim=0) * fg_weight
+ obj_loss += ((ca_map_obj * (1 - mask_1d)).topk(k=k_bg).values.mean(dim=1)).sum(dim=0) * bg_weight
+
+ loss += obj_loss / len(phrase_indices[obj_idx])
+
+ return loss
+
+ def compute_ca_loss(
+ self,
+ saved_attn,
+ bboxes,
+ phrase_indices,
+ guidance_attn_keys,
+ verbose=False,
+ **kwargs,
+ ):
+ """
+ The `saved_attn` is supposed to be passed to `save_attn_to_dict` in `cross_attention_kwargs` prior to computing ths loss.
+ `AttnProcessor` will put attention maps into the `save_attn_to_dict`.
+
+ `index` is the timestep.
+ `ref_ca_word_token_only`: This has precedence over `ref_ca_last_token_only` (i.e., if both are enabled, we take the token from word rather than the last token).
+ `ref_ca_last_token_only`: `ref_ca_saved_attn` comes from the attention map of the last token of the phrase in single object generation, so we apply it only to the last token of the phrase in overall generation if this is set to True. If set to False, `ref_ca_saved_attn` will be applied to all the text tokens.
+ """
+ loss = torch.tensor(0).float().cuda()
+ object_number = len(bboxes)
+ if object_number == 0:
+ return loss
+
+ for attn_key in guidance_attn_keys:
+ # We only have 1 cross attention for mid.
+
+ attn_map_integrated = saved_attn[attn_key]
+ if not attn_map_integrated.is_cuda:
+ attn_map_integrated = attn_map_integrated.cuda()
+ # Example dimension: [20, 64, 77]
+ attn_map = attn_map_integrated.squeeze(dim=0)
+
+ loss = self.add_ca_loss_per_attn_map_to_loss(
+ loss, attn_map, object_number, bboxes, phrase_indices, **kwargs
+ )
+
+ num_attn = len(guidance_attn_keys)
+
+ if num_attn > 0:
+ loss = loss / (object_number * num_attn)
+
+ return loss
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ gligen_scheduled_sampling_beta: float = 0.3,
+ phrases: List[str] = None,
+ boxes: List[List[float]] = None,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ ip_adapter_image: Optional[PipelineImageInput] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
+ callback_steps: int = 1,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ clip_skip: Optional[int] = None,
+ lmd_guidance_kwargs: Optional[Dict[str, Any]] = {},
+ phrase_indices: Optional[List[int]] = None,
+ ):
+ r"""
+ The call function to the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
+ height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ A higher guidance scale value encourages the model to generate images closely linked to the text
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
+ phrases (`List[str]`):
+ The phrases to guide what to include in each of the regions defined by the corresponding
+ `boxes`. There should only be one phrase per bounding box.
+ boxes (`List[List[float]]`):
+ The bounding boxes that identify rectangular regions of the image that are going to be filled with the
+ content described by the corresponding `phrases`. Each rectangular box is defined as a
+ `List[float]` of 4 elements `[xmin, ymin, xmax, ymax]` where each value is between [0,1].
+ gligen_scheduled_sampling_beta (`float`, defaults to 0.3):
+ Scheduled Sampling factor from [GLIGEN: Open-Set Grounded Text-to-Image
+ Generation](https://arxiv.org/pdf/2301.07093.pdf). Scheduled Sampling factor is only varied for
+ scheduled sampling during inference for improved quality and controllability.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
+ generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor is generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
+ provided, text embeddings are generated from the `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that calls every `callback_steps` steps during inference. The function is called with the
+ following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function is called. If not specified, the callback is called at
+ every step.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
+ Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when
+ using zero terminal SNR.
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+ lmd_guidance_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to `latent_lmd_guidance` function. Useful keys include `loss_scale` (the guidance strength), `loss_threshold` (when loss is lower than this value, the guidance is not applied anymore), `max_iter` (the number of iterations of guidance for each step), and `guidance_timesteps` (the number of diffusion timesteps to apply guidance on). See `latent_lmd_guidance` for implementation details.
+ phrase_indices (`list` of `list`, *optional*): The indices of the tokens of each phrase in the overall prompt. If omitted, the pipeline will match the first token subsequence. The pipeline will append the missing phrases to the end of the prompt by default.
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
+ otherwise a `tuple` is returned where the first element is a list with the generated images and the
+ second element is a list of `bool`s indicating whether the corresponding generated image contains
+ "not-safe-for-work" (nsfw) content.
+ """
+ # 0. Default height and width to unet
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ height,
+ width,
+ callback_steps,
+ phrases,
+ boxes,
+ negative_prompt,
+ prompt_embeds,
+ negative_prompt_embeds,
+ phrase_indices,
+ )
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ if phrase_indices is None:
+ phrase_indices, prompt = self.get_phrase_indices(prompt, phrases, add_suffix_if_not_found=True)
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ if phrase_indices is None:
+ phrase_indices = []
+ prompt_parsed = []
+ for prompt_item in prompt:
+ (
+ phrase_indices_parsed_item,
+ prompt_parsed_item,
+ ) = self.get_phrase_indices(prompt_item, add_suffix_if_not_found=True)
+ phrase_indices.append(phrase_indices_parsed_item)
+ prompt_parsed.append(prompt_parsed_item)
+ prompt = prompt_parsed
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ clip_skip=clip_skip,
+ )
+
+ cond_prompt_embeds = prompt_embeds
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ if do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+
+ if ip_adapter_image is not None:
+ image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt)
+ if self.do_classifier_free_guidance:
+ image_embeds = torch.cat([negative_image_embeds, image_embeds])
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.unet.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 5.1 Prepare GLIGEN variables
+ max_objs = 30
+ if len(boxes) > max_objs:
+ warnings.warn(
+ f"More that {max_objs} objects found. Only first {max_objs} objects will be processed.",
+ FutureWarning,
+ )
+ phrases = phrases[:max_objs]
+ boxes = boxes[:max_objs]
+
+ n_objs = len(boxes)
+ if n_objs:
+ # prepare batched input to the PositionNet (boxes, phrases, mask)
+ # Get tokens for phrases from pre-trained CLIPTokenizer
+ tokenizer_inputs = self.tokenizer(phrases, padding=True, return_tensors="pt").to(device)
+ # For the token, we use the same pre-trained text encoder
+ # to obtain its text feature
+ _text_embeddings = self.text_encoder(**tokenizer_inputs).pooler_output
+
+ # For each entity, described in phrases, is denoted with a bounding box,
+ # we represent the location information as (xmin,ymin,xmax,ymax)
+ cond_boxes = torch.zeros(max_objs, 4, device=device, dtype=self.text_encoder.dtype)
+ if n_objs:
+ cond_boxes[:n_objs] = torch.tensor(boxes)
+ text_embeddings = torch.zeros(
+ max_objs,
+ self.unet.config.cross_attention_dim,
+ device=device,
+ dtype=self.text_encoder.dtype,
+ )
+ if n_objs:
+ text_embeddings[:n_objs] = _text_embeddings
+ # Generate a mask for each object that is entity described by phrases
+ masks = torch.zeros(max_objs, device=device, dtype=self.text_encoder.dtype)
+ masks[:n_objs] = 1
+
+ repeat_batch = batch_size * num_images_per_prompt
+ cond_boxes = cond_boxes.unsqueeze(0).expand(repeat_batch, -1, -1).clone()
+ text_embeddings = text_embeddings.unsqueeze(0).expand(repeat_batch, -1, -1).clone()
+ masks = masks.unsqueeze(0).expand(repeat_batch, -1).clone()
+ if do_classifier_free_guidance:
+ repeat_batch = repeat_batch * 2
+ cond_boxes = torch.cat([cond_boxes] * 2)
+ text_embeddings = torch.cat([text_embeddings] * 2)
+ masks = torch.cat([masks] * 2)
+ masks[: repeat_batch // 2] = 0
+ if cross_attention_kwargs is None:
+ cross_attention_kwargs = {}
+ cross_attention_kwargs["gligen"] = {
+ "boxes": cond_boxes,
+ "positive_embeddings": text_embeddings,
+ "masks": masks,
+ }
+
+ num_grounding_steps = int(gligen_scheduled_sampling_beta * len(timesteps))
+ self.enable_fuser(True)
+
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 6.1 Add image embeds for IP-Adapter
+ added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
+
+ loss_attn = torch.tensor(10000.0)
+
+ # 7. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # Scheduled sampling
+ if i == num_grounding_steps:
+ self.enable_fuser(False)
+
+ if latents.shape[1] != 4:
+ latents = torch.randn_like(latents[:, :4])
+
+ # 7.1 Perform LMD guidance
+ if boxes:
+ latents, loss_attn = self.latent_lmd_guidance(
+ cond_prompt_embeds,
+ index=i,
+ boxes=boxes,
+ phrase_indices=phrase_indices,
+ t=t,
+ latents=latents,
+ loss=loss_attn,
+ **lmd_guidance_kwargs,
+ )
+
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs=cross_attention_kwargs,
+ added_cond_kwargs=added_cond_kwargs,
+ ).sample
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
+
+ if not output_type == "latent":
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
+ else:
+ image = latents
+ has_nsfw_concept = None
+
+ if has_nsfw_concept is None:
+ do_denormalize = [True] * image.shape[0]
+ else:
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
+
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
+
+ # Offload last model to CPU
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.final_offload_hook.offload()
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
+
+ @torch.set_grad_enabled(True)
+ def latent_lmd_guidance(
+ self,
+ cond_embeddings,
+ index,
+ boxes,
+ phrase_indices,
+ t,
+ latents,
+ loss,
+ *,
+ loss_scale=20,
+ loss_threshold=5.0,
+ max_iter=[3] * 5 + [2] * 5 + [1] * 5,
+ guidance_timesteps=15,
+ cross_attention_kwargs=None,
+ guidance_attn_keys=DEFAULT_GUIDANCE_ATTN_KEYS,
+ verbose=False,
+ clear_cache=False,
+ unet_additional_kwargs={},
+ guidance_callback=None,
+ **kwargs,
+ ):
+ scheduler, unet = self.scheduler, self.unet
+
+ iteration = 0
+
+ if index < guidance_timesteps:
+ if isinstance(max_iter, list):
+ max_iter = max_iter[index]
+
+ if verbose:
+ logger.info(
+ f"time index {index}, loss: {loss.item()/loss_scale:.3f} (de-scaled with scale {loss_scale:.1f}), loss threshold: {loss_threshold:.3f}"
+ )
+
+ try:
+ self.enable_attn_hook(enabled=True)
+
+ while (
+ loss.item() / loss_scale > loss_threshold and iteration < max_iter and index < guidance_timesteps
+ ):
+ self._saved_attn = {}
+
+ latents.requires_grad_(True)
+ latent_model_input = latents
+ latent_model_input = scheduler.scale_model_input(latent_model_input, t)
+
+ unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=cond_embeddings,
+ cross_attention_kwargs=cross_attention_kwargs,
+ **unet_additional_kwargs,
+ )
+
+ # update latents with guidance
+ loss = (
+ self.compute_ca_loss(
+ saved_attn=self._saved_attn,
+ bboxes=boxes,
+ phrase_indices=phrase_indices,
+ guidance_attn_keys=guidance_attn_keys,
+ verbose=verbose,
+ **kwargs,
+ )
+ * loss_scale
+ )
+
+ if torch.isnan(loss):
+ raise RuntimeError("**Loss is NaN**")
+
+ # This callback allows visualizations.
+ if guidance_callback is not None:
+ guidance_callback(self, latents, loss, iteration, index)
+
+ self._saved_attn = None
+
+ grad_cond = torch.autograd.grad(loss.requires_grad_(True), [latents])[0]
+
+ latents.requires_grad_(False)
+
+ # Scaling with classifier guidance
+ alpha_prod_t = scheduler.alphas_cumprod[t]
+ # Classifier guidance: https://arxiv.org/pdf/2105.05233.pdf
+ # DDIM: https://arxiv.org/pdf/2010.02502.pdf
+ scale = (1 - alpha_prod_t) ** (0.5)
+ latents = latents - scale * grad_cond
+
+ iteration += 1
+
+ if clear_cache:
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ if verbose:
+ logger.info(
+ f"time index {index}, loss: {loss.item()/loss_scale:.3f}, loss threshold: {loss_threshold:.3f}, iteration: {iteration}"
+ )
+
+ finally:
+ self.enable_attn_hook(enabled=False)
+
+ return latents, loss
+
+ # Below are methods copied from StableDiffusionPipeline
+ # The design choice of not inheriting from StableDiffusionPipeline is discussed here: https://github.com/huggingface/diffusers/pull/5993#issuecomment-1834258517
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
+ def _encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt=None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ lora_scale: Optional[float] = None,
+ **kwargs,
+ ):
+ deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
+ deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
+
+ prompt_embeds_tuple = self.encode_prompt(
+ prompt=prompt,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ do_classifier_free_guidance=do_classifier_free_guidance,
+ negative_prompt=negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ lora_scale=lora_scale,
+ **kwargs,
+ )
+
+ # concatenate for backwards comp
+ prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
+
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt=None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ lora_scale: Optional[float] = None,
+ clip_skip: Optional[int] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ lora_scale (`float`, *optional*):
+ A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+ """
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ # dynamically adjust the LoRA scale
+ if not USE_PEFT_BACKEND:
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
+ else:
+ scale_lora_layers(self.text_encoder, lora_scale)
+
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ # textual inversion: process multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = self.tokenizer.batch_decode(
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
+ )
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = text_inputs.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ if clip_skip is None:
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
+ prompt_embeds = prompt_embeds[0]
+ else:
+ prompt_embeds = self.text_encoder(
+ text_input_ids.to(device),
+ attention_mask=attention_mask,
+ output_hidden_states=True,
+ )
+ # Access the `hidden_states` first, that contains a tuple of
+ # all the hidden states from the encoder layers. Then index into
+ # the tuple to access the hidden states from the desired layer.
+ prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
+ # We also need to apply the final LayerNorm here to not mess with the
+ # representations. The `last_hidden_states` that we typically use for
+ # obtaining the final prompt representations passes through the LayerNorm
+ # layer.
+ prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
+
+ if self.text_encoder is not None:
+ prompt_embeds_dtype = self.text_encoder.dtype
+ elif self.unet is not None:
+ prompt_embeds_dtype = self.unet.dtype
+ else:
+ prompt_embeds_dtype = prompt_embeds.dtype
+
+ prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ # textual inversion: process multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = uncond_input.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ negative_prompt_embeds = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ negative_prompt_embeds = negative_prompt_embeds[0]
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder, lora_scale)
+
+ return prompt_embeds, negative_prompt_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
+ def encode_image(self, image, device, num_images_per_prompt):
+ dtype = next(self.image_encoder.parameters()).dtype
+
+ if not isinstance(image, torch.Tensor):
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
+
+ image = image.to(device=device, dtype=dtype)
+ image_embeds = self.image_encoder(image).image_embeds
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+
+ uncond_image_embeds = torch.zeros_like(image_embeds)
+ return image_embeds, uncond_image_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
+ def run_safety_checker(self, image, device, dtype):
+ if self.safety_checker is None:
+ has_nsfw_concept = None
+ else:
+ if torch.is_tensor(image):
+ feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
+ else:
+ feature_extractor_input = self.image_processor.numpy_to_pil(image)
+ safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
+ image, has_nsfw_concept = self.safety_checker(
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
+ )
+ return image, has_nsfw_concept
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
+ def decode_latents(self, latents):
+ deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
+ deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
+
+ latents = 1 / self.vae.config.scaling_factor * latents
+ image = self.vae.decode(latents, return_dict=False)[0]
+ image = (image / 2 + 0.5).clamp(0, 1)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+ return image
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
+ def prepare_latents(
+ self,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ latents=None,
+ ):
+ shape = (
+ batch_size,
+ num_channels_latents,
+ height // self.vae_scale_factor,
+ width // self.vae_scale_factor,
+ )
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
+ def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
+ """
+ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
+
+ Args:
+ timesteps (`torch.Tensor`):
+ generate embedding vectors at these timesteps
+ embedding_dim (`int`, *optional*, defaults to 512):
+ dimension of the embeddings to generate
+ dtype:
+ data type of the generated embeddings
+
+ Returns:
+ `torch.Tensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
+ """
+ assert len(w.shape) == 1
+ w = w * 1000.0
+
+ half_dim = embedding_dim // 2
+ emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
+ emb = w.to(dtype)[:, None] * emb[None, :]
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+ if embedding_dim % 2 == 1: # zero pad
+ emb = torch.nn.functional.pad(emb, (0, 1))
+ assert emb.shape == (w.shape[0], embedding_dim)
+ return emb
+
+ @property
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.guidance_scale
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.guidance_rescale
+ def guidance_rescale(self):
+ return self._guidance_rescale
+
+ @property
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.clip_skip
+ def clip_skip(self):
+ return self._clip_skip
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ @property
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.do_classifier_free_guidance
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
+
+ @property
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.cross_attention_kwargs
+ def cross_attention_kwargs(self):
+ return self._cross_attention_kwargs
+
+ @property
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.num_timesteps
+ def num_timesteps(self):
+ return self._num_timesteps
diff --git a/diffusers/examples/community/lpw_stable_diffusion.py b/diffusers/examples/community/lpw_stable_diffusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..ec27acdce3318fd8cd27363d8cf452a5100d24e3
--- /dev/null
+++ b/diffusers/examples/community/lpw_stable_diffusion.py
@@ -0,0 +1,1427 @@
+import inspect
+import re
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import numpy as np
+import PIL.Image
+import torch
+from packaging import version
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
+
+from diffusers import DiffusionPipeline
+from diffusers.configuration_utils import FrozenDict
+from diffusers.image_processor import VaeImageProcessor
+from diffusers.loaders import FromSingleFileMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
+from diffusers.models import AutoencoderKL, UNet2DConditionModel
+from diffusers.models.lora import adjust_lora_scale_text_encoder
+from diffusers.pipelines.pipeline_utils import StableDiffusionMixin
+from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
+from diffusers.schedulers import KarrasDiffusionSchedulers
+from diffusers.utils import (
+ PIL_INTERPOLATION,
+ USE_PEFT_BACKEND,
+ deprecate,
+ logging,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
+from diffusers.utils.torch_utils import randn_tensor
+
+
+# ------------------------------------------------------------------------------
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+re_attention = re.compile(
+ r"""
+\\\(|
+\\\)|
+\\\[|
+\\]|
+\\\\|
+\\|
+\(|
+\[|
+:([+-]?[.\d]+)\)|
+\)|
+]|
+[^\\()\[\]:]+|
+:
+""",
+ re.X,
+)
+
+
+def parse_prompt_attention(text):
+ """
+ Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
+ Accepted tokens are:
+ (abc) - increases attention to abc by a multiplier of 1.1
+ (abc:3.12) - increases attention to abc by a multiplier of 3.12
+ [abc] - decreases attention to abc by a multiplier of 1.1
+ \\( - literal character '('
+ \\[ - literal character '['
+ \\) - literal character ')'
+ \\] - literal character ']'
+ \\ - literal character '\'
+ anything else - just text
+ >>> parse_prompt_attention('normal text')
+ [['normal text', 1.0]]
+ >>> parse_prompt_attention('an (important) word')
+ [['an ', 1.0], ['important', 1.1], [' word', 1.0]]
+ >>> parse_prompt_attention('(unbalanced')
+ [['unbalanced', 1.1]]
+ >>> parse_prompt_attention('\\(literal\\]')
+ [['(literal]', 1.0]]
+ >>> parse_prompt_attention('(unnecessary)(parens)')
+ [['unnecessaryparens', 1.1]]
+ >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
+ [['a ', 1.0],
+ ['house', 1.5730000000000004],
+ [' ', 1.1],
+ ['on', 1.0],
+ [' a ', 1.1],
+ ['hill', 0.55],
+ [', sun, ', 1.1],
+ ['sky', 1.4641000000000006],
+ ['.', 1.1]]
+ """
+
+ res = []
+ round_brackets = []
+ square_brackets = []
+
+ round_bracket_multiplier = 1.1
+ square_bracket_multiplier = 1 / 1.1
+
+ def multiply_range(start_position, multiplier):
+ for p in range(start_position, len(res)):
+ res[p][1] *= multiplier
+
+ for m in re_attention.finditer(text):
+ text = m.group(0)
+ weight = m.group(1)
+
+ if text.startswith("\\"):
+ res.append([text[1:], 1.0])
+ elif text == "(":
+ round_brackets.append(len(res))
+ elif text == "[":
+ square_brackets.append(len(res))
+ elif weight is not None and len(round_brackets) > 0:
+ multiply_range(round_brackets.pop(), float(weight))
+ elif text == ")" and len(round_brackets) > 0:
+ multiply_range(round_brackets.pop(), round_bracket_multiplier)
+ elif text == "]" and len(square_brackets) > 0:
+ multiply_range(square_brackets.pop(), square_bracket_multiplier)
+ else:
+ res.append([text, 1.0])
+
+ for pos in round_brackets:
+ multiply_range(pos, round_bracket_multiplier)
+
+ for pos in square_brackets:
+ multiply_range(pos, square_bracket_multiplier)
+
+ if len(res) == 0:
+ res = [["", 1.0]]
+
+ # merge runs of identical weights
+ i = 0
+ while i + 1 < len(res):
+ if res[i][1] == res[i + 1][1]:
+ res[i][0] += res[i + 1][0]
+ res.pop(i + 1)
+ else:
+ i += 1
+
+ return res
+
+
+def get_prompts_with_weights(pipe: DiffusionPipeline, prompt: List[str], max_length: int):
+ r"""
+ Tokenize a list of prompts and return its tokens with weights of each token.
+
+ No padding, starting or ending token is included.
+ """
+ tokens = []
+ weights = []
+ truncated = False
+ for text in prompt:
+ texts_and_weights = parse_prompt_attention(text)
+ text_token = []
+ text_weight = []
+ for word, weight in texts_and_weights:
+ # tokenize and discard the starting and the ending token
+ token = pipe.tokenizer(word).input_ids[1:-1]
+ text_token += token
+ # copy the weight by length of token
+ text_weight += [weight] * len(token)
+ # stop if the text is too long (longer than truncation limit)
+ if len(text_token) > max_length:
+ truncated = True
+ break
+ # truncate
+ if len(text_token) > max_length:
+ truncated = True
+ text_token = text_token[:max_length]
+ text_weight = text_weight[:max_length]
+ tokens.append(text_token)
+ weights.append(text_weight)
+ if truncated:
+ logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
+ return tokens, weights
+
+
+def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, pad, no_boseos_middle=True, chunk_length=77):
+ r"""
+ Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
+ """
+ max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
+ weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length
+ for i in range(len(tokens)):
+ tokens[i] = [bos] + tokens[i] + [pad] * (max_length - 1 - len(tokens[i]) - 1) + [eos]
+ if no_boseos_middle:
+ weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
+ else:
+ w = []
+ if len(weights[i]) == 0:
+ w = [1.0] * weights_length
+ else:
+ for j in range(max_embeddings_multiples):
+ w.append(1.0) # weight for starting token in this chunk
+ w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))]
+ w.append(1.0) # weight for ending token in this chunk
+ w += [1.0] * (weights_length - len(w))
+ weights[i] = w[:]
+
+ return tokens, weights
+
+
+def get_unweighted_text_embeddings(
+ pipe: DiffusionPipeline,
+ text_input: torch.Tensor,
+ chunk_length: int,
+ no_boseos_middle: Optional[bool] = True,
+ clip_skip: Optional[int] = None,
+):
+ """
+ When the length of tokens is a multiple of the capacity of the text encoder,
+ it should be split into chunks and sent to the text encoder individually.
+ """
+ max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2)
+ if max_embeddings_multiples > 1:
+ text_embeddings = []
+ for i in range(max_embeddings_multiples):
+ # extract the i-th chunk
+ text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone()
+
+ # cover the head and the tail by the starting and the ending tokens
+ text_input_chunk[:, 0] = text_input[0, 0]
+ text_input_chunk[:, -1] = text_input[0, -1]
+ if clip_skip is None:
+ prompt_embeds = pipe.text_encoder(text_input_chunk.to(pipe.device))
+ text_embedding = prompt_embeds[0]
+ else:
+ prompt_embeds = pipe.text_encoder(text_input_chunk.to(pipe.device), output_hidden_states=True)
+ # Access the `hidden_states` first, that contains a tuple of
+ # all the hidden states from the encoder layers. Then index into
+ # the tuple to access the hidden states from the desired layer.
+ prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
+ # We also need to apply the final LayerNorm here to not mess with the
+ # representations. The `last_hidden_states` that we typically use for
+ # obtaining the final prompt representations passes through the LayerNorm
+ # layer.
+ text_embedding = pipe.text_encoder.text_model.final_layer_norm(prompt_embeds)
+
+ if no_boseos_middle:
+ if i == 0:
+ # discard the ending token
+ text_embedding = text_embedding[:, :-1]
+ elif i == max_embeddings_multiples - 1:
+ # discard the starting token
+ text_embedding = text_embedding[:, 1:]
+ else:
+ # discard both starting and ending tokens
+ text_embedding = text_embedding[:, 1:-1]
+
+ text_embeddings.append(text_embedding)
+ text_embeddings = torch.concat(text_embeddings, axis=1)
+ else:
+ if clip_skip is None:
+ clip_skip = 0
+ prompt_embeds = pipe.text_encoder(text_input, output_hidden_states=True)[-1][-(clip_skip + 1)]
+ text_embeddings = pipe.text_encoder.text_model.final_layer_norm(prompt_embeds)
+ return text_embeddings
+
+
+def get_weighted_text_embeddings(
+ pipe: DiffusionPipeline,
+ prompt: Union[str, List[str]],
+ uncond_prompt: Optional[Union[str, List[str]]] = None,
+ max_embeddings_multiples: Optional[int] = 3,
+ no_boseos_middle: Optional[bool] = False,
+ skip_parsing: Optional[bool] = False,
+ skip_weighting: Optional[bool] = False,
+ clip_skip=None,
+ lora_scale=None,
+):
+ r"""
+ Prompts can be assigned with local weights using brackets. For example,
+ prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful',
+ and the embedding tokens corresponding to the words get multiplied by a constant, 1.1.
+
+ Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean.
+
+ Args:
+ pipe (`DiffusionPipeline`):
+ Pipe to provide access to the tokenizer and the text encoder.
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image generation.
+ uncond_prompt (`str` or `List[str]`):
+ The unconditional prompt or prompts for guide the image generation. If unconditional prompt
+ is provided, the embeddings of prompt and uncond_prompt are concatenated.
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
+ no_boseos_middle (`bool`, *optional*, defaults to `False`):
+ If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and
+ ending token in each of the chunk in the middle.
+ skip_parsing (`bool`, *optional*, defaults to `False`):
+ Skip the parsing of brackets.
+ skip_weighting (`bool`, *optional*, defaults to `False`):
+ Skip the weighting. When the parsing is skipped, it is forced True.
+ """
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(pipe, StableDiffusionLoraLoaderMixin):
+ pipe._lora_scale = lora_scale
+
+ # dynamically adjust the LoRA scale
+ if not USE_PEFT_BACKEND:
+ adjust_lora_scale_text_encoder(pipe.text_encoder, lora_scale)
+ else:
+ scale_lora_layers(pipe.text_encoder, lora_scale)
+ max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
+ if isinstance(prompt, str):
+ prompt = [prompt]
+
+ if not skip_parsing:
+ prompt_tokens, prompt_weights = get_prompts_with_weights(pipe, prompt, max_length - 2)
+ if uncond_prompt is not None:
+ if isinstance(uncond_prompt, str):
+ uncond_prompt = [uncond_prompt]
+ uncond_tokens, uncond_weights = get_prompts_with_weights(pipe, uncond_prompt, max_length - 2)
+ else:
+ prompt_tokens = [
+ token[1:-1] for token in pipe.tokenizer(prompt, max_length=max_length, truncation=True).input_ids
+ ]
+ prompt_weights = [[1.0] * len(token) for token in prompt_tokens]
+ if uncond_prompt is not None:
+ if isinstance(uncond_prompt, str):
+ uncond_prompt = [uncond_prompt]
+ uncond_tokens = [
+ token[1:-1]
+ for token in pipe.tokenizer(uncond_prompt, max_length=max_length, truncation=True).input_ids
+ ]
+ uncond_weights = [[1.0] * len(token) for token in uncond_tokens]
+
+ # round up the longest length of tokens to a multiple of (model_max_length - 2)
+ max_length = max([len(token) for token in prompt_tokens])
+ if uncond_prompt is not None:
+ max_length = max(max_length, max([len(token) for token in uncond_tokens]))
+
+ max_embeddings_multiples = min(
+ max_embeddings_multiples,
+ (max_length - 1) // (pipe.tokenizer.model_max_length - 2) + 1,
+ )
+ max_embeddings_multiples = max(1, max_embeddings_multiples)
+ max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
+
+ # pad the length of tokens and weights
+ bos = pipe.tokenizer.bos_token_id
+ eos = pipe.tokenizer.eos_token_id
+ pad = getattr(pipe.tokenizer, "pad_token_id", eos)
+ prompt_tokens, prompt_weights = pad_tokens_and_weights(
+ prompt_tokens,
+ prompt_weights,
+ max_length,
+ bos,
+ eos,
+ pad,
+ no_boseos_middle=no_boseos_middle,
+ chunk_length=pipe.tokenizer.model_max_length,
+ )
+ prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=pipe.device)
+ if uncond_prompt is not None:
+ uncond_tokens, uncond_weights = pad_tokens_and_weights(
+ uncond_tokens,
+ uncond_weights,
+ max_length,
+ bos,
+ eos,
+ pad,
+ no_boseos_middle=no_boseos_middle,
+ chunk_length=pipe.tokenizer.model_max_length,
+ )
+ uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=pipe.device)
+
+ # get the embeddings
+ text_embeddings = get_unweighted_text_embeddings(
+ pipe, prompt_tokens, pipe.tokenizer.model_max_length, no_boseos_middle=no_boseos_middle, clip_skip=clip_skip
+ )
+ prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=text_embeddings.device)
+ if uncond_prompt is not None:
+ uncond_embeddings = get_unweighted_text_embeddings(
+ pipe,
+ uncond_tokens,
+ pipe.tokenizer.model_max_length,
+ no_boseos_middle=no_boseos_middle,
+ clip_skip=clip_skip,
+ )
+ uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=uncond_embeddings.device)
+
+ # assign weights to the prompts and normalize in the sense of mean
+ # TODO: should we normalize by chunk or in a whole (current implementation)?
+ if (not skip_parsing) and (not skip_weighting):
+ previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
+ text_embeddings *= prompt_weights.unsqueeze(-1)
+ current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
+ text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
+ if uncond_prompt is not None:
+ previous_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
+ uncond_embeddings *= uncond_weights.unsqueeze(-1)
+ current_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
+ uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
+
+ if pipe.text_encoder is not None:
+ if isinstance(pipe, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(pipe.text_encoder, lora_scale)
+
+ if uncond_prompt is not None:
+ return text_embeddings, uncond_embeddings
+ return text_embeddings, None
+
+
+def preprocess_image(image, batch_size):
+ w, h = image.size
+ w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8
+ image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
+ image = np.array(image).astype(np.float32) / 255.0
+ image = np.vstack([image[None].transpose(0, 3, 1, 2)] * batch_size)
+ image = torch.from_numpy(image)
+ return 2.0 * image - 1.0
+
+
+def preprocess_mask(mask, batch_size, scale_factor=8):
+ if not isinstance(mask, torch.Tensor):
+ mask = mask.convert("L")
+ w, h = mask.size
+ w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8
+ mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"])
+ mask = np.array(mask).astype(np.float32) / 255.0
+ mask = np.tile(mask, (4, 1, 1))
+ mask = np.vstack([mask[None]] * batch_size)
+ mask = 1 - mask # repaint white, keep black
+ mask = torch.from_numpy(mask)
+ return mask
+
+ else:
+ valid_mask_channel_sizes = [1, 3]
+ # if mask channel is fourth tensor dimension, permute dimensions to pytorch standard (B, C, H, W)
+ if mask.shape[3] in valid_mask_channel_sizes:
+ mask = mask.permute(0, 3, 1, 2)
+ elif mask.shape[1] not in valid_mask_channel_sizes:
+ raise ValueError(
+ f"Mask channel dimension of size in {valid_mask_channel_sizes} should be second or fourth dimension,"
+ f" but received mask of shape {tuple(mask.shape)}"
+ )
+ # (potentially) reduce mask channel dimension from 3 to 1 for broadcasting to latent shape
+ mask = mask.mean(dim=1, keepdim=True)
+ h, w = mask.shape[-2:]
+ h, w = (x - x % 8 for x in (h, w)) # resize to integer multiple of 8
+ mask = torch.nn.functional.interpolate(mask, (h // scale_factor, w // scale_factor))
+ return mask
+
+
+class StableDiffusionLongPromptWeightingPipeline(
+ DiffusionPipeline,
+ StableDiffusionMixin,
+ TextualInversionLoaderMixin,
+ StableDiffusionLoraLoaderMixin,
+ FromSingleFileMixin,
+):
+ r"""
+ Pipeline for text-to-image generation using Stable Diffusion without tokens length limit, and support parsing
+ weighting in prompt.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
+ feature_extractor ([`CLIPImageProcessor`]):
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+ """
+
+ model_cpu_offload_seq = "text_encoder-->unet->vae"
+ _optional_components = ["safety_checker", "feature_extractor"]
+ _exclude_from_cpu_offload = ["safety_checker"]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: KarrasDiffusionSchedulers,
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPImageProcessor,
+ requires_safety_checker: bool = True,
+ ):
+ super().__init__()
+
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
+ " file"
+ )
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["steps_offset"] = 1
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
+ )
+ deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["clip_sample"] = False
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if safety_checker is None and requires_safety_checker:
+ logger.warning(
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
+ )
+
+ if safety_checker is not None and feature_extractor is None:
+ raise ValueError(
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
+ )
+
+ is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
+ version.parse(unet.config._diffusers_version).base_version
+ ) < version.parse("0.9.0.dev0")
+ is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
+ deprecation_message = (
+ "The configuration file of the unet has set the default `sample_size` to smaller than"
+ " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
+ " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
+ " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
+ " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
+ " in the config might lead to incorrect results in future versions. If you have downloaded this"
+ " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
+ " the `unet/config.json` file"
+ )
+ deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(unet.config)
+ new_config["sample_size"] = 64
+ unet._internal_dict = FrozenDict(new_config)
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+ self.register_to_config(
+ requires_safety_checker=requires_safety_checker,
+ )
+
+ def _encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt=None,
+ max_embeddings_multiples=3,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ clip_skip: Optional[int] = None,
+ lora_scale: Optional[float] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `list(int)`):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
+ """
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if negative_prompt_embeds is None:
+ if negative_prompt is None:
+ negative_prompt = [""] * batch_size
+ elif isinstance(negative_prompt, str):
+ negative_prompt = [negative_prompt] * batch_size
+ if batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ if prompt_embeds is None or negative_prompt_embeds is None:
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = self.maybe_convert_prompt(negative_prompt, self.tokenizer)
+
+ prompt_embeds1, negative_prompt_embeds1 = get_weighted_text_embeddings(
+ pipe=self,
+ prompt=prompt,
+ uncond_prompt=negative_prompt if do_classifier_free_guidance else None,
+ max_embeddings_multiples=max_embeddings_multiples,
+ clip_skip=clip_skip,
+ lora_scale=lora_scale,
+ )
+ if prompt_embeds is None:
+ prompt_embeds = prompt_embeds1
+ if negative_prompt_embeds is None:
+ negative_prompt_embeds = negative_prompt_embeds1
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ if do_classifier_free_guidance:
+ bs_embed, seq_len, _ = negative_prompt_embeds.shape
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+
+ return prompt_embeds
+
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ strength,
+ callback_steps,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ ):
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if strength < 0 or strength > 1:
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ def get_timesteps(self, num_inference_steps, strength, device, is_text2img):
+ if is_text2img:
+ return self.scheduler.timesteps.to(device), num_inference_steps
+ else:
+ # get the original timestep using init_timestep
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
+
+ t_start = max(num_inference_steps - init_timestep, 0)
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
+
+ return timesteps, num_inference_steps - t_start
+
+ def run_safety_checker(self, image, device, dtype):
+ if self.safety_checker is not None:
+ safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
+ image, has_nsfw_concept = self.safety_checker(
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
+ )
+ else:
+ has_nsfw_concept = None
+ return image, has_nsfw_concept
+
+ def decode_latents(self, latents):
+ latents = 1 / self.vae.config.scaling_factor * latents
+ image = self.vae.decode(latents).sample
+ image = (image / 2 + 0.5).clamp(0, 1)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+ return image
+
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def prepare_latents(
+ self,
+ image,
+ timestep,
+ num_images_per_prompt,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ latents=None,
+ ):
+ if image is None:
+ batch_size = batch_size * num_images_per_prompt
+ shape = (
+ batch_size,
+ num_channels_latents,
+ int(height) // self.vae_scale_factor,
+ int(width) // self.vae_scale_factor,
+ )
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents, None, None
+ else:
+ image = image.to(device=self.device, dtype=dtype)
+ init_latent_dist = self.vae.encode(image).latent_dist
+ init_latents = init_latent_dist.sample(generator=generator)
+ init_latents = self.vae.config.scaling_factor * init_latents
+
+ # Expand init_latents for batch_size and num_images_per_prompt
+ init_latents = torch.cat([init_latents] * num_images_per_prompt, dim=0)
+ init_latents_orig = init_latents
+
+ # add noise to latents using the timesteps
+ noise = randn_tensor(init_latents.shape, generator=generator, device=self.device, dtype=dtype)
+ init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
+ latents = init_latents
+ return latents, init_latents_orig, noise
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]],
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ image: Union[torch.Tensor, PIL.Image.Image] = None,
+ mask_image: Union[torch.Tensor, PIL.Image.Image] = None,
+ height: int = 512,
+ width: int = 512,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ strength: float = 0.8,
+ num_images_per_prompt: Optional[int] = 1,
+ add_predicted_noise: Optional[bool] = False,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ max_embeddings_multiples: Optional[int] = 3,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
+ is_cancelled_callback: Optional[Callable[[], bool]] = None,
+ clip_skip: Optional[int] = None,
+ callback_steps: int = 1,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image generation.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ image (`torch.Tensor` or `PIL.Image.Image`):
+ `Image`, or tensor representing an image batch, that will be used as the starting point for the
+ process.
+ mask_image (`torch.Tensor` or `PIL.Image.Image`):
+ `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
+ replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
+ PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
+ contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
+ height (`int`, *optional*, defaults to 512):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to 512):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ strength (`float`, *optional*, defaults to 0.8):
+ Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1.
+ `image` will be used as a starting point, adding more noise to it the larger the `strength`. The
+ number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added
+ noise will be maximum and the denoising process will run for the full number of iterations specified in
+ `num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ add_predicted_noise (`bool`, *optional*, defaults to True):
+ Use predicted noise instead of random noise when constructing noisy versions of the original image in
+ the reverse diffusion process
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
+ is_cancelled_callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. If the function returns
+ `True`, the inference will be cancelled.
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+
+ Returns:
+ `None` if cancelled by `is_cancelled_callback`,
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+ # 0. Default height and width to unet
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt, height, width, strength, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
+ )
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+ lora_scale = cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
+
+ # 3. Encode input prompt
+ prompt_embeds = self._encode_prompt(
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt,
+ max_embeddings_multiples,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ clip_skip=clip_skip,
+ lora_scale=lora_scale,
+ )
+ dtype = prompt_embeds.dtype
+
+ # 4. Preprocess image and mask
+ if isinstance(image, PIL.Image.Image):
+ image = preprocess_image(image, batch_size)
+ if image is not None:
+ image = image.to(device=self.device, dtype=dtype)
+ if isinstance(mask_image, PIL.Image.Image):
+ mask_image = preprocess_mask(mask_image, batch_size, self.vae_scale_factor)
+ if mask_image is not None:
+ mask = mask_image.to(device=self.device, dtype=dtype)
+ mask = torch.cat([mask] * num_images_per_prompt)
+ else:
+ mask = None
+
+ # 5. set timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device, image is None)
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
+
+ # 6. Prepare latent variables
+ latents, init_latents_orig, noise = self.prepare_latents(
+ image,
+ latent_timestep,
+ num_images_per_prompt,
+ batch_size,
+ self.unet.config.in_channels,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 8. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs=cross_attention_kwargs,
+ ).sample
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+ if mask is not None:
+ # masking
+ if add_predicted_noise:
+ init_latents_proper = self.scheduler.add_noise(
+ init_latents_orig, noise_pred_uncond, torch.tensor([t])
+ )
+ else:
+ init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))
+ latents = (init_latents_proper * mask) + (latents * (1 - mask))
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if i % callback_steps == 0:
+ if callback is not None:
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
+ if is_cancelled_callback is not None and is_cancelled_callback():
+ return None
+
+ if output_type == "latent":
+ image = latents
+ has_nsfw_concept = None
+ elif output_type == "pil":
+ # 9. Post-processing
+ image = self.decode_latents(latents)
+
+ # 10. Run safety checker
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
+
+ # 11. Convert to PIL
+ image = self.numpy_to_pil(image)
+ else:
+ # 9. Post-processing
+ image = self.decode_latents(latents)
+
+ # 10. Run safety checker
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
+
+ # Offload last model to CPU
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.final_offload_hook.offload()
+
+ if not return_dict:
+ return image, has_nsfw_concept
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
+
+ def text2img(
+ self,
+ prompt: Union[str, List[str]],
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ height: int = 512,
+ width: int = 512,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ max_embeddings_multiples: Optional[int] = 3,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
+ is_cancelled_callback: Optional[Callable[[], bool]] = None,
+ clip_skip=None,
+ callback_steps: int = 1,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ ):
+ r"""
+ Function for text-to-image generation.
+ Args:
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image generation.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ height (`int`, *optional*, defaults to 512):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to 512):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
+ is_cancelled_callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. If the function returns
+ `True`, the inference will be cancelled.
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+
+ Returns:
+ `None` if cancelled by `is_cancelled_callback`,
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+ return self.__call__(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ height=height,
+ width=width,
+ num_inference_steps=num_inference_steps,
+ guidance_scale=guidance_scale,
+ num_images_per_prompt=num_images_per_prompt,
+ eta=eta,
+ generator=generator,
+ latents=latents,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ max_embeddings_multiples=max_embeddings_multiples,
+ output_type=output_type,
+ return_dict=return_dict,
+ callback=callback,
+ is_cancelled_callback=is_cancelled_callback,
+ clip_skip=clip_skip,
+ callback_steps=callback_steps,
+ cross_attention_kwargs=cross_attention_kwargs,
+ )
+
+ def img2img(
+ self,
+ image: Union[torch.Tensor, PIL.Image.Image],
+ prompt: Union[str, List[str]],
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ strength: float = 0.8,
+ num_inference_steps: Optional[int] = 50,
+ guidance_scale: Optional[float] = 7.5,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: Optional[float] = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ max_embeddings_multiples: Optional[int] = 3,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
+ is_cancelled_callback: Optional[Callable[[], bool]] = None,
+ callback_steps: int = 1,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ ):
+ r"""
+ Function for image-to-image generation.
+ Args:
+ image (`torch.Tensor` or `PIL.Image.Image`):
+ `Image`, or tensor representing an image batch, that will be used as the starting point for the
+ process.
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image generation.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ strength (`float`, *optional*, defaults to 0.8):
+ Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1.
+ `image` will be used as a starting point, adding more noise to it the larger the `strength`. The
+ number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added
+ noise will be maximum and the denoising process will run for the full number of iterations specified in
+ `num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference. This parameter will be modulated by `strength`.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
+ is_cancelled_callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. If the function returns
+ `True`, the inference will be cancelled.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+
+ Returns:
+ `None` if cancelled by `is_cancelled_callback`,
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+ return self.__call__(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ image=image,
+ num_inference_steps=num_inference_steps,
+ guidance_scale=guidance_scale,
+ strength=strength,
+ num_images_per_prompt=num_images_per_prompt,
+ eta=eta,
+ generator=generator,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ max_embeddings_multiples=max_embeddings_multiples,
+ output_type=output_type,
+ return_dict=return_dict,
+ callback=callback,
+ is_cancelled_callback=is_cancelled_callback,
+ callback_steps=callback_steps,
+ cross_attention_kwargs=cross_attention_kwargs,
+ )
+
+ def inpaint(
+ self,
+ image: Union[torch.Tensor, PIL.Image.Image],
+ mask_image: Union[torch.Tensor, PIL.Image.Image],
+ prompt: Union[str, List[str]],
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ strength: float = 0.8,
+ num_inference_steps: Optional[int] = 50,
+ guidance_scale: Optional[float] = 7.5,
+ num_images_per_prompt: Optional[int] = 1,
+ add_predicted_noise: Optional[bool] = False,
+ eta: Optional[float] = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ max_embeddings_multiples: Optional[int] = 3,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
+ is_cancelled_callback: Optional[Callable[[], bool]] = None,
+ callback_steps: int = 1,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ ):
+ r"""
+ Function for inpaint.
+ Args:
+ image (`torch.Tensor` or `PIL.Image.Image`):
+ `Image`, or tensor representing an image batch, that will be used as the starting point for the
+ process. This is the image whose masked region will be inpainted.
+ mask_image (`torch.Tensor` or `PIL.Image.Image`):
+ `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
+ replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
+ PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
+ contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image generation.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ strength (`float`, *optional*, defaults to 0.8):
+ Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength`
+ is 1, the denoising process will be run on the masked area for the full number of iterations specified
+ in `num_inference_steps`. `image` will be used as a reference for the masked area, adding more
+ noise to that region the larger the `strength`. If `strength` is 0, no inpainting will occur.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The reference number of denoising steps. More denoising steps usually lead to a higher quality image at
+ the expense of slower inference. This parameter will be modulated by `strength`, as explained above.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ add_predicted_noise (`bool`, *optional*, defaults to True):
+ Use predicted noise instead of random noise when constructing noisy versions of the original image in
+ the reverse diffusion process
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
+ is_cancelled_callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. If the function returns
+ `True`, the inference will be cancelled.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+
+ Returns:
+ `None` if cancelled by `is_cancelled_callback`,
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+ return self.__call__(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ image=image,
+ mask_image=mask_image,
+ num_inference_steps=num_inference_steps,
+ guidance_scale=guidance_scale,
+ strength=strength,
+ num_images_per_prompt=num_images_per_prompt,
+ add_predicted_noise=add_predicted_noise,
+ eta=eta,
+ generator=generator,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ max_embeddings_multiples=max_embeddings_multiples,
+ output_type=output_type,
+ return_dict=return_dict,
+ callback=callback,
+ is_cancelled_callback=is_cancelled_callback,
+ callback_steps=callback_steps,
+ cross_attention_kwargs=cross_attention_kwargs,
+ )
diff --git a/diffusers/examples/community/lpw_stable_diffusion_onnx.py b/diffusers/examples/community/lpw_stable_diffusion_onnx.py
new file mode 100644
index 0000000000000000000000000000000000000000..87c2944dbc44a063b33c5a9be9ff7bbdbcb77544
--- /dev/null
+++ b/diffusers/examples/community/lpw_stable_diffusion_onnx.py
@@ -0,0 +1,1148 @@
+import inspect
+import re
+from typing import Callable, List, Optional, Union
+
+import numpy as np
+import PIL.Image
+import torch
+from packaging import version
+from transformers import CLIPImageProcessor, CLIPTokenizer
+
+import diffusers
+from diffusers import OnnxRuntimeModel, OnnxStableDiffusionPipeline, SchedulerMixin
+from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
+from diffusers.utils import logging
+
+
+try:
+ from diffusers.pipelines.onnx_utils import ORT_TO_NP_TYPE
+except ImportError:
+ ORT_TO_NP_TYPE = {
+ "tensor(bool)": np.bool_,
+ "tensor(int8)": np.int8,
+ "tensor(uint8)": np.uint8,
+ "tensor(int16)": np.int16,
+ "tensor(uint16)": np.uint16,
+ "tensor(int32)": np.int32,
+ "tensor(uint32)": np.uint32,
+ "tensor(int64)": np.int64,
+ "tensor(uint64)": np.uint64,
+ "tensor(float16)": np.float16,
+ "tensor(float)": np.float32,
+ "tensor(double)": np.float64,
+ }
+
+try:
+ from diffusers.utils import PIL_INTERPOLATION
+except ImportError:
+ if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
+ PIL_INTERPOLATION = {
+ "linear": PIL.Image.Resampling.BILINEAR,
+ "bilinear": PIL.Image.Resampling.BILINEAR,
+ "bicubic": PIL.Image.Resampling.BICUBIC,
+ "lanczos": PIL.Image.Resampling.LANCZOS,
+ "nearest": PIL.Image.Resampling.NEAREST,
+ }
+ else:
+ PIL_INTERPOLATION = {
+ "linear": PIL.Image.LINEAR,
+ "bilinear": PIL.Image.BILINEAR,
+ "bicubic": PIL.Image.BICUBIC,
+ "lanczos": PIL.Image.LANCZOS,
+ "nearest": PIL.Image.NEAREST,
+ }
+# ------------------------------------------------------------------------------
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+re_attention = re.compile(
+ r"""
+\\\(|
+\\\)|
+\\\[|
+\\]|
+\\\\|
+\\|
+\(|
+\[|
+:([+-]?[.\d]+)\)|
+\)|
+]|
+[^\\()\[\]:]+|
+:
+""",
+ re.X,
+)
+
+
+def parse_prompt_attention(text):
+ """
+ Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
+ Accepted tokens are:
+ (abc) - increases attention to abc by a multiplier of 1.1
+ (abc:3.12) - increases attention to abc by a multiplier of 3.12
+ [abc] - decreases attention to abc by a multiplier of 1.1
+ \\( - literal character '('
+ \\[ - literal character '['
+ \\) - literal character ')'
+ \\] - literal character ']'
+ \\ - literal character '\'
+ anything else - just text
+ >>> parse_prompt_attention('normal text')
+ [['normal text', 1.0]]
+ >>> parse_prompt_attention('an (important) word')
+ [['an ', 1.0], ['important', 1.1], [' word', 1.0]]
+ >>> parse_prompt_attention('(unbalanced')
+ [['unbalanced', 1.1]]
+ >>> parse_prompt_attention('\\(literal\\]')
+ [['(literal]', 1.0]]
+ >>> parse_prompt_attention('(unnecessary)(parens)')
+ [['unnecessaryparens', 1.1]]
+ >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
+ [['a ', 1.0],
+ ['house', 1.5730000000000004],
+ [' ', 1.1],
+ ['on', 1.0],
+ [' a ', 1.1],
+ ['hill', 0.55],
+ [', sun, ', 1.1],
+ ['sky', 1.4641000000000006],
+ ['.', 1.1]]
+ """
+
+ res = []
+ round_brackets = []
+ square_brackets = []
+
+ round_bracket_multiplier = 1.1
+ square_bracket_multiplier = 1 / 1.1
+
+ def multiply_range(start_position, multiplier):
+ for p in range(start_position, len(res)):
+ res[p][1] *= multiplier
+
+ for m in re_attention.finditer(text):
+ text = m.group(0)
+ weight = m.group(1)
+
+ if text.startswith("\\"):
+ res.append([text[1:], 1.0])
+ elif text == "(":
+ round_brackets.append(len(res))
+ elif text == "[":
+ square_brackets.append(len(res))
+ elif weight is not None and len(round_brackets) > 0:
+ multiply_range(round_brackets.pop(), float(weight))
+ elif text == ")" and len(round_brackets) > 0:
+ multiply_range(round_brackets.pop(), round_bracket_multiplier)
+ elif text == "]" and len(square_brackets) > 0:
+ multiply_range(square_brackets.pop(), square_bracket_multiplier)
+ else:
+ res.append([text, 1.0])
+
+ for pos in round_brackets:
+ multiply_range(pos, round_bracket_multiplier)
+
+ for pos in square_brackets:
+ multiply_range(pos, square_bracket_multiplier)
+
+ if len(res) == 0:
+ res = [["", 1.0]]
+
+ # merge runs of identical weights
+ i = 0
+ while i + 1 < len(res):
+ if res[i][1] == res[i + 1][1]:
+ res[i][0] += res[i + 1][0]
+ res.pop(i + 1)
+ else:
+ i += 1
+
+ return res
+
+
+def get_prompts_with_weights(pipe, prompt: List[str], max_length: int):
+ r"""
+ Tokenize a list of prompts and return its tokens with weights of each token.
+
+ No padding, starting or ending token is included.
+ """
+ tokens = []
+ weights = []
+ truncated = False
+ for text in prompt:
+ texts_and_weights = parse_prompt_attention(text)
+ text_token = []
+ text_weight = []
+ for word, weight in texts_and_weights:
+ # tokenize and discard the starting and the ending token
+ token = pipe.tokenizer(word, return_tensors="np").input_ids[0, 1:-1]
+ text_token += list(token)
+ # copy the weight by length of token
+ text_weight += [weight] * len(token)
+ # stop if the text is too long (longer than truncation limit)
+ if len(text_token) > max_length:
+ truncated = True
+ break
+ # truncate
+ if len(text_token) > max_length:
+ truncated = True
+ text_token = text_token[:max_length]
+ text_weight = text_weight[:max_length]
+ tokens.append(text_token)
+ weights.append(text_weight)
+ if truncated:
+ logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
+ return tokens, weights
+
+
+def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, pad, no_boseos_middle=True, chunk_length=77):
+ r"""
+ Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
+ """
+ max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
+ weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length
+ for i in range(len(tokens)):
+ tokens[i] = [bos] + tokens[i] + [pad] * (max_length - 1 - len(tokens[i]) - 1) + [eos]
+ if no_boseos_middle:
+ weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
+ else:
+ w = []
+ if len(weights[i]) == 0:
+ w = [1.0] * weights_length
+ else:
+ for j in range(max_embeddings_multiples):
+ w.append(1.0) # weight for starting token in this chunk
+ w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))]
+ w.append(1.0) # weight for ending token in this chunk
+ w += [1.0] * (weights_length - len(w))
+ weights[i] = w[:]
+
+ return tokens, weights
+
+
+def get_unweighted_text_embeddings(
+ pipe,
+ text_input: np.array,
+ chunk_length: int,
+ no_boseos_middle: Optional[bool] = True,
+):
+ """
+ When the length of tokens is a multiple of the capacity of the text encoder,
+ it should be split into chunks and sent to the text encoder individually.
+ """
+ max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2)
+ if max_embeddings_multiples > 1:
+ text_embeddings = []
+ for i in range(max_embeddings_multiples):
+ # extract the i-th chunk
+ text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].copy()
+
+ # cover the head and the tail by the starting and the ending tokens
+ text_input_chunk[:, 0] = text_input[0, 0]
+ text_input_chunk[:, -1] = text_input[0, -1]
+
+ text_embedding = pipe.text_encoder(input_ids=text_input_chunk)[0]
+
+ if no_boseos_middle:
+ if i == 0:
+ # discard the ending token
+ text_embedding = text_embedding[:, :-1]
+ elif i == max_embeddings_multiples - 1:
+ # discard the starting token
+ text_embedding = text_embedding[:, 1:]
+ else:
+ # discard both starting and ending tokens
+ text_embedding = text_embedding[:, 1:-1]
+
+ text_embeddings.append(text_embedding)
+ text_embeddings = np.concatenate(text_embeddings, axis=1)
+ else:
+ text_embeddings = pipe.text_encoder(input_ids=text_input)[0]
+ return text_embeddings
+
+
+def get_weighted_text_embeddings(
+ pipe,
+ prompt: Union[str, List[str]],
+ uncond_prompt: Optional[Union[str, List[str]]] = None,
+ max_embeddings_multiples: Optional[int] = 4,
+ no_boseos_middle: Optional[bool] = False,
+ skip_parsing: Optional[bool] = False,
+ skip_weighting: Optional[bool] = False,
+ **kwargs,
+):
+ r"""
+ Prompts can be assigned with local weights using brackets. For example,
+ prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful',
+ and the embedding tokens corresponding to the words get multiplied by a constant, 1.1.
+
+ Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean.
+
+ Args:
+ pipe (`OnnxStableDiffusionPipeline`):
+ Pipe to provide access to the tokenizer and the text encoder.
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image generation.
+ uncond_prompt (`str` or `List[str]`):
+ The unconditional prompt or prompts for guide the image generation. If unconditional prompt
+ is provided, the embeddings of prompt and uncond_prompt are concatenated.
+ max_embeddings_multiples (`int`, *optional*, defaults to `1`):
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
+ no_boseos_middle (`bool`, *optional*, defaults to `False`):
+ If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and
+ ending token in each of the chunk in the middle.
+ skip_parsing (`bool`, *optional*, defaults to `False`):
+ Skip the parsing of brackets.
+ skip_weighting (`bool`, *optional*, defaults to `False`):
+ Skip the weighting. When the parsing is skipped, it is forced True.
+ """
+ max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
+ if isinstance(prompt, str):
+ prompt = [prompt]
+
+ if not skip_parsing:
+ prompt_tokens, prompt_weights = get_prompts_with_weights(pipe, prompt, max_length - 2)
+ if uncond_prompt is not None:
+ if isinstance(uncond_prompt, str):
+ uncond_prompt = [uncond_prompt]
+ uncond_tokens, uncond_weights = get_prompts_with_weights(pipe, uncond_prompt, max_length - 2)
+ else:
+ prompt_tokens = [
+ token[1:-1]
+ for token in pipe.tokenizer(prompt, max_length=max_length, truncation=True, return_tensors="np").input_ids
+ ]
+ prompt_weights = [[1.0] * len(token) for token in prompt_tokens]
+ if uncond_prompt is not None:
+ if isinstance(uncond_prompt, str):
+ uncond_prompt = [uncond_prompt]
+ uncond_tokens = [
+ token[1:-1]
+ for token in pipe.tokenizer(
+ uncond_prompt,
+ max_length=max_length,
+ truncation=True,
+ return_tensors="np",
+ ).input_ids
+ ]
+ uncond_weights = [[1.0] * len(token) for token in uncond_tokens]
+
+ # round up the longest length of tokens to a multiple of (model_max_length - 2)
+ max_length = max([len(token) for token in prompt_tokens])
+ if uncond_prompt is not None:
+ max_length = max(max_length, max([len(token) for token in uncond_tokens]))
+
+ max_embeddings_multiples = min(
+ max_embeddings_multiples,
+ (max_length - 1) // (pipe.tokenizer.model_max_length - 2) + 1,
+ )
+ max_embeddings_multiples = max(1, max_embeddings_multiples)
+ max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
+
+ # pad the length of tokens and weights
+ bos = pipe.tokenizer.bos_token_id
+ eos = pipe.tokenizer.eos_token_id
+ pad = getattr(pipe.tokenizer, "pad_token_id", eos)
+ prompt_tokens, prompt_weights = pad_tokens_and_weights(
+ prompt_tokens,
+ prompt_weights,
+ max_length,
+ bos,
+ eos,
+ pad,
+ no_boseos_middle=no_boseos_middle,
+ chunk_length=pipe.tokenizer.model_max_length,
+ )
+ prompt_tokens = np.array(prompt_tokens, dtype=np.int32)
+ if uncond_prompt is not None:
+ uncond_tokens, uncond_weights = pad_tokens_and_weights(
+ uncond_tokens,
+ uncond_weights,
+ max_length,
+ bos,
+ eos,
+ pad,
+ no_boseos_middle=no_boseos_middle,
+ chunk_length=pipe.tokenizer.model_max_length,
+ )
+ uncond_tokens = np.array(uncond_tokens, dtype=np.int32)
+
+ # get the embeddings
+ text_embeddings = get_unweighted_text_embeddings(
+ pipe,
+ prompt_tokens,
+ pipe.tokenizer.model_max_length,
+ no_boseos_middle=no_boseos_middle,
+ )
+ prompt_weights = np.array(prompt_weights, dtype=text_embeddings.dtype)
+ if uncond_prompt is not None:
+ uncond_embeddings = get_unweighted_text_embeddings(
+ pipe,
+ uncond_tokens,
+ pipe.tokenizer.model_max_length,
+ no_boseos_middle=no_boseos_middle,
+ )
+ uncond_weights = np.array(uncond_weights, dtype=uncond_embeddings.dtype)
+
+ # assign weights to the prompts and normalize in the sense of mean
+ # TODO: should we normalize by chunk or in a whole (current implementation)?
+ if (not skip_parsing) and (not skip_weighting):
+ previous_mean = text_embeddings.mean(axis=(-2, -1))
+ text_embeddings *= prompt_weights[:, :, None]
+ text_embeddings *= (previous_mean / text_embeddings.mean(axis=(-2, -1)))[:, None, None]
+ if uncond_prompt is not None:
+ previous_mean = uncond_embeddings.mean(axis=(-2, -1))
+ uncond_embeddings *= uncond_weights[:, :, None]
+ uncond_embeddings *= (previous_mean / uncond_embeddings.mean(axis=(-2, -1)))[:, None, None]
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ if uncond_prompt is not None:
+ return text_embeddings, uncond_embeddings
+
+ return text_embeddings
+
+
+def preprocess_image(image):
+ w, h = image.size
+ w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32
+ image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
+ image = np.array(image).astype(np.float32) / 255.0
+ image = image[None].transpose(0, 3, 1, 2)
+ return 2.0 * image - 1.0
+
+
+def preprocess_mask(mask, scale_factor=8):
+ mask = mask.convert("L")
+ w, h = mask.size
+ w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32
+ mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"])
+ mask = np.array(mask).astype(np.float32) / 255.0
+ mask = np.tile(mask, (4, 1, 1))
+ mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
+ mask = 1 - mask # repaint white, keep black
+ return mask
+
+
+class OnnxStableDiffusionLongPromptWeightingPipeline(OnnxStableDiffusionPipeline):
+ r"""
+ Pipeline for text-to-image generation using Stable Diffusion without tokens length limit, and support parsing
+ weighting in prompt.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+ """
+
+ if version.parse(version.parse(diffusers.__version__).base_version) >= version.parse("0.9.0"):
+
+ def __init__(
+ self,
+ vae_encoder: OnnxRuntimeModel,
+ vae_decoder: OnnxRuntimeModel,
+ text_encoder: OnnxRuntimeModel,
+ tokenizer: CLIPTokenizer,
+ unet: OnnxRuntimeModel,
+ scheduler: SchedulerMixin,
+ safety_checker: OnnxRuntimeModel,
+ feature_extractor: CLIPImageProcessor,
+ requires_safety_checker: bool = True,
+ ):
+ super().__init__(
+ vae_encoder=vae_encoder,
+ vae_decoder=vae_decoder,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ requires_safety_checker=requires_safety_checker,
+ )
+ self.__init__additional__()
+
+ else:
+
+ def __init__(
+ self,
+ vae_encoder: OnnxRuntimeModel,
+ vae_decoder: OnnxRuntimeModel,
+ text_encoder: OnnxRuntimeModel,
+ tokenizer: CLIPTokenizer,
+ unet: OnnxRuntimeModel,
+ scheduler: SchedulerMixin,
+ safety_checker: OnnxRuntimeModel,
+ feature_extractor: CLIPImageProcessor,
+ ):
+ super().__init__(
+ vae_encoder=vae_encoder,
+ vae_decoder=vae_decoder,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+ self.__init__additional__()
+
+ def __init__additional__(self):
+ self.unet.config.in_channels = 4
+ self.vae_scale_factor = 8
+
+ def _encode_prompt(
+ self,
+ prompt,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt,
+ max_embeddings_multiples,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `list(int)`):
+ prompt to be encoded
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
+ """
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
+
+ if negative_prompt is None:
+ negative_prompt = [""] * batch_size
+ elif isinstance(negative_prompt, str):
+ negative_prompt = [negative_prompt] * batch_size
+ if batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+
+ text_embeddings, uncond_embeddings = get_weighted_text_embeddings(
+ pipe=self,
+ prompt=prompt,
+ uncond_prompt=negative_prompt if do_classifier_free_guidance else None,
+ max_embeddings_multiples=max_embeddings_multiples,
+ )
+
+ text_embeddings = text_embeddings.repeat(num_images_per_prompt, 0)
+ if do_classifier_free_guidance:
+ uncond_embeddings = uncond_embeddings.repeat(num_images_per_prompt, 0)
+ text_embeddings = np.concatenate([uncond_embeddings, text_embeddings])
+
+ return text_embeddings
+
+ def check_inputs(self, prompt, height, width, strength, callback_steps):
+ if not isinstance(prompt, str) and not isinstance(prompt, list):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if strength < 0 or strength > 1:
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
+
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ def get_timesteps(self, num_inference_steps, strength, is_text2img):
+ if is_text2img:
+ return self.scheduler.timesteps, num_inference_steps
+ else:
+ # get the original timestep using init_timestep
+ offset = self.scheduler.config.get("steps_offset", 0)
+ init_timestep = int(num_inference_steps * strength) + offset
+ init_timestep = min(init_timestep, num_inference_steps)
+
+ t_start = max(num_inference_steps - init_timestep + offset, 0)
+ timesteps = self.scheduler.timesteps[t_start:]
+ return timesteps, num_inference_steps - t_start
+
+ def run_safety_checker(self, image):
+ if self.safety_checker is not None:
+ safety_checker_input = self.feature_extractor(
+ self.numpy_to_pil(image), return_tensors="np"
+ ).pixel_values.astype(image.dtype)
+ # There will throw an error if use safety_checker directly and batchsize>1
+ images, has_nsfw_concept = [], []
+ for i in range(image.shape[0]):
+ image_i, has_nsfw_concept_i = self.safety_checker(
+ clip_input=safety_checker_input[i : i + 1], images=image[i : i + 1]
+ )
+ images.append(image_i)
+ has_nsfw_concept.append(has_nsfw_concept_i[0])
+ image = np.concatenate(images)
+ else:
+ has_nsfw_concept = None
+ return image, has_nsfw_concept
+
+ def decode_latents(self, latents):
+ latents = 1 / 0.18215 * latents
+ # image = self.vae_decoder(latent_sample=latents)[0]
+ # it seems likes there is a strange result for using half-precision vae decoder if batchsize>1
+ image = np.concatenate(
+ [self.vae_decoder(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])]
+ )
+ image = np.clip(image / 2 + 0.5, 0, 1)
+ image = image.transpose((0, 2, 3, 1))
+ return image
+
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def prepare_latents(self, image, timestep, batch_size, height, width, dtype, generator, latents=None):
+ if image is None:
+ shape = (
+ batch_size,
+ self.unet.config.in_channels,
+ height // self.vae_scale_factor,
+ width // self.vae_scale_factor,
+ )
+
+ if latents is None:
+ latents = torch.randn(shape, generator=generator, device="cpu").numpy().astype(dtype)
+ else:
+ if latents.shape != shape:
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = (torch.from_numpy(latents) * self.scheduler.init_noise_sigma).numpy()
+ return latents, None, None
+ else:
+ init_latents = self.vae_encoder(sample=image)[0]
+ init_latents = 0.18215 * init_latents
+ init_latents = np.concatenate([init_latents] * batch_size, axis=0)
+ init_latents_orig = init_latents
+ shape = init_latents.shape
+
+ # add noise to latents using the timesteps
+ noise = torch.randn(shape, generator=generator, device="cpu").numpy().astype(dtype)
+ latents = self.scheduler.add_noise(
+ torch.from_numpy(init_latents), torch.from_numpy(noise), timestep
+ ).numpy()
+ return latents, init_latents_orig, noise
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]],
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ image: Union[np.ndarray, PIL.Image.Image] = None,
+ mask_image: Union[np.ndarray, PIL.Image.Image] = None,
+ height: int = 512,
+ width: int = 512,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ strength: float = 0.8,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[torch.Generator] = None,
+ latents: Optional[np.ndarray] = None,
+ max_embeddings_multiples: Optional[int] = 3,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
+ is_cancelled_callback: Optional[Callable[[], bool]] = None,
+ callback_steps: int = 1,
+ **kwargs,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image generation.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ image (`np.ndarray` or `PIL.Image.Image`):
+ `Image`, or tensor representing an image batch, that will be used as the starting point for the
+ process.
+ mask_image (`np.ndarray` or `PIL.Image.Image`):
+ `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
+ replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
+ PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
+ contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
+ height (`int`, *optional*, defaults to 512):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to 512):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ strength (`float`, *optional*, defaults to 0.8):
+ Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1.
+ `image` will be used as a starting point, adding more noise to it the larger the `strength`. The
+ number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added
+ noise will be maximum and the denoising process will run for the full number of iterations specified in
+ `num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator`, *optional*):
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
+ deterministic.
+ latents (`np.ndarray`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: np.ndarray)`.
+ is_cancelled_callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. If the function returns
+ `True`, the inference will be cancelled.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+
+ Returns:
+ `None` if cancelled by `is_cancelled_callback`,
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+ # 0. Default height and width to unet
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(prompt, height, width, strength, callback_steps)
+
+ # 2. Define call parameters
+ batch_size = 1 if isinstance(prompt, str) else len(prompt)
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ text_embeddings = self._encode_prompt(
+ prompt,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt,
+ max_embeddings_multiples,
+ )
+ dtype = text_embeddings.dtype
+
+ # 4. Preprocess image and mask
+ if isinstance(image, PIL.Image.Image):
+ image = preprocess_image(image)
+ if image is not None:
+ image = image.astype(dtype)
+ if isinstance(mask_image, PIL.Image.Image):
+ mask_image = preprocess_mask(mask_image, self.vae_scale_factor)
+ if mask_image is not None:
+ mask = mask_image.astype(dtype)
+ mask = np.concatenate([mask] * batch_size * num_images_per_prompt)
+ else:
+ mask = None
+
+ # 5. set timesteps
+ self.scheduler.set_timesteps(num_inference_steps)
+ timestep_dtype = next(
+ (input.type for input in self.unet.model.get_inputs() if input.name == "timestep"), "tensor(float)"
+ )
+ timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype]
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, image is None)
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
+
+ # 6. Prepare latent variables
+ latents, init_latents_orig, noise = self.prepare_latents(
+ image,
+ latent_timestep,
+ batch_size * num_images_per_prompt,
+ height,
+ width,
+ dtype,
+ generator,
+ latents,
+ )
+
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 8. Denoising loop
+ for i, t in enumerate(self.progress_bar(timesteps)):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(torch.from_numpy(latent_model_input), t)
+ latent_model_input = latent_model_input.numpy()
+
+ # predict the noise residual
+ noise_pred = self.unet(
+ sample=latent_model_input,
+ timestep=np.array([t], dtype=timestep_dtype),
+ encoder_hidden_states=text_embeddings,
+ )
+ noise_pred = noise_pred[0]
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ scheduler_output = self.scheduler.step(
+ torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs
+ )
+ latents = scheduler_output.prev_sample.numpy()
+
+ if mask is not None:
+ # masking
+ init_latents_proper = self.scheduler.add_noise(
+ torch.from_numpy(init_latents_orig),
+ torch.from_numpy(noise),
+ t,
+ ).numpy()
+ latents = (init_latents_proper * mask) + (latents * (1 - mask))
+
+ # call the callback, if provided
+ if i % callback_steps == 0:
+ if callback is not None:
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
+ if is_cancelled_callback is not None and is_cancelled_callback():
+ return None
+
+ # 9. Post-processing
+ image = self.decode_latents(latents)
+
+ # 10. Run safety checker
+ image, has_nsfw_concept = self.run_safety_checker(image)
+
+ # 11. Convert to PIL
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return image, has_nsfw_concept
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
+
+ def text2img(
+ self,
+ prompt: Union[str, List[str]],
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ height: int = 512,
+ width: int = 512,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[torch.Generator] = None,
+ latents: Optional[np.ndarray] = None,
+ max_embeddings_multiples: Optional[int] = 3,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
+ callback_steps: int = 1,
+ **kwargs,
+ ):
+ r"""
+ Function for text-to-image generation.
+ Args:
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image generation.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ height (`int`, *optional*, defaults to 512):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to 512):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator`, *optional*):
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
+ deterministic.
+ latents (`np.ndarray`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: np.ndarray)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+ return self.__call__(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ height=height,
+ width=width,
+ num_inference_steps=num_inference_steps,
+ guidance_scale=guidance_scale,
+ num_images_per_prompt=num_images_per_prompt,
+ eta=eta,
+ generator=generator,
+ latents=latents,
+ max_embeddings_multiples=max_embeddings_multiples,
+ output_type=output_type,
+ return_dict=return_dict,
+ callback=callback,
+ callback_steps=callback_steps,
+ **kwargs,
+ )
+
+ def img2img(
+ self,
+ image: Union[np.ndarray, PIL.Image.Image],
+ prompt: Union[str, List[str]],
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ strength: float = 0.8,
+ num_inference_steps: Optional[int] = 50,
+ guidance_scale: Optional[float] = 7.5,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: Optional[float] = 0.0,
+ generator: Optional[torch.Generator] = None,
+ max_embeddings_multiples: Optional[int] = 3,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
+ callback_steps: int = 1,
+ **kwargs,
+ ):
+ r"""
+ Function for image-to-image generation.
+ Args:
+ image (`np.ndarray` or `PIL.Image.Image`):
+ `Image`, or ndarray representing an image batch, that will be used as the starting point for the
+ process.
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image generation.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ strength (`float`, *optional*, defaults to 0.8):
+ Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1.
+ `image` will be used as a starting point, adding more noise to it the larger the `strength`. The
+ number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added
+ noise will be maximum and the denoising process will run for the full number of iterations specified in
+ `num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference. This parameter will be modulated by `strength`.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator`, *optional*):
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
+ deterministic.
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: np.ndarray)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+ return self.__call__(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ image=image,
+ num_inference_steps=num_inference_steps,
+ guidance_scale=guidance_scale,
+ strength=strength,
+ num_images_per_prompt=num_images_per_prompt,
+ eta=eta,
+ generator=generator,
+ max_embeddings_multiples=max_embeddings_multiples,
+ output_type=output_type,
+ return_dict=return_dict,
+ callback=callback,
+ callback_steps=callback_steps,
+ **kwargs,
+ )
+
+ def inpaint(
+ self,
+ image: Union[np.ndarray, PIL.Image.Image],
+ mask_image: Union[np.ndarray, PIL.Image.Image],
+ prompt: Union[str, List[str]],
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ strength: float = 0.8,
+ num_inference_steps: Optional[int] = 50,
+ guidance_scale: Optional[float] = 7.5,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: Optional[float] = 0.0,
+ generator: Optional[torch.Generator] = None,
+ max_embeddings_multiples: Optional[int] = 3,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
+ callback_steps: int = 1,
+ **kwargs,
+ ):
+ r"""
+ Function for inpaint.
+ Args:
+ image (`np.ndarray` or `PIL.Image.Image`):
+ `Image`, or tensor representing an image batch, that will be used as the starting point for the
+ process. This is the image whose masked region will be inpainted.
+ mask_image (`np.ndarray` or `PIL.Image.Image`):
+ `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
+ replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
+ PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
+ contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image generation.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ strength (`float`, *optional*, defaults to 0.8):
+ Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength`
+ is 1, the denoising process will be run on the masked area for the full number of iterations specified
+ in `num_inference_steps`. `image` will be used as a reference for the masked area, adding more
+ noise to that region the larger the `strength`. If `strength` is 0, no inpainting will occur.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The reference number of denoising steps. More denoising steps usually lead to a higher quality image at
+ the expense of slower inference. This parameter will be modulated by `strength`, as explained above.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator`, *optional*):
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
+ deterministic.
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: np.ndarray)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+ return self.__call__(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ image=image,
+ mask_image=mask_image,
+ num_inference_steps=num_inference_steps,
+ guidance_scale=guidance_scale,
+ strength=strength,
+ num_images_per_prompt=num_images_per_prompt,
+ eta=eta,
+ generator=generator,
+ max_embeddings_multiples=max_embeddings_multiples,
+ output_type=output_type,
+ return_dict=return_dict,
+ callback=callback,
+ callback_steps=callback_steps,
+ **kwargs,
+ )
diff --git a/diffusers/examples/community/lpw_stable_diffusion_xl.py b/diffusers/examples/community/lpw_stable_diffusion_xl.py
new file mode 100644
index 0000000000000000000000000000000000000000..13d1e2a1156aed7a877cdaf77e00e97d97f9454e
--- /dev/null
+++ b/diffusers/examples/community/lpw_stable_diffusion_xl.py
@@ -0,0 +1,2243 @@
+## ----------------------------------------------------------
+# A SDXL pipeline can take unlimited weighted prompt
+#
+# Author: Andrew Zhu
+# GitHub: https://github.com/xhinker
+# Medium: https://medium.com/@xhinker
+## -----------------------------------------------------------
+
+import inspect
+import os
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import torch
+from PIL import Image
+from transformers import (
+ CLIPImageProcessor,
+ CLIPTextModel,
+ CLIPTextModelWithProjection,
+ CLIPTokenizer,
+ CLIPVisionModelWithProjection,
+)
+
+from diffusers import DiffusionPipeline, StableDiffusionXLPipeline
+from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
+from diffusers.loaders import (
+ FromSingleFileMixin,
+ IPAdapterMixin,
+ StableDiffusionXLLoraLoaderMixin,
+ TextualInversionLoaderMixin,
+)
+from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel
+from diffusers.models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor
+from diffusers.models.lora import adjust_lora_scale_text_encoder
+from diffusers.pipelines.pipeline_utils import StableDiffusionMixin
+from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
+from diffusers.schedulers import KarrasDiffusionSchedulers
+from diffusers.utils import (
+ USE_PEFT_BACKEND,
+ deprecate,
+ is_accelerate_available,
+ is_accelerate_version,
+ is_invisible_watermark_available,
+ logging,
+ replace_example_docstring,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
+from diffusers.utils.torch_utils import randn_tensor
+
+
+if is_invisible_watermark_available():
+ from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
+
+
+def parse_prompt_attention(text):
+ """
+ Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
+ Accepted tokens are:
+ (abc) - increases attention to abc by a multiplier of 1.1
+ (abc:3.12) - increases attention to abc by a multiplier of 3.12
+ [abc] - decreases attention to abc by a multiplier of 1.1
+ \\( - literal character '('
+ \\[ - literal character '['
+ \\) - literal character ')'
+ \\] - literal character ']'
+ \\ - literal character '\'
+ anything else - just text
+
+ >>> parse_prompt_attention('normal text')
+ [['normal text', 1.0]]
+ >>> parse_prompt_attention('an (important) word')
+ [['an ', 1.0], ['important', 1.1], [' word', 1.0]]
+ >>> parse_prompt_attention('(unbalanced')
+ [['unbalanced', 1.1]]
+ >>> parse_prompt_attention('\\(literal\\]')
+ [['(literal]', 1.0]]
+ >>> parse_prompt_attention('(unnecessary)(parens)')
+ [['unnecessaryparens', 1.1]]
+ >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
+ [['a ', 1.0],
+ ['house', 1.5730000000000004],
+ [' ', 1.1],
+ ['on', 1.0],
+ [' a ', 1.1],
+ ['hill', 0.55],
+ [', sun, ', 1.1],
+ ['sky', 1.4641000000000006],
+ ['.', 1.1]]
+ """
+ import re
+
+ re_attention = re.compile(
+ r"""
+ \\\(|\\\)|\\\[|\\]|\\\\|\\|\(|\[|:([+-]?[.\d]+)\)|
+ \)|]|[^\\()\[\]:]+|:
+ """,
+ re.X,
+ )
+
+ re_break = re.compile(r"\s*\bBREAK\b\s*", re.S)
+
+ res = []
+ round_brackets = []
+ square_brackets = []
+
+ round_bracket_multiplier = 1.1
+ square_bracket_multiplier = 1 / 1.1
+
+ def multiply_range(start_position, multiplier):
+ for p in range(start_position, len(res)):
+ res[p][1] *= multiplier
+
+ for m in re_attention.finditer(text):
+ text = m.group(0)
+ weight = m.group(1)
+
+ if text.startswith("\\"):
+ res.append([text[1:], 1.0])
+ elif text == "(":
+ round_brackets.append(len(res))
+ elif text == "[":
+ square_brackets.append(len(res))
+ elif weight is not None and len(round_brackets) > 0:
+ multiply_range(round_brackets.pop(), float(weight))
+ elif text == ")" and len(round_brackets) > 0:
+ multiply_range(round_brackets.pop(), round_bracket_multiplier)
+ elif text == "]" and len(square_brackets) > 0:
+ multiply_range(square_brackets.pop(), square_bracket_multiplier)
+ else:
+ parts = re.split(re_break, text)
+ for i, part in enumerate(parts):
+ if i > 0:
+ res.append(["BREAK", -1])
+ res.append([part, 1.0])
+
+ for pos in round_brackets:
+ multiply_range(pos, round_bracket_multiplier)
+
+ for pos in square_brackets:
+ multiply_range(pos, square_bracket_multiplier)
+
+ if len(res) == 0:
+ res = [["", 1.0]]
+
+ # merge runs of identical weights
+ i = 0
+ while i + 1 < len(res):
+ if res[i][1] == res[i + 1][1]:
+ res[i][0] += res[i + 1][0]
+ res.pop(i + 1)
+ else:
+ i += 1
+
+ return res
+
+
+def get_prompts_tokens_with_weights(clip_tokenizer: CLIPTokenizer, prompt: str):
+ """
+ Get prompt token ids and weights, this function works for both prompt and negative prompt
+
+ Args:
+ pipe (CLIPTokenizer)
+ A CLIPTokenizer
+ prompt (str)
+ A prompt string with weights
+
+ Returns:
+ text_tokens (list)
+ A list contains token ids
+ text_weight (list)
+ A list contains the correspondent weight of token ids
+
+ Example:
+ import torch
+ from transformers import CLIPTokenizer
+
+ clip_tokenizer = CLIPTokenizer.from_pretrained(
+ "stablediffusionapi/deliberate-v2"
+ , subfolder = "tokenizer"
+ , dtype = torch.float16
+ )
+
+ token_id_list, token_weight_list = get_prompts_tokens_with_weights(
+ clip_tokenizer = clip_tokenizer
+ ,prompt = "a (red:1.5) cat"*70
+ )
+ """
+ texts_and_weights = parse_prompt_attention(prompt)
+ text_tokens, text_weights = [], []
+ for word, weight in texts_and_weights:
+ # tokenize and discard the starting and the ending token
+ token = clip_tokenizer(word, truncation=False).input_ids[1:-1] # so that tokenize whatever length prompt
+ # the returned token is a 1d list: [320, 1125, 539, 320]
+
+ # merge the new tokens to the all tokens holder: text_tokens
+ text_tokens = [*text_tokens, *token]
+
+ # each token chunk will come with one weight, like ['red cat', 2.0]
+ # need to expand weight for each token.
+ chunk_weights = [weight] * len(token)
+
+ # append the weight back to the weight holder: text_weights
+ text_weights = [*text_weights, *chunk_weights]
+ return text_tokens, text_weights
+
+
+def group_tokens_and_weights(token_ids: list, weights: list, pad_last_block=False):
+ """
+ Produce tokens and weights in groups and pad the missing tokens
+
+ Args:
+ token_ids (list)
+ The token ids from tokenizer
+ weights (list)
+ The weights list from function get_prompts_tokens_with_weights
+ pad_last_block (bool)
+ Control if fill the last token list to 75 tokens with eos
+ Returns:
+ new_token_ids (2d list)
+ new_weights (2d list)
+
+ Example:
+ token_groups,weight_groups = group_tokens_and_weights(
+ token_ids = token_id_list
+ , weights = token_weight_list
+ )
+ """
+ bos, eos = 49406, 49407
+
+ # this will be a 2d list
+ new_token_ids = []
+ new_weights = []
+ while len(token_ids) >= 75:
+ # get the first 75 tokens
+ head_75_tokens = [token_ids.pop(0) for _ in range(75)]
+ head_75_weights = [weights.pop(0) for _ in range(75)]
+
+ # extract token ids and weights
+ temp_77_token_ids = [bos] + head_75_tokens + [eos]
+ temp_77_weights = [1.0] + head_75_weights + [1.0]
+
+ # add 77 token and weights chunk to the holder list
+ new_token_ids.append(temp_77_token_ids)
+ new_weights.append(temp_77_weights)
+
+ # padding the left
+ if len(token_ids) > 0:
+ padding_len = 75 - len(token_ids) if pad_last_block else 0
+
+ temp_77_token_ids = [bos] + token_ids + [eos] * padding_len + [eos]
+ new_token_ids.append(temp_77_token_ids)
+
+ temp_77_weights = [1.0] + weights + [1.0] * padding_len + [1.0]
+ new_weights.append(temp_77_weights)
+
+ return new_token_ids, new_weights
+
+
+def get_weighted_text_embeddings_sdxl(
+ pipe: StableDiffusionXLPipeline,
+ prompt: str = "",
+ prompt_2: str = None,
+ neg_prompt: str = "",
+ neg_prompt_2: str = None,
+ num_images_per_prompt: int = 1,
+ device: Optional[torch.device] = None,
+ clip_skip: Optional[int] = None,
+ lora_scale: Optional[int] = None,
+):
+ """
+ This function can process long prompt with weights, no length limitation
+ for Stable Diffusion XL
+
+ Args:
+ pipe (StableDiffusionPipeline)
+ prompt (str)
+ prompt_2 (str)
+ neg_prompt (str)
+ neg_prompt_2 (str)
+ num_images_per_prompt (int)
+ device (torch.device)
+ clip_skip (int)
+ Returns:
+ prompt_embeds (torch.Tensor)
+ neg_prompt_embeds (torch.Tensor)
+ """
+ device = device or pipe._execution_device
+
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(pipe, StableDiffusionXLLoraLoaderMixin):
+ pipe._lora_scale = lora_scale
+
+ # dynamically adjust the LoRA scale
+ if pipe.text_encoder is not None:
+ if not USE_PEFT_BACKEND:
+ adjust_lora_scale_text_encoder(pipe.text_encoder, lora_scale)
+ else:
+ scale_lora_layers(pipe.text_encoder, lora_scale)
+
+ if pipe.text_encoder_2 is not None:
+ if not USE_PEFT_BACKEND:
+ adjust_lora_scale_text_encoder(pipe.text_encoder_2, lora_scale)
+ else:
+ scale_lora_layers(pipe.text_encoder_2, lora_scale)
+
+ if prompt_2:
+ prompt = f"{prompt} {prompt_2}"
+
+ if neg_prompt_2:
+ neg_prompt = f"{neg_prompt} {neg_prompt_2}"
+
+ prompt_t1 = prompt_t2 = prompt
+ neg_prompt_t1 = neg_prompt_t2 = neg_prompt
+
+ if isinstance(pipe, TextualInversionLoaderMixin):
+ prompt_t1 = pipe.maybe_convert_prompt(prompt_t1, pipe.tokenizer)
+ neg_prompt_t1 = pipe.maybe_convert_prompt(neg_prompt_t1, pipe.tokenizer)
+ prompt_t2 = pipe.maybe_convert_prompt(prompt_t2, pipe.tokenizer_2)
+ neg_prompt_t2 = pipe.maybe_convert_prompt(neg_prompt_t2, pipe.tokenizer_2)
+
+ eos = pipe.tokenizer.eos_token_id
+
+ # tokenizer 1
+ prompt_tokens, prompt_weights = get_prompts_tokens_with_weights(pipe.tokenizer, prompt_t1)
+ neg_prompt_tokens, neg_prompt_weights = get_prompts_tokens_with_weights(pipe.tokenizer, neg_prompt_t1)
+
+ # tokenizer 2
+ prompt_tokens_2, prompt_weights_2 = get_prompts_tokens_with_weights(pipe.tokenizer_2, prompt_t2)
+ neg_prompt_tokens_2, neg_prompt_weights_2 = get_prompts_tokens_with_weights(pipe.tokenizer_2, neg_prompt_t2)
+
+ # padding the shorter one for prompt set 1
+ prompt_token_len = len(prompt_tokens)
+ neg_prompt_token_len = len(neg_prompt_tokens)
+
+ if prompt_token_len > neg_prompt_token_len:
+ # padding the neg_prompt with eos token
+ neg_prompt_tokens = neg_prompt_tokens + [eos] * abs(prompt_token_len - neg_prompt_token_len)
+ neg_prompt_weights = neg_prompt_weights + [1.0] * abs(prompt_token_len - neg_prompt_token_len)
+ else:
+ # padding the prompt
+ prompt_tokens = prompt_tokens + [eos] * abs(prompt_token_len - neg_prompt_token_len)
+ prompt_weights = prompt_weights + [1.0] * abs(prompt_token_len - neg_prompt_token_len)
+
+ # padding the shorter one for token set 2
+ prompt_token_len_2 = len(prompt_tokens_2)
+ neg_prompt_token_len_2 = len(neg_prompt_tokens_2)
+
+ if prompt_token_len_2 > neg_prompt_token_len_2:
+ # padding the neg_prompt with eos token
+ neg_prompt_tokens_2 = neg_prompt_tokens_2 + [eos] * abs(prompt_token_len_2 - neg_prompt_token_len_2)
+ neg_prompt_weights_2 = neg_prompt_weights_2 + [1.0] * abs(prompt_token_len_2 - neg_prompt_token_len_2)
+ else:
+ # padding the prompt
+ prompt_tokens_2 = prompt_tokens_2 + [eos] * abs(prompt_token_len_2 - neg_prompt_token_len_2)
+ prompt_weights_2 = prompt_weights + [1.0] * abs(prompt_token_len_2 - neg_prompt_token_len_2)
+
+ embeds = []
+ neg_embeds = []
+
+ prompt_token_groups, prompt_weight_groups = group_tokens_and_weights(prompt_tokens.copy(), prompt_weights.copy())
+
+ neg_prompt_token_groups, neg_prompt_weight_groups = group_tokens_and_weights(
+ neg_prompt_tokens.copy(), neg_prompt_weights.copy()
+ )
+
+ prompt_token_groups_2, prompt_weight_groups_2 = group_tokens_and_weights(
+ prompt_tokens_2.copy(), prompt_weights_2.copy()
+ )
+
+ neg_prompt_token_groups_2, neg_prompt_weight_groups_2 = group_tokens_and_weights(
+ neg_prompt_tokens_2.copy(), neg_prompt_weights_2.copy()
+ )
+
+ # get prompt embeddings one by one is not working.
+ for i in range(len(prompt_token_groups)):
+ # get positive prompt embeddings with weights
+ token_tensor = torch.tensor([prompt_token_groups[i]], dtype=torch.long, device=device)
+ weight_tensor = torch.tensor(prompt_weight_groups[i], dtype=torch.float16, device=device)
+
+ token_tensor_2 = torch.tensor([prompt_token_groups_2[i]], dtype=torch.long, device=device)
+
+ # use first text encoder
+ prompt_embeds_1 = pipe.text_encoder(token_tensor.to(device), output_hidden_states=True)
+
+ # use second text encoder
+ prompt_embeds_2 = pipe.text_encoder_2(token_tensor_2.to(device), output_hidden_states=True)
+ pooled_prompt_embeds = prompt_embeds_2[0]
+
+ if clip_skip is None:
+ prompt_embeds_1_hidden_states = prompt_embeds_1.hidden_states[-2]
+ prompt_embeds_2_hidden_states = prompt_embeds_2.hidden_states[-2]
+ else:
+ # "2" because SDXL always indexes from the penultimate layer.
+ prompt_embeds_1_hidden_states = prompt_embeds_1.hidden_states[-(clip_skip + 2)]
+ prompt_embeds_2_hidden_states = prompt_embeds_2.hidden_states[-(clip_skip + 2)]
+
+ prompt_embeds_list = [prompt_embeds_1_hidden_states, prompt_embeds_2_hidden_states]
+ token_embedding = torch.concat(prompt_embeds_list, dim=-1).squeeze(0)
+
+ for j in range(len(weight_tensor)):
+ if weight_tensor[j] != 1.0:
+ token_embedding[j] = (
+ token_embedding[-1] + (token_embedding[j] - token_embedding[-1]) * weight_tensor[j]
+ )
+
+ token_embedding = token_embedding.unsqueeze(0)
+ embeds.append(token_embedding)
+
+ # get negative prompt embeddings with weights
+ neg_token_tensor = torch.tensor([neg_prompt_token_groups[i]], dtype=torch.long, device=device)
+ neg_token_tensor_2 = torch.tensor([neg_prompt_token_groups_2[i]], dtype=torch.long, device=device)
+ neg_weight_tensor = torch.tensor(neg_prompt_weight_groups[i], dtype=torch.float16, device=device)
+
+ # use first text encoder
+ neg_prompt_embeds_1 = pipe.text_encoder(neg_token_tensor.to(device), output_hidden_states=True)
+ neg_prompt_embeds_1_hidden_states = neg_prompt_embeds_1.hidden_states[-2]
+
+ # use second text encoder
+ neg_prompt_embeds_2 = pipe.text_encoder_2(neg_token_tensor_2.to(device), output_hidden_states=True)
+ neg_prompt_embeds_2_hidden_states = neg_prompt_embeds_2.hidden_states[-2]
+ negative_pooled_prompt_embeds = neg_prompt_embeds_2[0]
+
+ neg_prompt_embeds_list = [neg_prompt_embeds_1_hidden_states, neg_prompt_embeds_2_hidden_states]
+ neg_token_embedding = torch.concat(neg_prompt_embeds_list, dim=-1).squeeze(0)
+
+ for z in range(len(neg_weight_tensor)):
+ if neg_weight_tensor[z] != 1.0:
+ neg_token_embedding[z] = (
+ neg_token_embedding[-1] + (neg_token_embedding[z] - neg_token_embedding[-1]) * neg_weight_tensor[z]
+ )
+
+ neg_token_embedding = neg_token_embedding.unsqueeze(0)
+ neg_embeds.append(neg_token_embedding)
+
+ prompt_embeds = torch.cat(embeds, dim=1)
+ negative_prompt_embeds = torch.cat(neg_embeds, dim=1)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ seq_len = negative_prompt_embeds.shape[1]
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1).view(
+ bs_embed * num_images_per_prompt, -1
+ )
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1).view(
+ bs_embed * num_images_per_prompt, -1
+ )
+
+ if pipe.text_encoder is not None:
+ if isinstance(pipe, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(pipe.text_encoder, lora_scale)
+
+ if pipe.text_encoder_2 is not None:
+ if isinstance(pipe, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(pipe.text_encoder_2, lora_scale)
+
+ return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
+
+
+# -------------------------------------------------------------------------------------------------------------------------------
+# reuse the backbone code from StableDiffusionXLPipeline
+# -------------------------------------------------------------------------------------------------------------------------------
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ from diffusers import DiffusionPipeline
+ import torch
+
+ pipe = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0"
+ , torch_dtype = torch.float16
+ , use_safetensors = True
+ , variant = "fp16"
+ , custom_pipeline = "lpw_stable_diffusion_xl",
+ )
+
+ prompt = "a white cat running on the grass"*20
+ prompt2 = "play a football"*20
+ prompt = f"{prompt},{prompt2}"
+ neg_prompt = "blur, low quality"
+
+ pipe.to("cuda")
+ images = pipe(
+ prompt = prompt
+ , negative_prompt = neg_prompt
+ ).images[0]
+
+ pipe.to("cpu")
+ torch.cuda.empty_cache()
+ images
+ ```
+"""
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
+def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
+ """
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
+ """
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
+ # rescale the results from guidance (fixes overexposure)
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
+ return noise_cfg
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ **kwargs,
+):
+ """
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used,
+ `timesteps` must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
+ timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
+ must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class SDXLLongPromptWeightingPipeline(
+ DiffusionPipeline,
+ StableDiffusionMixin,
+ FromSingleFileMixin,
+ IPAdapterMixin,
+ StableDiffusionXLLoraLoaderMixin,
+ TextualInversionLoaderMixin,
+):
+ r"""
+ Pipeline for text-to-image generation using Stable Diffusion XL.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
+
+ The pipeline also inherits the following loading methods:
+ - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
+ - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
+ - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
+ - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
+ - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion XL uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ text_encoder_2 ([` CLIPTextModelWithProjection`]):
+ Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
+ specifically the
+ [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
+ variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ tokenizer_2 (`CLIPTokenizer`):
+ Second Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]):
+ Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ feature_extractor ([`~transformers.CLIPImageProcessor`]):
+ A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
+ """
+
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae"
+ _optional_components = [
+ "tokenizer",
+ "tokenizer_2",
+ "text_encoder",
+ "text_encoder_2",
+ "image_encoder",
+ "feature_extractor",
+ ]
+ _callback_tensor_inputs = [
+ "latents",
+ "prompt_embeds",
+ "negative_prompt_embeds",
+ "add_text_embeds",
+ "add_time_ids",
+ "negative_pooled_prompt_embeds",
+ "negative_add_time_ids",
+ ]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ text_encoder_2: CLIPTextModelWithProjection,
+ tokenizer: CLIPTokenizer,
+ tokenizer_2: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: KarrasDiffusionSchedulers,
+ feature_extractor: Optional[CLIPImageProcessor] = None,
+ image_encoder: Optional[CLIPVisionModelWithProjection] = None,
+ force_zeros_for_empty_prompt: bool = True,
+ add_watermarker: Optional[bool] = None,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ text_encoder_2=text_encoder_2,
+ tokenizer=tokenizer,
+ tokenizer_2=tokenizer_2,
+ unet=unet,
+ scheduler=scheduler,
+ feature_extractor=feature_extractor,
+ image_encoder=image_encoder,
+ )
+ self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+ self.mask_processor = VaeImageProcessor(
+ vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
+ )
+ self.default_sample_size = self.unet.config.sample_size
+
+ add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
+
+ if add_watermarker:
+ self.watermark = StableDiffusionXLWatermarker()
+ else:
+ self.watermark = None
+
+ def enable_model_cpu_offload(self, gpu_id=0):
+ r"""
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
+ """
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
+ from accelerate import cpu_offload_with_hook
+ else:
+ raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
+
+ device = torch.device(f"cuda:{gpu_id}")
+
+ if self.device.type != "cpu":
+ self.to("cpu", silence_dtype_warnings=True)
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
+
+ model_sequence = (
+ [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
+ )
+ model_sequence.extend([self.unet, self.vae])
+
+ hook = None
+ for cpu_offloaded_model in model_sequence:
+ _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
+
+ # We'll offload the last model manually.
+ self.final_offload_hook = hook
+
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt: str,
+ prompt_2: Optional[str] = None,
+ device: Optional[torch.device] = None,
+ num_images_per_prompt: int = 1,
+ do_classifier_free_guidance: bool = True,
+ negative_prompt: Optional[str] = None,
+ negative_prompt_2: Optional[str] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ pooled_prompt_embeds: Optional[torch.Tensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
+ lora_scale: Optional[float] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ used in both text-encoders
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ pooled_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
+ input argument.
+ lora_scale (`float`, *optional*):
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ """
+ device = device or self._execution_device
+
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # Define tokenizers and text encoders
+ tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
+ text_encoders = (
+ [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
+ )
+
+ if prompt_embeds is None:
+ prompt_2 = prompt_2 or prompt
+ # textual inversion: process multi-vector tokens if necessary
+ prompt_embeds_list = []
+ prompts = [prompt, prompt_2]
+ for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, tokenizer)
+
+ text_inputs = tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ prompt_embeds = text_encoder(
+ text_input_ids.to(device),
+ output_hidden_states=True,
+ )
+
+ # We are only ALWAYS interested in the pooled output of the final text encoder
+ pooled_prompt_embeds = prompt_embeds[0]
+ prompt_embeds = prompt_embeds.hidden_states[-2]
+
+ prompt_embeds_list.append(prompt_embeds)
+
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
+
+ # get unconditional embeddings for classifier free guidance
+ zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt
+ if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:
+ negative_prompt_embeds = torch.zeros_like(prompt_embeds)
+ negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
+ elif do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = negative_prompt or ""
+ negative_prompt_2 = negative_prompt_2 or negative_prompt
+
+ uncond_tokens: List[str]
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt, negative_prompt_2]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = [negative_prompt, negative_prompt_2]
+
+ negative_prompt_embeds_list = []
+ for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):
+ if isinstance(self, TextualInversionLoaderMixin):
+ negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer)
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = tokenizer(
+ negative_prompt,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ negative_prompt_embeds = text_encoder(
+ uncond_input.input_ids.to(device),
+ output_hidden_states=True,
+ )
+ # We are only ALWAYS interested in the pooled output of the final text encoder
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
+ negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
+
+ negative_prompt_embeds_list.append(negative_prompt_embeds)
+
+ negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
+
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
+ bs_embed * num_images_per_prompt, -1
+ )
+ if do_classifier_free_guidance:
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
+ bs_embed * num_images_per_prompt, -1
+ )
+
+ return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
+ def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
+ dtype = next(self.image_encoder.parameters()).dtype
+
+ if not isinstance(image, torch.Tensor):
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
+
+ image = image.to(device=device, dtype=dtype)
+ if output_hidden_states:
+ image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
+ image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
+ uncond_image_enc_hidden_states = self.image_encoder(
+ torch.zeros_like(image), output_hidden_states=True
+ ).hidden_states[-2]
+ uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
+ num_images_per_prompt, dim=0
+ )
+ return image_enc_hidden_states, uncond_image_enc_hidden_states
+ else:
+ image_embeds = self.image_encoder(image).image_embeds
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+ uncond_image_embeds = torch.zeros_like(image_embeds)
+
+ return image_embeds, uncond_image_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(
+ self,
+ prompt,
+ prompt_2,
+ height,
+ width,
+ strength,
+ callback_steps,
+ negative_prompt=None,
+ negative_prompt_2=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ pooled_prompt_embeds=None,
+ negative_pooled_prompt_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ ):
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if strength < 0 or strength > 1:
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
+
+ if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt_2 is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
+ raise ValueError(
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
+ )
+
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
+ raise ValueError(
+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
+ )
+
+ def get_timesteps(self, num_inference_steps, strength, device, denoising_start=None):
+ # get the original timestep using init_timestep
+ if denoising_start is None:
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
+ t_start = max(num_inference_steps - init_timestep, 0)
+ else:
+ t_start = 0
+
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
+
+ # Strength is irrelevant if we directly request a timestep to start at;
+ # that is, strength is determined by the denoising_start instead.
+ if denoising_start is not None:
+ discrete_timestep_cutoff = int(
+ round(
+ self.scheduler.config.num_train_timesteps
+ - (denoising_start * self.scheduler.config.num_train_timesteps)
+ )
+ )
+
+ num_inference_steps = (timesteps < discrete_timestep_cutoff).sum().item()
+ if self.scheduler.order == 2 and num_inference_steps % 2 == 0:
+ # if the scheduler is a 2nd order scheduler we might have to do +1
+ # because `num_inference_steps` might be even given that every timestep
+ # (except the highest one) is duplicated. If `num_inference_steps` is even it would
+ # mean that we cut the timesteps in the middle of the denoising step
+ # (between 1st and 2nd derivative) which leads to incorrect results. By adding 1
+ # we ensure that the denoising process always ends after the 2nd derivate step of the scheduler
+ num_inference_steps = num_inference_steps + 1
+
+ # because t_n+1 >= t_n, we slice the timesteps starting from the end
+ timesteps = timesteps[-num_inference_steps:]
+ return timesteps, num_inference_steps
+
+ return timesteps, num_inference_steps - t_start
+
+ def prepare_latents(
+ self,
+ image,
+ mask,
+ width,
+ height,
+ num_channels_latents,
+ timestep,
+ batch_size,
+ num_images_per_prompt,
+ dtype,
+ device,
+ generator=None,
+ add_noise=True,
+ latents=None,
+ is_strength_max=True,
+ return_noise=False,
+ return_image_latents=False,
+ ):
+ batch_size *= num_images_per_prompt
+
+ if image is None:
+ shape = (
+ batch_size,
+ num_channels_latents,
+ int(height) // self.vae_scale_factor,
+ int(width) // self.vae_scale_factor,
+ )
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ elif mask is None:
+ if not isinstance(image, (torch.Tensor, Image.Image, list)):
+ raise ValueError(
+ f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
+ )
+
+ # Offload text encoder if `enable_model_cpu_offload` was enabled
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.text_encoder_2.to("cpu")
+ torch.cuda.empty_cache()
+
+ image = image.to(device=device, dtype=dtype)
+
+ if image.shape[1] == 4:
+ init_latents = image
+
+ else:
+ # make sure the VAE is in float32 mode, as it overflows in float16
+ if self.vae.config.force_upcast:
+ image = image.float()
+ self.vae.to(dtype=torch.float32)
+
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ elif isinstance(generator, list):
+ init_latents = [
+ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
+ for i in range(batch_size)
+ ]
+ init_latents = torch.cat(init_latents, dim=0)
+ else:
+ init_latents = retrieve_latents(self.vae.encode(image), generator=generator)
+
+ if self.vae.config.force_upcast:
+ self.vae.to(dtype)
+
+ init_latents = init_latents.to(dtype)
+ init_latents = self.vae.config.scaling_factor * init_latents
+
+ if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
+ # expand init_latents for batch_size
+ additional_image_per_prompt = batch_size // init_latents.shape[0]
+ init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0)
+ elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
+ raise ValueError(
+ f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
+ )
+ else:
+ init_latents = torch.cat([init_latents], dim=0)
+
+ if add_noise:
+ shape = init_latents.shape
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ # get latents
+ init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
+
+ latents = init_latents
+ return latents
+
+ else:
+ shape = (
+ batch_size,
+ num_channels_latents,
+ int(height) // self.vae_scale_factor,
+ int(width) // self.vae_scale_factor,
+ )
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if (image is None or timestep is None) and not is_strength_max:
+ raise ValueError(
+ "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise."
+ "However, either the image or the noise timestep has not been provided."
+ )
+
+ if image.shape[1] == 4:
+ image_latents = image.to(device=device, dtype=dtype)
+ image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1)
+ elif return_image_latents or (latents is None and not is_strength_max):
+ image = image.to(device=device, dtype=dtype)
+ image_latents = self._encode_vae_image(image=image, generator=generator)
+ image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1)
+
+ if latents is None and add_noise:
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ # if strength is 1. then initialise the latents to noise, else initial to image + noise
+ latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep)
+ # if pure noise then scale the initial latents by the Scheduler's init sigma
+ latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents
+ elif add_noise:
+ noise = latents.to(device)
+ latents = noise * self.scheduler.init_noise_sigma
+ else:
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ latents = image_latents.to(device)
+
+ outputs = (latents,)
+
+ if return_noise:
+ outputs += (noise,)
+
+ if return_image_latents:
+ outputs += (image_latents,)
+
+ return outputs
+
+ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
+ dtype = image.dtype
+ if self.vae.config.force_upcast:
+ image = image.float()
+ self.vae.to(dtype=torch.float32)
+
+ if isinstance(generator, list):
+ image_latents = [
+ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
+ for i in range(image.shape[0])
+ ]
+ image_latents = torch.cat(image_latents, dim=0)
+ else:
+ image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
+
+ if self.vae.config.force_upcast:
+ self.vae.to(dtype)
+
+ image_latents = image_latents.to(dtype)
+ image_latents = self.vae.config.scaling_factor * image_latents
+
+ return image_latents
+
+ def prepare_mask_latents(
+ self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
+ ):
+ # resize the mask to latents shape as we concatenate the mask to the latents
+ # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
+ # and half precision
+ mask = torch.nn.functional.interpolate(
+ mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor)
+ )
+ mask = mask.to(device=device, dtype=dtype)
+
+ # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
+ if mask.shape[0] < batch_size:
+ if not batch_size % mask.shape[0] == 0:
+ raise ValueError(
+ "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
+ f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
+ " of masks that you pass is divisible by the total requested batch size."
+ )
+ mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
+
+ mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
+
+ if masked_image is not None and masked_image.shape[1] == 4:
+ masked_image_latents = masked_image
+ else:
+ masked_image_latents = None
+
+ if masked_image is not None:
+ if masked_image_latents is None:
+ masked_image = masked_image.to(device=device, dtype=dtype)
+ masked_image_latents = self._encode_vae_image(masked_image, generator=generator)
+
+ if masked_image_latents.shape[0] < batch_size:
+ if not batch_size % masked_image_latents.shape[0] == 0:
+ raise ValueError(
+ "The passed images and the required batch size don't match. Images are supposed to be duplicated"
+ f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
+ " Make sure the number of images that you pass is divisible by the total requested batch size."
+ )
+ masked_image_latents = masked_image_latents.repeat(
+ batch_size // masked_image_latents.shape[0], 1, 1, 1
+ )
+
+ masked_image_latents = (
+ torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents
+ )
+
+ # aligning device to prevent device errors when concating it with the latent model input
+ masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
+
+ return mask, masked_image_latents
+
+ def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype):
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
+
+ passed_add_embed_dim = (
+ self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim
+ )
+ expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
+
+ if expected_add_embed_dim != passed_add_embed_dim:
+ raise ValueError(
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
+ )
+
+ add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
+ return add_time_ids
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
+ def upcast_vae(self):
+ dtype = self.vae.dtype
+ self.vae.to(dtype=torch.float32)
+ use_torch_2_0_or_xformers = isinstance(
+ self.vae.decoder.mid_block.attentions[0].processor,
+ (AttnProcessor2_0, XFormersAttnProcessor),
+ )
+ # if xformers or torch_2_0 is used attention block does not need
+ # to be in float32 which can save lots of memory
+ if use_torch_2_0_or_xformers:
+ self.vae.post_quant_conv.to(dtype)
+ self.vae.decoder.conv_in.to(dtype)
+ self.vae.decoder.mid_block.to(dtype)
+
+ # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
+ def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
+ """
+ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
+
+ Args:
+ timesteps (`torch.Tensor`):
+ generate embedding vectors at these timesteps
+ embedding_dim (`int`, *optional*, defaults to 512):
+ dimension of the embeddings to generate
+ dtype:
+ data type of the generated embeddings
+
+ Returns:
+ `torch.Tensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
+ """
+ assert len(w.shape) == 1
+ w = w * 1000.0
+
+ half_dim = embedding_dim // 2
+ emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
+ emb = w.to(dtype)[:, None] * emb[None, :]
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+ if embedding_dim % 2 == 1: # zero pad
+ emb = torch.nn.functional.pad(emb, (0, 1))
+ assert emb.shape == (w.shape[0], embedding_dim)
+ return emb
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def guidance_rescale(self):
+ return self._guidance_rescale
+
+ @property
+ def clip_skip(self):
+ return self._clip_skip
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
+
+ @property
+ def cross_attention_kwargs(self):
+ return self._cross_attention_kwargs
+
+ @property
+ def denoising_end(self):
+ return self._denoising_end
+
+ @property
+ def denoising_start(self):
+ return self._denoising_start
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: str = None,
+ prompt_2: Optional[str] = None,
+ image: Optional[PipelineImageInput] = None,
+ mask_image: Optional[PipelineImageInput] = None,
+ masked_image_latents: Optional[torch.Tensor] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ strength: float = 0.8,
+ num_inference_steps: int = 50,
+ timesteps: List[int] = None,
+ denoising_start: Optional[float] = None,
+ denoising_end: Optional[float] = None,
+ guidance_scale: float = 5.0,
+ negative_prompt: Optional[str] = None,
+ negative_prompt_2: Optional[str] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ ip_adapter_image: Optional[PipelineImageInput] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ pooled_prompt_embeds: Optional[torch.Tensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ guidance_rescale: float = 0.0,
+ original_size: Optional[Tuple[int, int]] = None,
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
+ target_size: Optional[Tuple[int, int]] = None,
+ clip_skip: Optional[int] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ **kwargs,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str`):
+ The prompt to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ prompt_2 (`str`):
+ The prompt to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ used in both text-encoders
+ image (`PipelineImageInput`, *optional*):
+ `Image`, or tensor representing an image batch, that will be used as the starting point for the
+ process.
+ mask_image (`PipelineImageInput`, *optional*):
+ `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
+ replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
+ PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
+ contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image.
+ strength (`float`, *optional*, defaults to 0.8):
+ Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1.
+ `image` will be used as a starting point, adding more noise to it the larger the `strength`. The
+ number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added
+ noise will be maximum and the denoising process will run for the full number of iterations specified in
+ `num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
+ passed will be used. Must be in descending order.
+ denoising_start (`float`, *optional*):
+ When specified, indicates the fraction (between 0.0 and 1.0) of the total denoising process to be
+ bypassed before it is initiated. Consequently, the initial part of the denoising process is skipped and
+ it is assumed that the passed `image` is a partly denoised image. Note that when this is specified,
+ strength will be ignored. The `denoising_start` parameter is particularly beneficial when this pipeline
+ is integrated into a "Mixture of Denoisers" multi-pipeline setup, as detailed in [**Refine Image
+ Quality**](https://huggingface.co/docs/diffusers/using-diffusers/sdxl#refine-image-quality).
+ denoising_end (`float`, *optional*):
+ When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
+ completed before it is intentionally prematurely terminated. As a result, the returned sample will
+ still retain a substantial amount of noise (ca. final 20% of timesteps still needed) and should be
+ denoised by a successor pipeline that has `denoising_start` set to 0.8 so that it only denoises the
+ final 20% of the scheduler. The denoising_end parameter should ideally be utilized when this pipeline
+ forms a part of a "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refine Image
+ Quality**](https://huggingface.co/docs/diffusers/using-diffusers/sdxl#refine-image-quality).
+ guidance_scale (`float`, *optional*, defaults to 5.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str`):
+ The prompt not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ negative_prompt_2 (`str`):
+ The prompt not to guide the image generation to be sent to `tokenizer_2` and
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ ip_adapter_image: (`PipelineImageInput`, *optional*):
+ Optional image input to work with IP Adapters.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ pooled_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
+ input argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
+ of a plain tuple.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
+ Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
+ [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
+ Guidance rescale factor should fix overexposure when using zero terminal SNR.
+ original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
+ `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
+ explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
+ `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
+ `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
+ `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ For most cases, `target_size` should be set to the desired height and width of the generated image. If
+ not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in
+ section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
+ """
+
+ callback = kwargs.pop("callback", None)
+ callback_steps = kwargs.pop("callback_steps", None)
+
+ if callback is not None:
+ deprecate(
+ "callback",
+ "1.0.0",
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
+ )
+ if callback_steps is not None:
+ deprecate(
+ "callback_steps",
+ "1.0.0",
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
+ )
+
+ # 0. Default height and width to unet
+ height = height or self.default_sample_size * self.vae_scale_factor
+ width = width or self.default_sample_size * self.vae_scale_factor
+
+ original_size = original_size or (height, width)
+ target_size = target_size or (height, width)
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ prompt_2,
+ height,
+ width,
+ strength,
+ callback_steps,
+ negative_prompt,
+ negative_prompt_2,
+ prompt_embeds,
+ negative_prompt_embeds,
+ pooled_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ callback_on_step_end_tensor_inputs,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._guidance_rescale = guidance_rescale
+ self._clip_skip = clip_skip
+ self._cross_attention_kwargs = cross_attention_kwargs
+ self._denoising_end = denoising_end
+ self._denoising_start = denoising_start
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ if ip_adapter_image is not None:
+ output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
+ image_embeds, negative_image_embeds = self.encode_image(
+ ip_adapter_image, device, num_images_per_prompt, output_hidden_state
+ )
+ if self.do_classifier_free_guidance:
+ image_embeds = torch.cat([negative_image_embeds, image_embeds])
+
+ # 3. Encode input prompt
+ lora_scale = (
+ self._cross_attention_kwargs.get("scale", None) if self._cross_attention_kwargs is not None else None
+ )
+
+ negative_prompt = negative_prompt if negative_prompt is not None else ""
+
+ (
+ prompt_embeds,
+ negative_prompt_embeds,
+ pooled_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ ) = get_weighted_text_embeddings_sdxl(
+ pipe=self,
+ prompt=prompt,
+ neg_prompt=negative_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ clip_skip=clip_skip,
+ lora_scale=lora_scale,
+ )
+ dtype = prompt_embeds.dtype
+
+ if isinstance(image, Image.Image):
+ image = self.image_processor.preprocess(image, height=height, width=width)
+ if image is not None:
+ image = image.to(device=self.device, dtype=dtype)
+
+ if isinstance(mask_image, Image.Image):
+ mask = self.mask_processor.preprocess(mask_image, height=height, width=width)
+ else:
+ mask = mask_image
+ if mask_image is not None:
+ mask = mask.to(device=self.device, dtype=dtype)
+
+ if masked_image_latents is not None:
+ masked_image = masked_image_latents
+ elif image.shape[1] == 4:
+ # if image is in latent space, we can't mask it
+ masked_image = None
+ else:
+ masked_image = image * (mask < 0.5)
+ else:
+ mask = None
+
+ # 4. Prepare timesteps
+ def denoising_value_valid(dnv):
+ return isinstance(dnv, float) and 0 < dnv < 1
+
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
+ if image is not None:
+ timesteps, num_inference_steps = self.get_timesteps(
+ num_inference_steps,
+ strength,
+ device,
+ denoising_start=self.denoising_start if denoising_value_valid(self.denoising_start) else None,
+ )
+
+ # check that number of inference steps is not < 1 - as this doesn't make sense
+ if num_inference_steps < 1:
+ raise ValueError(
+ f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline"
+ f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
+ )
+
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
+ is_strength_max = strength == 1.0
+ add_noise = True if self.denoising_start is None else False
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.vae.config.latent_channels
+ num_channels_unet = self.unet.config.in_channels
+ return_image_latents = num_channels_unet == 4
+
+ latents = self.prepare_latents(
+ image=image,
+ mask=mask,
+ width=width,
+ height=height,
+ num_channels_latents=num_channels_unet,
+ timestep=latent_timestep,
+ batch_size=batch_size,
+ num_images_per_prompt=num_images_per_prompt,
+ dtype=prompt_embeds.dtype,
+ device=device,
+ generator=generator,
+ add_noise=add_noise,
+ latents=latents,
+ is_strength_max=is_strength_max,
+ return_noise=True,
+ return_image_latents=return_image_latents,
+ )
+
+ if mask is not None:
+ if return_image_latents:
+ latents, noise, image_latents = latents
+ else:
+ latents, noise = latents
+
+ # 5.1 Prepare mask latent variables
+ if mask is not None:
+ mask, masked_image_latents = self.prepare_mask_latents(
+ mask=mask,
+ masked_image=masked_image,
+ batch_size=batch_size * num_images_per_prompt,
+ height=height,
+ width=width,
+ dtype=prompt_embeds.dtype,
+ device=device,
+ generator=generator,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ )
+
+ # Check that sizes of mask, masked image and latents match
+ if num_channels_unet == 9:
+ # default case for runwayml/stable-diffusion-inpainting
+ num_channels_mask = mask.shape[1]
+ num_channels_masked_image = masked_image_latents.shape[1]
+ if num_channels_latents + num_channels_mask + num_channels_masked_image != num_channels_unet:
+ raise ValueError(
+ f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
+ f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
+ f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
+ f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
+ " `pipeline.unet` or your `mask_image` or `image` input."
+ )
+ elif num_channels_unet != 4:
+ raise ValueError(
+ f"The unet {self.unet.__class__} should have either 4 or 9 input channels, not {self.unet.config.in_channels}."
+ )
+
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 6.1 Add image embeds for IP-Adapter
+ added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else {}
+
+ height, width = latents.shape[-2:]
+ height = height * self.vae_scale_factor
+ width = width * self.vae_scale_factor
+
+ original_size = original_size or (height, width)
+ target_size = target_size or (height, width)
+
+ # 7. Prepare added time ids & embeddings
+ add_text_embeds = pooled_prompt_embeds
+ add_time_ids = self._get_add_time_ids(
+ original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype
+ )
+
+ if self.do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
+ add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)
+
+ prompt_embeds = prompt_embeds.to(device)
+ add_text_embeds = add_text_embeds.to(device)
+ add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
+
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+
+ # 7.1 Apply denoising_end
+ if (
+ self.denoising_end is not None
+ and self.denoising_start is not None
+ and denoising_value_valid(self.denoising_end)
+ and denoising_value_valid(self.denoising_start)
+ and self.denoising_start >= self.denoising_end
+ ):
+ raise ValueError(
+ f"`denoising_start`: {self.denoising_start} cannot be larger than or equal to `denoising_end`: "
+ + f" {self.denoising_end} when using type float."
+ )
+ elif self.denoising_end is not None and denoising_value_valid(self.denoising_end):
+ discrete_timestep_cutoff = int(
+ round(
+ self.scheduler.config.num_train_timesteps
+ - (self.denoising_end * self.scheduler.config.num_train_timesteps)
+ )
+ )
+ num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
+ timesteps = timesteps[:num_inference_steps]
+
+ # 8. Optionally get Guidance Scale Embedding
+ timestep_cond = None
+ if self.unet.config.time_cond_proj_dim is not None:
+ guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
+ timestep_cond = self.get_guidance_scale_embedding(
+ guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
+ ).to(device=device, dtype=latents.dtype)
+
+ self._num_timesteps = len(timesteps)
+
+ # 9. Denoising loop
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ if mask is not None and num_channels_unet == 9:
+ latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
+
+ # predict the noise residual
+ added_cond_kwargs.update({"text_embeds": add_text_embeds, "time_ids": add_time_ids})
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ timestep_cond=timestep_cond,
+ cross_attention_kwargs=self.cross_attention_kwargs,
+ added_cond_kwargs=added_cond_kwargs,
+ return_dict=False,
+ )[0]
+
+ # perform guidance
+ if self.do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ if self.do_classifier_free_guidance and guidance_rescale > 0.0:
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+ if mask is not None and num_channels_unet == 4:
+ init_latents_proper = image_latents
+
+ if self.do_classifier_free_guidance:
+ init_mask, _ = mask.chunk(2)
+ else:
+ init_mask = mask
+
+ if i < len(timesteps) - 1:
+ noise_timestep = timesteps[i + 1]
+ init_latents_proper = self.scheduler.add_noise(
+ init_latents_proper, noise, torch.tensor([noise_timestep])
+ )
+
+ latents = (1 - init_mask) * init_latents_proper + init_mask * latents
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+ add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
+ negative_pooled_prompt_embeds = callback_outputs.pop(
+ "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
+ )
+ add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
+
+ if not output_type == "latent":
+ # make sure the VAE is in float32 mode, as it overflows in float16
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
+
+ if needs_upcasting:
+ self.upcast_vae()
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
+
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+
+ # cast back to fp16 if needed
+ if needs_upcasting:
+ self.vae.to(dtype=torch.float16)
+ else:
+ image = latents
+ return StableDiffusionXLPipelineOutput(images=image)
+
+ # apply watermark if available
+ if self.watermark is not None:
+ image = self.watermark.apply_watermark(image)
+
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload last model to CPU
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.final_offload_hook.offload()
+
+ if not return_dict:
+ return (image,)
+
+ return StableDiffusionXLPipelineOutput(images=image)
+
+ def text2img(
+ self,
+ prompt: str = None,
+ prompt_2: Optional[str] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ timesteps: List[int] = None,
+ denoising_start: Optional[float] = None,
+ denoising_end: Optional[float] = None,
+ guidance_scale: float = 5.0,
+ negative_prompt: Optional[str] = None,
+ negative_prompt_2: Optional[str] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ ip_adapter_image: Optional[PipelineImageInput] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ pooled_prompt_embeds: Optional[torch.Tensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ guidance_rescale: float = 0.0,
+ original_size: Optional[Tuple[int, int]] = None,
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
+ target_size: Optional[Tuple[int, int]] = None,
+ clip_skip: Optional[int] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ **kwargs,
+ ):
+ r"""
+ Function invoked when calling pipeline for text-to-image.
+
+ Refer to the documentation of the `__call__` method for parameter descriptions.
+ """
+ return self.__call__(
+ prompt=prompt,
+ prompt_2=prompt_2,
+ height=height,
+ width=width,
+ num_inference_steps=num_inference_steps,
+ timesteps=timesteps,
+ denoising_start=denoising_start,
+ denoising_end=denoising_end,
+ guidance_scale=guidance_scale,
+ negative_prompt=negative_prompt,
+ negative_prompt_2=negative_prompt_2,
+ num_images_per_prompt=num_images_per_prompt,
+ eta=eta,
+ generator=generator,
+ latents=latents,
+ ip_adapter_image=ip_adapter_image,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ output_type=output_type,
+ return_dict=return_dict,
+ cross_attention_kwargs=cross_attention_kwargs,
+ guidance_rescale=guidance_rescale,
+ original_size=original_size,
+ crops_coords_top_left=crops_coords_top_left,
+ target_size=target_size,
+ clip_skip=clip_skip,
+ callback_on_step_end=callback_on_step_end,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ **kwargs,
+ )
+
+ def img2img(
+ self,
+ prompt: str = None,
+ prompt_2: Optional[str] = None,
+ image: Optional[PipelineImageInput] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ strength: float = 0.8,
+ num_inference_steps: int = 50,
+ timesteps: List[int] = None,
+ denoising_start: Optional[float] = None,
+ denoising_end: Optional[float] = None,
+ guidance_scale: float = 5.0,
+ negative_prompt: Optional[str] = None,
+ negative_prompt_2: Optional[str] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ ip_adapter_image: Optional[PipelineImageInput] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ pooled_prompt_embeds: Optional[torch.Tensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ guidance_rescale: float = 0.0,
+ original_size: Optional[Tuple[int, int]] = None,
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
+ target_size: Optional[Tuple[int, int]] = None,
+ clip_skip: Optional[int] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ **kwargs,
+ ):
+ r"""
+ Function invoked when calling pipeline for image-to-image.
+
+ Refer to the documentation of the `__call__` method for parameter descriptions.
+ """
+ return self.__call__(
+ prompt=prompt,
+ prompt_2=prompt_2,
+ image=image,
+ height=height,
+ width=width,
+ strength=strength,
+ num_inference_steps=num_inference_steps,
+ timesteps=timesteps,
+ denoising_start=denoising_start,
+ denoising_end=denoising_end,
+ guidance_scale=guidance_scale,
+ negative_prompt=negative_prompt,
+ negative_prompt_2=negative_prompt_2,
+ num_images_per_prompt=num_images_per_prompt,
+ eta=eta,
+ generator=generator,
+ latents=latents,
+ ip_adapter_image=ip_adapter_image,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ output_type=output_type,
+ return_dict=return_dict,
+ cross_attention_kwargs=cross_attention_kwargs,
+ guidance_rescale=guidance_rescale,
+ original_size=original_size,
+ crops_coords_top_left=crops_coords_top_left,
+ target_size=target_size,
+ clip_skip=clip_skip,
+ callback_on_step_end=callback_on_step_end,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ **kwargs,
+ )
+
+ def inpaint(
+ self,
+ prompt: str = None,
+ prompt_2: Optional[str] = None,
+ image: Optional[PipelineImageInput] = None,
+ mask_image: Optional[PipelineImageInput] = None,
+ masked_image_latents: Optional[torch.Tensor] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ strength: float = 0.8,
+ num_inference_steps: int = 50,
+ timesteps: List[int] = None,
+ denoising_start: Optional[float] = None,
+ denoising_end: Optional[float] = None,
+ guidance_scale: float = 5.0,
+ negative_prompt: Optional[str] = None,
+ negative_prompt_2: Optional[str] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ ip_adapter_image: Optional[PipelineImageInput] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ pooled_prompt_embeds: Optional[torch.Tensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ guidance_rescale: float = 0.0,
+ original_size: Optional[Tuple[int, int]] = None,
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
+ target_size: Optional[Tuple[int, int]] = None,
+ clip_skip: Optional[int] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ **kwargs,
+ ):
+ r"""
+ Function invoked when calling pipeline for inpainting.
+
+ Refer to the documentation of the `__call__` method for parameter descriptions.
+ """
+ return self.__call__(
+ prompt=prompt,
+ prompt_2=prompt_2,
+ image=image,
+ mask_image=mask_image,
+ masked_image_latents=masked_image_latents,
+ height=height,
+ width=width,
+ strength=strength,
+ num_inference_steps=num_inference_steps,
+ timesteps=timesteps,
+ denoising_start=denoising_start,
+ denoising_end=denoising_end,
+ guidance_scale=guidance_scale,
+ negative_prompt=negative_prompt,
+ negative_prompt_2=negative_prompt_2,
+ num_images_per_prompt=num_images_per_prompt,
+ eta=eta,
+ generator=generator,
+ latents=latents,
+ ip_adapter_image=ip_adapter_image,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ output_type=output_type,
+ return_dict=return_dict,
+ cross_attention_kwargs=cross_attention_kwargs,
+ guidance_rescale=guidance_rescale,
+ original_size=original_size,
+ crops_coords_top_left=crops_coords_top_left,
+ target_size=target_size,
+ clip_skip=clip_skip,
+ callback_on_step_end=callback_on_step_end,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ **kwargs,
+ )
+
+ # Override to properly handle the loading and unloading of the additional text encoder.
+ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
+ # We could have accessed the unet config from `lora_state_dict()` too. We pass
+ # it here explicitly to be able to tell that it's coming from an SDXL
+ # pipeline.
+ state_dict, network_alphas = self.lora_state_dict(
+ pretrained_model_name_or_path_or_dict,
+ unet_config=self.unet.config,
+ **kwargs,
+ )
+ self.load_lora_into_unet(state_dict, network_alphas=network_alphas, unet=self.unet)
+
+ text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
+ if len(text_encoder_state_dict) > 0:
+ self.load_lora_into_text_encoder(
+ text_encoder_state_dict,
+ network_alphas=network_alphas,
+ text_encoder=self.text_encoder,
+ prefix="text_encoder",
+ lora_scale=self.lora_scale,
+ )
+
+ text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k}
+ if len(text_encoder_2_state_dict) > 0:
+ self.load_lora_into_text_encoder(
+ text_encoder_2_state_dict,
+ network_alphas=network_alphas,
+ text_encoder=self.text_encoder_2,
+ prefix="text_encoder_2",
+ lora_scale=self.lora_scale,
+ )
+
+ @classmethod
+ def save_lora_weights(
+ cls,
+ save_directory: Union[str, os.PathLike],
+ unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
+ text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
+ text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
+ is_main_process: bool = True,
+ weight_name: str = None,
+ save_function: Callable = None,
+ safe_serialization: bool = False,
+ ):
+ state_dict = {}
+
+ def pack_weights(layers, prefix):
+ layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
+ layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
+ return layers_state_dict
+
+ state_dict.update(pack_weights(unet_lora_layers, "unet"))
+
+ if text_encoder_lora_layers and text_encoder_2_lora_layers:
+ state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder"))
+ state_dict.update(pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
+
+ cls.write_lora_layers(
+ state_dict=state_dict,
+ save_directory=save_directory,
+ is_main_process=is_main_process,
+ weight_name=weight_name,
+ save_function=save_function,
+ safe_serialization=safe_serialization,
+ )
+
+ def _remove_text_encoder_monkey_patch(self):
+ self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder)
+ self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder_2)
diff --git a/diffusers/examples/community/magic_mix.py b/diffusers/examples/community/magic_mix.py
new file mode 100644
index 0000000000000000000000000000000000000000..d3d118f84bfcdcdabad76a074515efec42f20bb7
--- /dev/null
+++ b/diffusers/examples/community/magic_mix.py
@@ -0,0 +1,152 @@
+from typing import Union
+
+import torch
+from PIL import Image
+from torchvision import transforms as tfms
+from tqdm.auto import tqdm
+from transformers import CLIPTextModel, CLIPTokenizer
+
+from diffusers import (
+ AutoencoderKL,
+ DDIMScheduler,
+ DiffusionPipeline,
+ LMSDiscreteScheduler,
+ PNDMScheduler,
+ UNet2DConditionModel,
+)
+
+
+class MagicMixPipeline(DiffusionPipeline):
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: Union[PNDMScheduler, LMSDiscreteScheduler, DDIMScheduler],
+ ):
+ super().__init__()
+
+ self.register_modules(vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
+
+ # convert PIL image to latents
+ def encode(self, img):
+ with torch.no_grad():
+ latent = self.vae.encode(tfms.ToTensor()(img).unsqueeze(0).to(self.device) * 2 - 1)
+ latent = 0.18215 * latent.latent_dist.sample()
+ return latent
+
+ # convert latents to PIL image
+ def decode(self, latent):
+ latent = (1 / 0.18215) * latent
+ with torch.no_grad():
+ img = self.vae.decode(latent).sample
+ img = (img / 2 + 0.5).clamp(0, 1)
+ img = img.detach().cpu().permute(0, 2, 3, 1).numpy()
+ img = (img * 255).round().astype("uint8")
+ return Image.fromarray(img[0])
+
+ # convert prompt into text embeddings, also unconditional embeddings
+ def prep_text(self, prompt):
+ text_input = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ text_embedding = self.text_encoder(text_input.input_ids.to(self.device))[0]
+
+ uncond_input = self.tokenizer(
+ "",
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ uncond_embedding = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
+
+ return torch.cat([uncond_embedding, text_embedding])
+
+ def __call__(
+ self,
+ img: Image.Image,
+ prompt: str,
+ kmin: float = 0.3,
+ kmax: float = 0.6,
+ mix_factor: float = 0.5,
+ seed: int = 42,
+ steps: int = 50,
+ guidance_scale: float = 7.5,
+ ) -> Image.Image:
+ tmin = steps - int(kmin * steps)
+ tmax = steps - int(kmax * steps)
+
+ text_embeddings = self.prep_text(prompt)
+
+ self.scheduler.set_timesteps(steps)
+
+ width, height = img.size
+ encoded = self.encode(img)
+
+ torch.manual_seed(seed)
+ noise = torch.randn(
+ (1, self.unet.config.in_channels, height // 8, width // 8),
+ ).to(self.device)
+
+ latents = self.scheduler.add_noise(
+ encoded,
+ noise,
+ timesteps=self.scheduler.timesteps[tmax],
+ )
+
+ input = torch.cat([latents] * 2)
+
+ input = self.scheduler.scale_model_input(input, self.scheduler.timesteps[tmax])
+
+ with torch.no_grad():
+ pred = self.unet(
+ input,
+ self.scheduler.timesteps[tmax],
+ encoder_hidden_states=text_embeddings,
+ ).sample
+
+ pred_uncond, pred_text = pred.chunk(2)
+ pred = pred_uncond + guidance_scale * (pred_text - pred_uncond)
+
+ latents = self.scheduler.step(pred, self.scheduler.timesteps[tmax], latents).prev_sample
+
+ for i, t in enumerate(tqdm(self.scheduler.timesteps)):
+ if i > tmax:
+ if i < tmin: # layout generation phase
+ orig_latents = self.scheduler.add_noise(
+ encoded,
+ noise,
+ timesteps=t,
+ )
+
+ input = (
+ (mix_factor * latents) + (1 - mix_factor) * orig_latents
+ ) # interpolating between layout noise and conditionally generated noise to preserve layout sematics
+ input = torch.cat([input] * 2)
+
+ else: # content generation phase
+ input = torch.cat([latents] * 2)
+
+ input = self.scheduler.scale_model_input(input, t)
+
+ with torch.no_grad():
+ pred = self.unet(
+ input,
+ t,
+ encoder_hidden_states=text_embeddings,
+ ).sample
+
+ pred_uncond, pred_text = pred.chunk(2)
+ pred = pred_uncond + guidance_scale * (pred_text - pred_uncond)
+
+ latents = self.scheduler.step(pred, t, latents).prev_sample
+
+ return self.decode(latents)
diff --git a/diffusers/examples/community/marigold_depth_estimation.py b/diffusers/examples/community/marigold_depth_estimation.py
new file mode 100644
index 0000000000000000000000000000000000000000..92f01d046ef927edc497d284ee9b8248c8c2c705
--- /dev/null
+++ b/diffusers/examples/community/marigold_depth_estimation.py
@@ -0,0 +1,674 @@
+# Copyright 2024 Bingxin Ke, ETH Zurich and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# --------------------------------------------------------------------------
+# If you find this code useful, we kindly ask you to cite our paper in your work.
+# Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
+# More information about the method can be found at https://marigoldmonodepth.github.io
+# --------------------------------------------------------------------------
+
+
+import logging
+import math
+from typing import Dict, Union
+
+import matplotlib
+import numpy as np
+import torch
+from PIL import Image
+from PIL.Image import Resampling
+from scipy.optimize import minimize
+from torch.utils.data import DataLoader, TensorDataset
+from tqdm.auto import tqdm
+from transformers import CLIPTextModel, CLIPTokenizer
+
+from diffusers import (
+ AutoencoderKL,
+ DDIMScheduler,
+ DiffusionPipeline,
+ LCMScheduler,
+ UNet2DConditionModel,
+)
+from diffusers.utils import BaseOutput, check_min_version
+
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.31.0.dev0")
+
+
+class MarigoldDepthOutput(BaseOutput):
+ """
+ Output class for Marigold monocular depth prediction pipeline.
+
+ Args:
+ depth_np (`np.ndarray`):
+ Predicted depth map, with depth values in the range of [0, 1].
+ depth_colored (`None` or `PIL.Image.Image`):
+ Colorized depth map, with the shape of [3, H, W] and values in [0, 1].
+ uncertainty (`None` or `np.ndarray`):
+ Uncalibrated uncertainty(MAD, median absolute deviation) coming from ensembling.
+ """
+
+ depth_np: np.ndarray
+ depth_colored: Union[None, Image.Image]
+ uncertainty: Union[None, np.ndarray]
+
+
+def get_pil_resample_method(method_str: str) -> Resampling:
+ resample_method_dic = {
+ "bilinear": Resampling.BILINEAR,
+ "bicubic": Resampling.BICUBIC,
+ "nearest": Resampling.NEAREST,
+ }
+ resample_method = resample_method_dic.get(method_str, None)
+ if resample_method is None:
+ raise ValueError(f"Unknown resampling method: {resample_method}")
+ else:
+ return resample_method
+
+
+class MarigoldPipeline(DiffusionPipeline):
+ """
+ Pipeline for monocular depth estimation using Marigold: https://marigoldmonodepth.github.io.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ unet (`UNet2DConditionModel`):
+ Conditional U-Net to denoise the depth latent, conditioned on image latent.
+ vae (`AutoencoderKL`):
+ Variational Auto-Encoder (VAE) Model to encode and decode images and depth maps
+ to and from latent representations.
+ scheduler (`DDIMScheduler`):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents.
+ text_encoder (`CLIPTextModel`):
+ Text-encoder, for empty text embedding.
+ tokenizer (`CLIPTokenizer`):
+ CLIP tokenizer.
+ """
+
+ rgb_latent_scale_factor = 0.18215
+ depth_latent_scale_factor = 0.18215
+
+ def __init__(
+ self,
+ unet: UNet2DConditionModel,
+ vae: AutoencoderKL,
+ scheduler: DDIMScheduler,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ unet=unet,
+ vae=vae,
+ scheduler=scheduler,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ )
+
+ self.empty_text_embed = None
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ input_image: Image,
+ denoising_steps: int = 10,
+ ensemble_size: int = 10,
+ processing_res: int = 768,
+ match_input_res: bool = True,
+ resample_method: str = "bilinear",
+ batch_size: int = 0,
+ seed: Union[int, None] = None,
+ color_map: str = "Spectral",
+ show_progress_bar: bool = True,
+ ensemble_kwargs: Dict = None,
+ ) -> MarigoldDepthOutput:
+ """
+ Function invoked when calling the pipeline.
+
+ Args:
+ input_image (`Image`):
+ Input RGB (or gray-scale) image.
+ processing_res (`int`, *optional*, defaults to `768`):
+ Maximum resolution of processing.
+ If set to 0: will not resize at all.
+ match_input_res (`bool`, *optional*, defaults to `True`):
+ Resize depth prediction to match input resolution.
+ Only valid if `processing_res` > 0.
+ resample_method: (`str`, *optional*, defaults to `bilinear`):
+ Resampling method used to resize images and depth predictions. This can be one of `bilinear`, `bicubic` or `nearest`, defaults to: `bilinear`.
+ denoising_steps (`int`, *optional*, defaults to `10`):
+ Number of diffusion denoising steps (DDIM) during inference.
+ ensemble_size (`int`, *optional*, defaults to `10`):
+ Number of predictions to be ensembled.
+ batch_size (`int`, *optional*, defaults to `0`):
+ Inference batch size, no bigger than `num_ensemble`.
+ If set to 0, the script will automatically decide the proper batch size.
+ seed (`int`, *optional*, defaults to `None`)
+ Reproducibility seed.
+ show_progress_bar (`bool`, *optional*, defaults to `True`):
+ Display a progress bar of diffusion denoising.
+ color_map (`str`, *optional*, defaults to `"Spectral"`, pass `None` to skip colorized depth map generation):
+ Colormap used to colorize the depth map.
+ ensemble_kwargs (`dict`, *optional*, defaults to `None`):
+ Arguments for detailed ensembling settings.
+ Returns:
+ `MarigoldDepthOutput`: Output class for Marigold monocular depth prediction pipeline, including:
+ - **depth_np** (`np.ndarray`) Predicted depth map, with depth values in the range of [0, 1]
+ - **depth_colored** (`PIL.Image.Image`) Colorized depth map, with the shape of [3, H, W] and values in [0, 1], None if `color_map` is `None`
+ - **uncertainty** (`None` or `np.ndarray`) Uncalibrated uncertainty(MAD, median absolute deviation)
+ coming from ensembling. None if `ensemble_size = 1`
+ """
+
+ device = self.device
+ input_size = input_image.size
+
+ if not match_input_res:
+ assert processing_res is not None, "Value error: `resize_output_back` is only valid with "
+ assert processing_res >= 0
+ assert ensemble_size >= 1
+
+ # Check if denoising step is reasonable
+ self._check_inference_step(denoising_steps)
+
+ resample_method: Resampling = get_pil_resample_method(resample_method)
+
+ # ----------------- Image Preprocess -----------------
+ # Resize image
+ if processing_res > 0:
+ input_image = self.resize_max_res(
+ input_image,
+ max_edge_resolution=processing_res,
+ resample_method=resample_method,
+ )
+ # Convert the image to RGB, to 1.remove the alpha channel 2.convert B&W to 3-channel
+ input_image = input_image.convert("RGB")
+ image = np.asarray(input_image)
+
+ # Normalize rgb values
+ rgb = np.transpose(image, (2, 0, 1)) # [H, W, rgb] -> [rgb, H, W]
+ rgb_norm = rgb / 255.0 * 2.0 - 1.0 # [0, 255] -> [-1, 1]
+ rgb_norm = torch.from_numpy(rgb_norm).to(self.dtype)
+ rgb_norm = rgb_norm.to(device)
+ assert rgb_norm.min() >= -1.0 and rgb_norm.max() <= 1.0
+
+ # ----------------- Predicting depth -----------------
+ # Batch repeated input image
+ duplicated_rgb = torch.stack([rgb_norm] * ensemble_size)
+ single_rgb_dataset = TensorDataset(duplicated_rgb)
+ if batch_size > 0:
+ _bs = batch_size
+ else:
+ _bs = self._find_batch_size(
+ ensemble_size=ensemble_size,
+ input_res=max(rgb_norm.shape[1:]),
+ dtype=self.dtype,
+ )
+
+ single_rgb_loader = DataLoader(single_rgb_dataset, batch_size=_bs, shuffle=False)
+
+ # Predict depth maps (batched)
+ depth_pred_ls = []
+ if show_progress_bar:
+ iterable = tqdm(single_rgb_loader, desc=" " * 2 + "Inference batches", leave=False)
+ else:
+ iterable = single_rgb_loader
+ for batch in iterable:
+ (batched_img,) = batch
+ depth_pred_raw = self.single_infer(
+ rgb_in=batched_img,
+ num_inference_steps=denoising_steps,
+ show_pbar=show_progress_bar,
+ seed=seed,
+ )
+ depth_pred_ls.append(depth_pred_raw.detach())
+ depth_preds = torch.concat(depth_pred_ls, dim=0).squeeze()
+ torch.cuda.empty_cache() # clear vram cache for ensembling
+
+ # ----------------- Test-time ensembling -----------------
+ if ensemble_size > 1:
+ depth_pred, pred_uncert = self.ensemble_depths(depth_preds, **(ensemble_kwargs or {}))
+ else:
+ depth_pred = depth_preds
+ pred_uncert = None
+
+ # ----------------- Post processing -----------------
+ # Scale prediction to [0, 1]
+ min_d = torch.min(depth_pred)
+ max_d = torch.max(depth_pred)
+ depth_pred = (depth_pred - min_d) / (max_d - min_d)
+
+ # Convert to numpy
+ depth_pred = depth_pred.cpu().numpy().astype(np.float32)
+
+ # Resize back to original resolution
+ if match_input_res:
+ pred_img = Image.fromarray(depth_pred)
+ pred_img = pred_img.resize(input_size, resample=resample_method)
+ depth_pred = np.asarray(pred_img)
+
+ # Clip output range
+ depth_pred = depth_pred.clip(0, 1)
+
+ # Colorize
+ if color_map is not None:
+ depth_colored = self.colorize_depth_maps(
+ depth_pred, 0, 1, cmap=color_map
+ ).squeeze() # [3, H, W], value in (0, 1)
+ depth_colored = (depth_colored * 255).astype(np.uint8)
+ depth_colored_hwc = self.chw2hwc(depth_colored)
+ depth_colored_img = Image.fromarray(depth_colored_hwc)
+ else:
+ depth_colored_img = None
+
+ return MarigoldDepthOutput(
+ depth_np=depth_pred,
+ depth_colored=depth_colored_img,
+ uncertainty=pred_uncert,
+ )
+
+ def _check_inference_step(self, n_step: int):
+ """
+ Check if denoising step is reasonable
+ Args:
+ n_step (`int`): denoising steps
+ """
+ assert n_step >= 1
+
+ if isinstance(self.scheduler, DDIMScheduler):
+ if n_step < 10:
+ logging.warning(
+ f"Too few denoising steps: {n_step}. Recommended to use the LCM checkpoint for few-step inference."
+ )
+ elif isinstance(self.scheduler, LCMScheduler):
+ if not 1 <= n_step <= 4:
+ logging.warning(f"Non-optimal setting of denoising steps: {n_step}. Recommended setting is 1-4 steps.")
+ else:
+ raise RuntimeError(f"Unsupported scheduler type: {type(self.scheduler)}")
+
+ def _encode_empty_text(self):
+ """
+ Encode text embedding for empty prompt.
+ """
+ prompt = ""
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="do_not_pad",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids.to(self.text_encoder.device)
+ self.empty_text_embed = self.text_encoder(text_input_ids)[0].to(self.dtype)
+
+ @torch.no_grad()
+ def single_infer(
+ self,
+ rgb_in: torch.Tensor,
+ num_inference_steps: int,
+ seed: Union[int, None],
+ show_pbar: bool,
+ ) -> torch.Tensor:
+ """
+ Perform an individual depth prediction without ensembling.
+
+ Args:
+ rgb_in (`torch.Tensor`):
+ Input RGB image.
+ num_inference_steps (`int`):
+ Number of diffusion denoisign steps (DDIM) during inference.
+ show_pbar (`bool`):
+ Display a progress bar of diffusion denoising.
+ Returns:
+ `torch.Tensor`: Predicted depth map.
+ """
+ device = rgb_in.device
+
+ # Set timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps # [T]
+
+ # Encode image
+ rgb_latent = self.encode_rgb(rgb_in)
+
+ # Initial depth map (noise)
+ if seed is None:
+ rand_num_generator = None
+ else:
+ rand_num_generator = torch.Generator(device=device)
+ rand_num_generator.manual_seed(seed)
+ depth_latent = torch.randn(
+ rgb_latent.shape,
+ device=device,
+ dtype=self.dtype,
+ generator=rand_num_generator,
+ ) # [B, 4, h, w]
+
+ # Batched empty text embedding
+ if self.empty_text_embed is None:
+ self._encode_empty_text()
+ batch_empty_text_embed = self.empty_text_embed.repeat((rgb_latent.shape[0], 1, 1)) # [B, 2, 1024]
+
+ # Denoising loop
+ if show_pbar:
+ iterable = tqdm(
+ enumerate(timesteps),
+ total=len(timesteps),
+ leave=False,
+ desc=" " * 4 + "Diffusion denoising",
+ )
+ else:
+ iterable = enumerate(timesteps)
+
+ for i, t in iterable:
+ unet_input = torch.cat([rgb_latent, depth_latent], dim=1) # this order is important
+
+ # predict the noise residual
+ noise_pred = self.unet(unet_input, t, encoder_hidden_states=batch_empty_text_embed).sample # [B, 4, h, w]
+
+ # compute the previous noisy sample x_t -> x_t-1
+ depth_latent = self.scheduler.step(noise_pred, t, depth_latent, generator=rand_num_generator).prev_sample
+
+ depth = self.decode_depth(depth_latent)
+
+ # clip prediction
+ depth = torch.clip(depth, -1.0, 1.0)
+ # shift to [0, 1]
+ depth = (depth + 1.0) / 2.0
+
+ return depth
+
+ def encode_rgb(self, rgb_in: torch.Tensor) -> torch.Tensor:
+ """
+ Encode RGB image into latent.
+
+ Args:
+ rgb_in (`torch.Tensor`):
+ Input RGB image to be encoded.
+
+ Returns:
+ `torch.Tensor`: Image latent.
+ """
+ # encode
+ h = self.vae.encoder(rgb_in)
+ moments = self.vae.quant_conv(h)
+ mean, logvar = torch.chunk(moments, 2, dim=1)
+ # scale latent
+ rgb_latent = mean * self.rgb_latent_scale_factor
+ return rgb_latent
+
+ def decode_depth(self, depth_latent: torch.Tensor) -> torch.Tensor:
+ """
+ Decode depth latent into depth map.
+
+ Args:
+ depth_latent (`torch.Tensor`):
+ Depth latent to be decoded.
+
+ Returns:
+ `torch.Tensor`: Decoded depth map.
+ """
+ # scale latent
+ depth_latent = depth_latent / self.depth_latent_scale_factor
+ # decode
+ z = self.vae.post_quant_conv(depth_latent)
+ stacked = self.vae.decoder(z)
+ # mean of output channels
+ depth_mean = stacked.mean(dim=1, keepdim=True)
+ return depth_mean
+
+ @staticmethod
+ def resize_max_res(img: Image.Image, max_edge_resolution: int, resample_method=Resampling.BILINEAR) -> Image.Image:
+ """
+ Resize image to limit maximum edge length while keeping aspect ratio.
+
+ Args:
+ img (`Image.Image`):
+ Image to be resized.
+ max_edge_resolution (`int`):
+ Maximum edge length (pixel).
+ resample_method (`PIL.Image.Resampling`):
+ Resampling method used to resize images.
+
+ Returns:
+ `Image.Image`: Resized image.
+ """
+ original_width, original_height = img.size
+ downscale_factor = min(max_edge_resolution / original_width, max_edge_resolution / original_height)
+
+ new_width = int(original_width * downscale_factor)
+ new_height = int(original_height * downscale_factor)
+
+ resized_img = img.resize((new_width, new_height), resample=resample_method)
+ return resized_img
+
+ @staticmethod
+ def colorize_depth_maps(depth_map, min_depth, max_depth, cmap="Spectral", valid_mask=None):
+ """
+ Colorize depth maps.
+ """
+ assert len(depth_map.shape) >= 2, "Invalid dimension"
+
+ if isinstance(depth_map, torch.Tensor):
+ depth = depth_map.detach().clone().squeeze().numpy()
+ elif isinstance(depth_map, np.ndarray):
+ depth = depth_map.copy().squeeze()
+ # reshape to [ (B,) H, W ]
+ if depth.ndim < 3:
+ depth = depth[np.newaxis, :, :]
+
+ # colorize
+ cm = matplotlib.colormaps[cmap]
+ depth = ((depth - min_depth) / (max_depth - min_depth)).clip(0, 1)
+ img_colored_np = cm(depth, bytes=False)[:, :, :, 0:3] # value from 0 to 1
+ img_colored_np = np.rollaxis(img_colored_np, 3, 1)
+
+ if valid_mask is not None:
+ if isinstance(depth_map, torch.Tensor):
+ valid_mask = valid_mask.detach().numpy()
+ valid_mask = valid_mask.squeeze() # [H, W] or [B, H, W]
+ if valid_mask.ndim < 3:
+ valid_mask = valid_mask[np.newaxis, np.newaxis, :, :]
+ else:
+ valid_mask = valid_mask[:, np.newaxis, :, :]
+ valid_mask = np.repeat(valid_mask, 3, axis=1)
+ img_colored_np[~valid_mask] = 0
+
+ if isinstance(depth_map, torch.Tensor):
+ img_colored = torch.from_numpy(img_colored_np).float()
+ elif isinstance(depth_map, np.ndarray):
+ img_colored = img_colored_np
+
+ return img_colored
+
+ @staticmethod
+ def chw2hwc(chw):
+ assert 3 == len(chw.shape)
+ if isinstance(chw, torch.Tensor):
+ hwc = torch.permute(chw, (1, 2, 0))
+ elif isinstance(chw, np.ndarray):
+ hwc = np.moveaxis(chw, 0, -1)
+ return hwc
+
+ @staticmethod
+ def _find_batch_size(ensemble_size: int, input_res: int, dtype: torch.dtype) -> int:
+ """
+ Automatically search for suitable operating batch size.
+
+ Args:
+ ensemble_size (`int`):
+ Number of predictions to be ensembled.
+ input_res (`int`):
+ Operating resolution of the input image.
+
+ Returns:
+ `int`: Operating batch size.
+ """
+ # Search table for suggested max. inference batch size
+ bs_search_table = [
+ # tested on A100-PCIE-80GB
+ {"res": 768, "total_vram": 79, "bs": 35, "dtype": torch.float32},
+ {"res": 1024, "total_vram": 79, "bs": 20, "dtype": torch.float32},
+ # tested on A100-PCIE-40GB
+ {"res": 768, "total_vram": 39, "bs": 15, "dtype": torch.float32},
+ {"res": 1024, "total_vram": 39, "bs": 8, "dtype": torch.float32},
+ {"res": 768, "total_vram": 39, "bs": 30, "dtype": torch.float16},
+ {"res": 1024, "total_vram": 39, "bs": 15, "dtype": torch.float16},
+ # tested on RTX3090, RTX4090
+ {"res": 512, "total_vram": 23, "bs": 20, "dtype": torch.float32},
+ {"res": 768, "total_vram": 23, "bs": 7, "dtype": torch.float32},
+ {"res": 1024, "total_vram": 23, "bs": 3, "dtype": torch.float32},
+ {"res": 512, "total_vram": 23, "bs": 40, "dtype": torch.float16},
+ {"res": 768, "total_vram": 23, "bs": 18, "dtype": torch.float16},
+ {"res": 1024, "total_vram": 23, "bs": 10, "dtype": torch.float16},
+ # tested on GTX1080Ti
+ {"res": 512, "total_vram": 10, "bs": 5, "dtype": torch.float32},
+ {"res": 768, "total_vram": 10, "bs": 2, "dtype": torch.float32},
+ {"res": 512, "total_vram": 10, "bs": 10, "dtype": torch.float16},
+ {"res": 768, "total_vram": 10, "bs": 5, "dtype": torch.float16},
+ {"res": 1024, "total_vram": 10, "bs": 3, "dtype": torch.float16},
+ ]
+
+ if not torch.cuda.is_available():
+ return 1
+
+ total_vram = torch.cuda.mem_get_info()[1] / 1024.0**3
+ filtered_bs_search_table = [s for s in bs_search_table if s["dtype"] == dtype]
+ for settings in sorted(
+ filtered_bs_search_table,
+ key=lambda k: (k["res"], -k["total_vram"]),
+ ):
+ if input_res <= settings["res"] and total_vram >= settings["total_vram"]:
+ bs = settings["bs"]
+ if bs > ensemble_size:
+ bs = ensemble_size
+ elif bs > math.ceil(ensemble_size / 2) and bs < ensemble_size:
+ bs = math.ceil(ensemble_size / 2)
+ return bs
+
+ return 1
+
+ @staticmethod
+ def ensemble_depths(
+ input_images: torch.Tensor,
+ regularizer_strength: float = 0.02,
+ max_iter: int = 2,
+ tol: float = 1e-3,
+ reduction: str = "median",
+ max_res: int = None,
+ ):
+ """
+ To ensemble multiple affine-invariant depth images (up to scale and shift),
+ by aligning estimating the scale and shift
+ """
+
+ def inter_distances(tensors: torch.Tensor):
+ """
+ To calculate the distance between each two depth maps.
+ """
+ distances = []
+ for i, j in torch.combinations(torch.arange(tensors.shape[0])):
+ arr1 = tensors[i : i + 1]
+ arr2 = tensors[j : j + 1]
+ distances.append(arr1 - arr2)
+ dist = torch.concatenate(distances, dim=0)
+ return dist
+
+ device = input_images.device
+ dtype = input_images.dtype
+ np_dtype = np.float32
+
+ original_input = input_images.clone()
+ n_img = input_images.shape[0]
+ ori_shape = input_images.shape
+
+ if max_res is not None:
+ scale_factor = torch.min(max_res / torch.tensor(ori_shape[-2:]))
+ if scale_factor < 1:
+ downscaler = torch.nn.Upsample(scale_factor=scale_factor, mode="nearest")
+ input_images = downscaler(torch.from_numpy(input_images)).numpy()
+
+ # init guess
+ _min = np.min(input_images.reshape((n_img, -1)).cpu().numpy(), axis=1)
+ _max = np.max(input_images.reshape((n_img, -1)).cpu().numpy(), axis=1)
+ s_init = 1.0 / (_max - _min).reshape((-1, 1, 1))
+ t_init = (-1 * s_init.flatten() * _min.flatten()).reshape((-1, 1, 1))
+ x = np.concatenate([s_init, t_init]).reshape(-1).astype(np_dtype)
+
+ input_images = input_images.to(device)
+
+ # objective function
+ def closure(x):
+ l = len(x)
+ s = x[: int(l / 2)]
+ t = x[int(l / 2) :]
+ s = torch.from_numpy(s).to(dtype=dtype).to(device)
+ t = torch.from_numpy(t).to(dtype=dtype).to(device)
+
+ transformed_arrays = input_images * s.view((-1, 1, 1)) + t.view((-1, 1, 1))
+ dists = inter_distances(transformed_arrays)
+ sqrt_dist = torch.sqrt(torch.mean(dists**2))
+
+ if "mean" == reduction:
+ pred = torch.mean(transformed_arrays, dim=0)
+ elif "median" == reduction:
+ pred = torch.median(transformed_arrays, dim=0).values
+ else:
+ raise ValueError
+
+ near_err = torch.sqrt((0 - torch.min(pred)) ** 2)
+ far_err = torch.sqrt((1 - torch.max(pred)) ** 2)
+
+ err = sqrt_dist + (near_err + far_err) * regularizer_strength
+ err = err.detach().cpu().numpy().astype(np_dtype)
+ return err
+
+ res = minimize(
+ closure,
+ x,
+ method="BFGS",
+ tol=tol,
+ options={"maxiter": max_iter, "disp": False},
+ )
+ x = res.x
+ l = len(x)
+ s = x[: int(l / 2)]
+ t = x[int(l / 2) :]
+
+ # Prediction
+ s = torch.from_numpy(s).to(dtype=dtype).to(device)
+ t = torch.from_numpy(t).to(dtype=dtype).to(device)
+ transformed_arrays = original_input * s.view(-1, 1, 1) + t.view(-1, 1, 1)
+ if "mean" == reduction:
+ aligned_images = torch.mean(transformed_arrays, dim=0)
+ std = torch.std(transformed_arrays, dim=0)
+ uncertainty = std
+ elif "median" == reduction:
+ aligned_images = torch.median(transformed_arrays, dim=0).values
+ # MAD (median absolute deviation) as uncertainty indicator
+ abs_dev = torch.abs(transformed_arrays - aligned_images)
+ mad = torch.median(abs_dev, dim=0).values
+ uncertainty = mad
+ else:
+ raise ValueError(f"Unknown reduction method: {reduction}")
+
+ # Scale and shift to [0, 1]
+ _min = torch.min(aligned_images)
+ _max = torch.max(aligned_images)
+ aligned_images = (aligned_images - _min) / (_max - _min)
+ uncertainty /= _max - _min
+
+ return aligned_images, uncertainty
diff --git a/diffusers/examples/community/masked_stable_diffusion_img2img.py b/diffusers/examples/community/masked_stable_diffusion_img2img.py
new file mode 100644
index 0000000000000000000000000000000000000000..a210c167a2f3574ad0f730441377bc17167d4281
--- /dev/null
+++ b/diffusers/examples/community/masked_stable_diffusion_img2img.py
@@ -0,0 +1,262 @@
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import numpy as np
+import PIL.Image
+import torch
+
+from diffusers import StableDiffusionImg2ImgPipeline
+from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
+
+
+class MaskedStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline):
+ debug_save = False
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ image: Union[
+ torch.Tensor,
+ PIL.Image.Image,
+ np.ndarray,
+ List[torch.Tensor],
+ List[PIL.Image.Image],
+ List[np.ndarray],
+ ] = None,
+ strength: float = 0.8,
+ num_inference_steps: Optional[int] = 50,
+ guidance_scale: Optional[float] = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: Optional[float] = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
+ callback_steps: int = 1,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ mask: Union[
+ torch.Tensor,
+ PIL.Image.Image,
+ np.ndarray,
+ List[torch.Tensor],
+ List[PIL.Image.Image],
+ List[np.ndarray],
+ ] = None,
+ ):
+ r"""
+ The call function to the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
+ image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
+ `Image` or tensor representing an image batch to be used as the starting point. Can also accept image
+ latents as `image`, but if passing latents directly it is not encoded again.
+ strength (`float`, *optional*, defaults to 0.8):
+ Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
+ starting point and more noise is added the higher the `strength`. The number of denoising steps depends
+ on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
+ process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
+ essentially ignores `image`.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference. This parameter is modulated by `strength`.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ A higher guidance scale value encourages the model to generate images closely linked to the text
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
+ generation deterministic.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
+ provided, text embeddings are generated from the `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that calls every `callback_steps` steps during inference. The function is called with the
+ following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function is called. If not specified, the callback is called at
+ every step.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ mask (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`, *optional*):
+ A mask with non-zero elements for the area to be inpainted. If not specified, no mask is applied.
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
+ otherwise a `tuple` is returned where the first element is a list with the generated images and the
+ second element is a list of `bool`s indicating whether the corresponding generated image contains
+ "not-safe-for-work" (nsfw) content.
+ """
+ # code adapted from parent class StableDiffusionImg2ImgPipeline
+
+ # 0. Check inputs. Raise error if not correct
+ self.check_inputs(prompt, strength, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds)
+
+ # 1. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 2. Encode input prompt
+ text_encoder_lora_scale = (
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
+ )
+ prompt_embeds = self._encode_prompt(
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ lora_scale=text_encoder_lora_scale,
+ )
+
+ # 3. Preprocess image
+ image = self.image_processor.preprocess(image)
+
+ # 4. set timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
+
+ # 5. Prepare latent variables
+ # it is sampled from the latent distribution of the VAE
+ latents = self.prepare_latents(
+ image, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator
+ )
+
+ # mean of the latent distribution
+ init_latents = [
+ self.vae.encode(image.to(device=device, dtype=prompt_embeds.dtype)[i : i + 1]).latent_dist.mean
+ for i in range(batch_size)
+ ]
+ init_latents = torch.cat(init_latents, dim=0)
+
+ # 6. create latent mask
+ latent_mask = self._make_latent_mask(latents, mask)
+
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 8. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs=cross_attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ if latent_mask is not None:
+ latents = torch.lerp(init_latents * self.vae.config.scaling_factor, latents, latent_mask)
+ noise_pred = torch.lerp(torch.zeros_like(noise_pred), noise_pred, latent_mask)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
+
+ if not output_type == "latent":
+ scaled = latents / self.vae.config.scaling_factor
+ if latent_mask is not None:
+ # scaled = latents / self.vae.config.scaling_factor * latent_mask + init_latents * (1 - latent_mask)
+ scaled = torch.lerp(init_latents, scaled, latent_mask)
+ image = self.vae.decode(scaled, return_dict=False)[0]
+ if self.debug_save:
+ image_gen = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+ image_gen = self.image_processor.postprocess(image_gen, output_type=output_type, do_denormalize=[True])
+ image_gen[0].save("from_latent.png")
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
+ else:
+ image = latents
+ has_nsfw_concept = None
+
+ if has_nsfw_concept is None:
+ do_denormalize = [True] * image.shape[0]
+ else:
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
+
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
+
+ # Offload last model to CPU
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.final_offload_hook.offload()
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
+
+ def _make_latent_mask(self, latents, mask):
+ if mask is not None:
+ latent_mask = []
+ if not isinstance(mask, list):
+ tmp_mask = [mask]
+ else:
+ tmp_mask = mask
+ _, l_channels, l_height, l_width = latents.shape
+ for m in tmp_mask:
+ if not isinstance(m, PIL.Image.Image):
+ if len(m.shape) == 2:
+ m = m[..., np.newaxis]
+ if m.max() > 1:
+ m = m / 255.0
+ m = self.image_processor.numpy_to_pil(m)[0]
+ if m.mode != "L":
+ m = m.convert("L")
+ resized = self.image_processor.resize(m, l_height, l_width)
+ if self.debug_save:
+ resized.save("latent_mask.png")
+ latent_mask.append(np.repeat(np.array(resized)[np.newaxis, :, :], l_channels, axis=0))
+ latent_mask = torch.as_tensor(np.stack(latent_mask)).to(latents)
+ latent_mask = latent_mask / latent_mask.max()
+ return latent_mask
diff --git a/diffusers/examples/community/masked_stable_diffusion_xl_img2img.py b/diffusers/examples/community/masked_stable_diffusion_xl_img2img.py
new file mode 100644
index 0000000000000000000000000000000000000000..c6b0ced527b5ebdfadc3aa6de648a47f9a2ed11f
--- /dev/null
+++ b/diffusers/examples/community/masked_stable_diffusion_xl_img2img.py
@@ -0,0 +1,682 @@
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+from PIL import Image, ImageFilter
+
+from diffusers.image_processor import PipelineImageInput
+from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
+from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img import (
+ StableDiffusionXLImg2ImgPipeline,
+ rescale_noise_cfg,
+ retrieve_latents,
+ retrieve_timesteps,
+)
+from diffusers.utils import (
+ deprecate,
+ is_torch_xla_available,
+ logging,
+)
+from diffusers.utils.torch_utils import randn_tensor
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class MaskedStableDiffusionXLImg2ImgPipeline(StableDiffusionXLImg2ImgPipeline):
+ debug_save = 0
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ prompt_2: Optional[Union[str, List[str]]] = None,
+ image: PipelineImageInput = None,
+ original_image: PipelineImageInput = None,
+ strength: float = 0.3,
+ num_inference_steps: Optional[int] = 50,
+ timesteps: List[int] = None,
+ denoising_start: Optional[float] = None,
+ denoising_end: Optional[float] = None,
+ guidance_scale: Optional[float] = 5.0,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: Optional[float] = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ ip_adapter_image: Optional[PipelineImageInput] = None,
+ ip_adapter_image_embeds: Optional[List[torch.FloatTensor]] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ guidance_rescale: float = 0.0,
+ original_size: Tuple[int, int] = None,
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
+ target_size: Tuple[int, int] = None,
+ negative_original_size: Optional[Tuple[int, int]] = None,
+ negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
+ negative_target_size: Optional[Tuple[int, int]] = None,
+ aesthetic_score: float = 6.0,
+ negative_aesthetic_score: float = 2.5,
+ clip_skip: Optional[int] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ mask: Union[
+ torch.FloatTensor,
+ Image.Image,
+ np.ndarray,
+ List[torch.FloatTensor],
+ List[Image.Image],
+ List[np.ndarray],
+ ] = None,
+ blur=24,
+ blur_compose=4,
+ sample_mode="sample",
+ **kwargs,
+ ):
+ r"""
+ The call function to the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
+ image (`PipelineImageInput`):
+ `Image` or tensor representing an image batch to be used as the starting point. This image might have mask painted on it.
+ original_image (`PipelineImageInput`, *optional*):
+ `Image` or tensor representing an image batch to be used for blending with the result.
+ strength (`float`, *optional*, defaults to 0.8):
+ Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
+ starting point and more noise is added the higher the `strength`. The number of denoising steps depends
+ on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
+ process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
+ essentially ignores `image`.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference. This parameter is modulated by `strength`.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ A higher guidance scale value encourages the model to generate images closely linked to the text
+ ,`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
+ generation deterministic.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
+ provided, text embeddings are generated from the `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that calls every `callback_steps` steps during inference. The function is called with the
+ following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function is called. If not specified, the callback is called at
+ every step.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ blur (`int`, *optional*):
+ blur to apply to mask
+ blur_compose (`int`, *optional*):
+ blur to apply for composition of original a
+ mask (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`, *optional*):
+ A mask with non-zero elements for the area to be inpainted. If not specified, no mask is applied.
+ sample_mode (`str`, *optional*):
+ control latents initialisation for the inpaint area, can be one of sample, argmax, random
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
+ otherwise a `tuple` is returned where the first element is a list with the generated images and the
+ second element is a list of `bool`s indicating whether the corresponding generated image contains
+ "not-safe-for-work" (nsfw) content.
+ """
+ # code adapted from parent class StableDiffusionXLImg2ImgPipeline
+ callback = kwargs.pop("callback", None)
+ callback_steps = kwargs.pop("callback_steps", None)
+
+ if callback is not None:
+ deprecate(
+ "callback",
+ "1.0.0",
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
+ )
+ if callback_steps is not None:
+ deprecate(
+ "callback_steps",
+ "1.0.0",
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
+ )
+
+ # 0. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ prompt_2,
+ strength,
+ num_inference_steps,
+ callback_steps,
+ negative_prompt,
+ negative_prompt_2,
+ prompt_embeds,
+ negative_prompt_embeds,
+ ip_adapter_image,
+ ip_adapter_image_embeds,
+ callback_on_step_end_tensor_inputs,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._guidance_rescale = guidance_rescale
+ self._clip_skip = clip_skip
+ self._cross_attention_kwargs = cross_attention_kwargs
+ self._denoising_end = denoising_end
+ self._denoising_start = denoising_start
+ self._interrupt = False
+
+ # 1. Define call parameters
+ # mask is computed from difference between image and original_image
+ if image is not None:
+ neq = np.any(np.array(original_image) != np.array(image), axis=-1)
+ mask = neq.astype(np.uint8) * 255
+ else:
+ assert mask is not None
+
+ if not isinstance(mask, Image.Image):
+ pil_mask = Image.fromarray(mask)
+ if pil_mask.mode != "L":
+ pil_mask = pil_mask.convert("L")
+ mask_blur = self.blur_mask(pil_mask, blur)
+ mask_compose = self.blur_mask(pil_mask, blur_compose)
+ if original_image is None:
+ original_image = image
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ # 2. Encode input prompt
+ text_encoder_lora_scale = (
+ self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
+ )
+ (
+ prompt_embeds,
+ negative_prompt_embeds,
+ pooled_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ ) = self.encode_prompt(
+ prompt=prompt,
+ prompt_2=prompt_2,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ negative_prompt=negative_prompt,
+ negative_prompt_2=negative_prompt_2,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ lora_scale=text_encoder_lora_scale,
+ clip_skip=self.clip_skip,
+ )
+
+ # 3. Preprocess image
+ input_image = image if image is not None else original_image
+ image = self.image_processor.preprocess(input_image)
+ original_image = self.image_processor.preprocess(original_image)
+
+ # 4. set timesteps
+ def denoising_value_valid(dnv):
+ return isinstance(dnv, float) and 0 < dnv < 1
+
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
+ timesteps, num_inference_steps = self.get_timesteps(
+ num_inference_steps,
+ strength,
+ device,
+ denoising_start=self.denoising_start if denoising_value_valid(self.denoising_start) else None,
+ )
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
+
+ add_noise = True if self.denoising_start is None else False
+
+ # 5. Prepare latent variables
+ # It is sampled from the latent distribution of the VAE
+ # that's what we repaint
+ latents = self.prepare_latents(
+ image,
+ latent_timestep,
+ batch_size,
+ num_images_per_prompt,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ add_noise,
+ sample_mode=sample_mode,
+ )
+
+ # mean of the latent distribution
+ # it is multiplied by self.vae.config.scaling_factor
+ non_paint_latents = self.prepare_latents(
+ original_image,
+ latent_timestep,
+ batch_size,
+ num_images_per_prompt,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ add_noise=False,
+ sample_mode="argmax",
+ )
+
+ if self.debug_save:
+ init_img_from_latents = self.latents_to_img(non_paint_latents)
+ init_img_from_latents[0].save("non_paint_latents.png")
+ # 6. create latent mask
+ latent_mask = self._make_latent_mask(latents, mask)
+
+ # 7. Prepare extra step kwargs.
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ height, width = latents.shape[-2:]
+ height = height * self.vae_scale_factor
+ width = width * self.vae_scale_factor
+
+ original_size = original_size or (height, width)
+ target_size = target_size or (height, width)
+
+ # 8. Prepare added time ids & embeddings
+ if negative_original_size is None:
+ negative_original_size = original_size
+ if negative_target_size is None:
+ negative_target_size = target_size
+
+ add_text_embeds = pooled_prompt_embeds
+ if self.text_encoder_2 is None:
+ text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
+ else:
+ text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
+
+ add_time_ids, add_neg_time_ids = self._get_add_time_ids(
+ original_size,
+ crops_coords_top_left,
+ target_size,
+ aesthetic_score,
+ negative_aesthetic_score,
+ negative_original_size,
+ negative_crops_coords_top_left,
+ negative_target_size,
+ dtype=prompt_embeds.dtype,
+ text_encoder_projection_dim=text_encoder_projection_dim,
+ )
+ add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1)
+
+ if self.do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
+ add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1)
+ add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0)
+
+ prompt_embeds = prompt_embeds.to(device)
+ add_text_embeds = add_text_embeds.to(device)
+ add_time_ids = add_time_ids.to(device)
+
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
+ image_embeds = self.prepare_ip_adapter_image_embeds(
+ ip_adapter_image,
+ ip_adapter_image_embeds,
+ device,
+ batch_size * num_images_per_prompt,
+ self.do_classifier_free_guidance,
+ )
+
+ # 10. Denoising loop
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+
+ # 10.1 Apply denoising_end
+ if (
+ self.denoising_end is not None
+ and self.denoising_start is not None
+ and denoising_value_valid(self.denoising_end)
+ and denoising_value_valid(self.denoising_start)
+ and self.denoising_start >= self.denoising_end
+ ):
+ raise ValueError(
+ f"`denoising_start`: {self.denoising_start} cannot be larger than or equal to `denoising_end`: "
+ + f" {self.denoising_end} when using type float."
+ )
+ elif self.denoising_end is not None and denoising_value_valid(self.denoising_end):
+ discrete_timestep_cutoff = int(
+ round(
+ self.scheduler.config.num_train_timesteps
+ - (self.denoising_end * self.scheduler.config.num_train_timesteps)
+ )
+ )
+ num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
+ timesteps = timesteps[:num_inference_steps]
+
+ # 10.2 Optionally get Guidance Scale Embedding
+ timestep_cond = None
+ if self.unet.config.time_cond_proj_dim is not None:
+ guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
+ timestep_cond = self.get_guidance_scale_embedding(
+ guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
+ ).to(device=device, dtype=latents.dtype)
+
+ self._num_timesteps = len(timesteps)
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ shape = non_paint_latents.shape
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=latents.dtype)
+ # noisy latent code of input image at current step
+ orig_latents_t = non_paint_latents
+ orig_latents_t = self.scheduler.add_noise(non_paint_latents, noise, t.unsqueeze(0))
+
+ # orig_latents_t (1 - latent_mask) + latents * latent_mask
+ latents = torch.lerp(orig_latents_t, latents, latent_mask)
+
+ if self.debug_save:
+ img1 = self.latents_to_img(latents)
+ t_str = str(t.int().item())
+ for i in range(3 - len(t_str)):
+ t_str = "0" + t_str
+ img1[0].save(f"step{t_str}.png")
+
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
+ added_cond_kwargs["image_embeds"] = image_embeds
+
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ timestep_cond=timestep_cond,
+ cross_attention_kwargs=self.cross_attention_kwargs,
+ added_cond_kwargs=added_cond_kwargs,
+ return_dict=False,
+ )[0]
+
+ # perform guidance
+ if self.do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents_dtype = latents.dtype
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+ if latents.dtype != latents_dtype:
+ if torch.backends.mps.is_available():
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
+ latents = latents.to(latents_dtype)
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+ add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
+ negative_pooled_prompt_embeds = callback_outputs.pop(
+ "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
+ )
+ add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
+ add_neg_time_ids = callback_outputs.pop("add_neg_time_ids", add_neg_time_ids)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ if not output_type == "latent":
+ # make sure the VAE is in float32 mode, as it overflows in float16
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
+
+ if needs_upcasting:
+ self.upcast_vae()
+ elif latents.dtype != self.vae.dtype:
+ if torch.backends.mps.is_available():
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
+ self.vae = self.vae.to(latents.dtype)
+
+ if self.debug_save:
+ image_gen = self.latents_to_img(latents)
+ image_gen[0].save("from_latent.png")
+
+ if latent_mask is not None:
+ # interpolate with latent mask
+ latents = torch.lerp(non_paint_latents, latents, latent_mask)
+
+ latents = self.denormalize(latents)
+ image = self.vae.decode(latents, return_dict=False)[0]
+ m = mask_compose.permute(2, 0, 1).unsqueeze(0).to(image)
+ img_compose = m * image + (1 - m) * original_image.to(image)
+ image = img_compose
+ # cast back to fp16 if needed
+ if needs_upcasting:
+ self.vae.to(dtype=torch.float16)
+ else:
+ image = latents
+
+ # apply watermark if available
+ if self.watermark is not None:
+ image = self.watermark.apply_watermark(image)
+
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image,)
+
+ return StableDiffusionXLPipelineOutput(images=image)
+
+ def _make_latent_mask(self, latents, mask):
+ if mask is not None:
+ latent_mask = []
+ if not isinstance(mask, list):
+ tmp_mask = [mask]
+ else:
+ tmp_mask = mask
+ _, l_channels, l_height, l_width = latents.shape
+ for m in tmp_mask:
+ if not isinstance(m, Image.Image):
+ if len(m.shape) == 2:
+ m = m[..., np.newaxis]
+ if m.max() > 1:
+ m = m / 255.0
+ m = self.image_processor.numpy_to_pil(m)[0]
+ if m.mode != "L":
+ m = m.convert("L")
+ resized = self.image_processor.resize(m, l_height, l_width)
+ if self.debug_save:
+ resized.save("latent_mask.png")
+ latent_mask.append(np.repeat(np.array(resized)[np.newaxis, :, :], l_channels, axis=0))
+ latent_mask = torch.as_tensor(np.stack(latent_mask)).to(latents)
+ latent_mask = latent_mask / max(latent_mask.max(), 1)
+ return latent_mask
+
+ def prepare_latents(
+ self,
+ image,
+ timestep,
+ batch_size,
+ num_images_per_prompt,
+ dtype,
+ device,
+ generator=None,
+ add_noise=True,
+ sample_mode: str = "sample",
+ ):
+ if not isinstance(image, (torch.Tensor, Image.Image, list)):
+ raise ValueError(
+ f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
+ )
+
+ # Offload text encoder if `enable_model_cpu_offload` was enabled
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.text_encoder_2.to("cpu")
+ torch.cuda.empty_cache()
+
+ image = image.to(device=device, dtype=dtype)
+
+ batch_size = batch_size * num_images_per_prompt
+
+ if image.shape[1] == 4:
+ init_latents = image
+ elif sample_mode == "random":
+ height, width = image.shape[-2:]
+ num_channels_latents = self.unet.config.in_channels
+ latents = self.random_latents(
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ )
+ return self.vae.config.scaling_factor * latents
+ else:
+ # make sure the VAE is in float32 mode, as it overflows in float16
+ if self.vae.config.force_upcast:
+ image = image.float()
+ self.vae.to(dtype=torch.float32)
+
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ elif isinstance(generator, list):
+ init_latents = [
+ retrieve_latents(
+ self.vae.encode(image[i : i + 1]), generator=generator[i], sample_mode=sample_mode
+ )
+ for i in range(batch_size)
+ ]
+ init_latents = torch.cat(init_latents, dim=0)
+ else:
+ init_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode=sample_mode)
+
+ if self.vae.config.force_upcast:
+ self.vae.to(dtype)
+
+ init_latents = init_latents.to(dtype)
+ init_latents = self.vae.config.scaling_factor * init_latents
+
+ if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
+ # expand init_latents for batch_size
+ additional_image_per_prompt = batch_size // init_latents.shape[0]
+ init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0)
+ elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
+ raise ValueError(
+ f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
+ )
+ else:
+ init_latents = torch.cat([init_latents], dim=0)
+
+ if add_noise:
+ shape = init_latents.shape
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ # get latents
+ init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
+
+ latents = init_latents
+
+ return latents
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
+ def random_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ def denormalize(self, latents):
+ # unscale/denormalize the latents
+ # denormalize with the mean and std if available and not None
+ has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None
+ has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None
+ if has_latents_mean and has_latents_std:
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)
+ )
+ latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)
+ latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean
+ else:
+ latents = latents / self.vae.config.scaling_factor
+
+ return latents
+
+ def latents_to_img(self, latents):
+ l1 = self.denormalize(latents)
+ img1 = self.vae.decode(l1, return_dict=False)[0]
+ img1 = self.image_processor.postprocess(img1, output_type="pil", do_denormalize=[True])
+ return img1
+
+ def blur_mask(self, pil_mask, blur):
+ mask_blur = pil_mask.filter(ImageFilter.GaussianBlur(radius=blur))
+ mask_blur = np.array(mask_blur)
+ return torch.from_numpy(np.tile(mask_blur / mask_blur.max(), (3, 1, 1)).transpose(1, 2, 0))
diff --git a/diffusers/examples/community/mixture_canvas.py b/diffusers/examples/community/mixture_canvas.py
new file mode 100644
index 0000000000000000000000000000000000000000..2bb054a123d055f3b00ed72e58507f82e687f407
--- /dev/null
+++ b/diffusers/examples/community/mixture_canvas.py
@@ -0,0 +1,501 @@
+import re
+from copy import deepcopy
+from dataclasses import asdict, dataclass
+from enum import Enum
+from typing import List, Optional, Union
+
+import numpy as np
+import torch
+from numpy import exp, pi, sqrt
+from torchvision.transforms.functional import resize
+from tqdm.auto import tqdm
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
+
+from diffusers.models import AutoencoderKL, UNet2DConditionModel
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
+from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
+from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
+
+
+def preprocess_image(image):
+ from PIL import Image
+
+ """Preprocess an input image
+
+ Same as
+ https://github.com/huggingface/diffusers/blob/1138d63b519e37f0ce04e027b9f4a3261d27c628/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py#L44
+ """
+ w, h = image.size
+ w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32
+ image = image.resize((w, h), resample=Image.LANCZOS)
+ image = np.array(image).astype(np.float32) / 255.0
+ image = image[None].transpose(0, 3, 1, 2)
+ image = torch.from_numpy(image)
+ return 2.0 * image - 1.0
+
+
+@dataclass
+class CanvasRegion:
+ """Class defining a rectangular region in the canvas"""
+
+ row_init: int # Region starting row in pixel space (included)
+ row_end: int # Region end row in pixel space (not included)
+ col_init: int # Region starting column in pixel space (included)
+ col_end: int # Region end column in pixel space (not included)
+ region_seed: int = None # Seed for random operations in this region
+ noise_eps: float = 0.0 # Deviation of a zero-mean gaussian noise to be applied over the latents in this region. Useful for slightly "rerolling" latents
+
+ def __post_init__(self):
+ # Initialize arguments if not specified
+ if self.region_seed is None:
+ self.region_seed = np.random.randint(9999999999)
+ # Check coordinates are non-negative
+ for coord in [self.row_init, self.row_end, self.col_init, self.col_end]:
+ if coord < 0:
+ raise ValueError(
+ f"A CanvasRegion must be defined with non-negative indices, found ({self.row_init}, {self.row_end}, {self.col_init}, {self.col_end})"
+ )
+ # Check coordinates are divisible by 8, else we end up with nasty rounding error when mapping to latent space
+ for coord in [self.row_init, self.row_end, self.col_init, self.col_end]:
+ if coord // 8 != coord / 8:
+ raise ValueError(
+ f"A CanvasRegion must be defined with locations divisible by 8, found ({self.row_init}-{self.row_end}, {self.col_init}-{self.col_end})"
+ )
+ # Check noise eps is non-negative
+ if self.noise_eps < 0:
+ raise ValueError(f"A CanvasRegion must be defined noises eps non-negative, found {self.noise_eps}")
+ # Compute coordinates for this region in latent space
+ self.latent_row_init = self.row_init // 8
+ self.latent_row_end = self.row_end // 8
+ self.latent_col_init = self.col_init // 8
+ self.latent_col_end = self.col_end // 8
+
+ @property
+ def width(self):
+ return self.col_end - self.col_init
+
+ @property
+ def height(self):
+ return self.row_end - self.row_init
+
+ def get_region_generator(self, device="cpu"):
+ """Creates a torch.Generator based on the random seed of this region"""
+ # Initialize region generator
+ return torch.Generator(device).manual_seed(self.region_seed)
+
+ @property
+ def __dict__(self):
+ return asdict(self)
+
+
+class MaskModes(Enum):
+ """Modes in which the influence of diffuser is masked"""
+
+ CONSTANT = "constant"
+ GAUSSIAN = "gaussian"
+ QUARTIC = "quartic" # See https://en.wikipedia.org/wiki/Kernel_(statistics)
+
+
+@dataclass
+class DiffusionRegion(CanvasRegion):
+ """Abstract class defining a region where some class of diffusion process is acting"""
+
+ pass
+
+
+@dataclass
+class Text2ImageRegion(DiffusionRegion):
+ """Class defining a region where a text guided diffusion process is acting"""
+
+ prompt: str = "" # Text prompt guiding the diffuser in this region
+ guidance_scale: float = 7.5 # Guidance scale of the diffuser in this region. If None, randomize
+ mask_type: MaskModes = MaskModes.GAUSSIAN.value # Kind of weight mask applied to this region
+ mask_weight: float = 1.0 # Global weights multiplier of the mask
+ tokenized_prompt = None # Tokenized prompt
+ encoded_prompt = None # Encoded prompt
+
+ def __post_init__(self):
+ super().__post_init__()
+ # Mask weight cannot be negative
+ if self.mask_weight < 0:
+ raise ValueError(
+ f"A Text2ImageRegion must be defined with non-negative mask weight, found {self.mask_weight}"
+ )
+ # Mask type must be an actual known mask
+ if self.mask_type not in [e.value for e in MaskModes]:
+ raise ValueError(
+ f"A Text2ImageRegion was defined with mask {self.mask_type}, which is not an accepted mask ({[e.value for e in MaskModes]})"
+ )
+ # Randomize arguments if given as None
+ if self.guidance_scale is None:
+ self.guidance_scale = np.random.randint(5, 30)
+ # Clean prompt
+ self.prompt = re.sub(" +", " ", self.prompt).replace("\n", " ")
+
+ def tokenize_prompt(self, tokenizer):
+ """Tokenizes the prompt for this diffusion region using a given tokenizer"""
+ self.tokenized_prompt = tokenizer(
+ self.prompt,
+ padding="max_length",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ def encode_prompt(self, text_encoder, device):
+ """Encodes the previously tokenized prompt for this diffusion region using a given encoder"""
+ assert self.tokenized_prompt is not None, ValueError(
+ "Prompt in diffusion region must be tokenized before encoding"
+ )
+ self.encoded_prompt = text_encoder(self.tokenized_prompt.input_ids.to(device))[0]
+
+
+@dataclass
+class Image2ImageRegion(DiffusionRegion):
+ """Class defining a region where an image guided diffusion process is acting"""
+
+ reference_image: torch.Tensor = None
+ strength: float = 0.8 # Strength of the image
+
+ def __post_init__(self):
+ super().__post_init__()
+ if self.reference_image is None:
+ raise ValueError("Must provide a reference image when creating an Image2ImageRegion")
+ if self.strength < 0 or self.strength > 1:
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {self.strength}")
+ # Rescale image to region shape
+ self.reference_image = resize(self.reference_image, size=[self.height, self.width])
+
+ def encode_reference_image(self, encoder, device, generator, cpu_vae=False):
+ """Encodes the reference image for this Image2Image region into the latent space"""
+ # Place encoder in CPU or not following the parameter cpu_vae
+ if cpu_vae:
+ # Note here we use mean instead of sample, to avoid moving also generator to CPU, which is troublesome
+ self.reference_latents = encoder.cpu().encode(self.reference_image).latent_dist.mean.to(device)
+ else:
+ self.reference_latents = encoder.encode(self.reference_image.to(device)).latent_dist.sample(
+ generator=generator
+ )
+ self.reference_latents = 0.18215 * self.reference_latents
+
+ @property
+ def __dict__(self):
+ # This class requires special casting to dict because of the reference_image tensor. Otherwise it cannot be casted to JSON
+
+ # Get all basic fields from parent class
+ super_fields = {key: getattr(self, key) for key in DiffusionRegion.__dataclass_fields__.keys()}
+ # Pack other fields
+ return {**super_fields, "reference_image": self.reference_image.cpu().tolist(), "strength": self.strength}
+
+
+class RerollModes(Enum):
+ """Modes in which the reroll regions operate"""
+
+ RESET = "reset" # Completely reset the random noise in the region
+ EPSILON = "epsilon" # Alter slightly the latents in the region
+
+
+@dataclass
+class RerollRegion(CanvasRegion):
+ """Class defining a rectangular canvas region in which initial latent noise will be rerolled"""
+
+ reroll_mode: RerollModes = RerollModes.RESET.value
+
+
+@dataclass
+class MaskWeightsBuilder:
+ """Auxiliary class to compute a tensor of weights for a given diffusion region"""
+
+ latent_space_dim: int # Size of the U-net latent space
+ nbatch: int = 1 # Batch size in the U-net
+
+ def compute_mask_weights(self, region: DiffusionRegion) -> torch.tensor:
+ """Computes a tensor of weights for a given diffusion region"""
+ MASK_BUILDERS = {
+ MaskModes.CONSTANT.value: self._constant_weights,
+ MaskModes.GAUSSIAN.value: self._gaussian_weights,
+ MaskModes.QUARTIC.value: self._quartic_weights,
+ }
+ return MASK_BUILDERS[region.mask_type](region)
+
+ def _constant_weights(self, region: DiffusionRegion) -> torch.tensor:
+ """Computes a tensor of constant for a given diffusion region"""
+ latent_width = region.latent_col_end - region.latent_col_init
+ latent_height = region.latent_row_end - region.latent_row_init
+ return torch.ones(self.nbatch, self.latent_space_dim, latent_height, latent_width) * region.mask_weight
+
+ def _gaussian_weights(self, region: DiffusionRegion) -> torch.tensor:
+ """Generates a gaussian mask of weights for tile contributions"""
+ latent_width = region.latent_col_end - region.latent_col_init
+ latent_height = region.latent_row_end - region.latent_row_init
+
+ var = 0.01
+ midpoint = (latent_width - 1) / 2 # -1 because index goes from 0 to latent_width - 1
+ x_probs = [
+ exp(-(x - midpoint) * (x - midpoint) / (latent_width * latent_width) / (2 * var)) / sqrt(2 * pi * var)
+ for x in range(latent_width)
+ ]
+ midpoint = (latent_height - 1) / 2
+ y_probs = [
+ exp(-(y - midpoint) * (y - midpoint) / (latent_height * latent_height) / (2 * var)) / sqrt(2 * pi * var)
+ for y in range(latent_height)
+ ]
+
+ weights = np.outer(y_probs, x_probs) * region.mask_weight
+ return torch.tile(torch.tensor(weights), (self.nbatch, self.latent_space_dim, 1, 1))
+
+ def _quartic_weights(self, region: DiffusionRegion) -> torch.tensor:
+ """Generates a quartic mask of weights for tile contributions
+
+ The quartic kernel has bounded support over the diffusion region, and a smooth decay to the region limits.
+ """
+ quartic_constant = 15.0 / 16.0
+
+ support = (np.array(range(region.latent_col_init, region.latent_col_end)) - region.latent_col_init) / (
+ region.latent_col_end - region.latent_col_init - 1
+ ) * 1.99 - (1.99 / 2.0)
+ x_probs = quartic_constant * np.square(1 - np.square(support))
+ support = (np.array(range(region.latent_row_init, region.latent_row_end)) - region.latent_row_init) / (
+ region.latent_row_end - region.latent_row_init - 1
+ ) * 1.99 - (1.99 / 2.0)
+ y_probs = quartic_constant * np.square(1 - np.square(support))
+
+ weights = np.outer(y_probs, x_probs) * region.mask_weight
+ return torch.tile(torch.tensor(weights), (self.nbatch, self.latent_space_dim, 1, 1))
+
+
+class StableDiffusionCanvasPipeline(DiffusionPipeline, StableDiffusionMixin):
+ """Stable Diffusion pipeline that mixes several diffusers in the same canvas"""
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: Union[DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler],
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPImageProcessor,
+ ):
+ super().__init__()
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+
+ def decode_latents(self, latents, cpu_vae=False):
+ """Decodes a given array of latents into pixel space"""
+ # scale and decode the image latents with vae
+ if cpu_vae:
+ lat = deepcopy(latents).cpu()
+ vae = deepcopy(self.vae).cpu()
+ else:
+ lat = latents
+ vae = self.vae
+
+ lat = 1 / 0.18215 * lat
+ image = vae.decode(lat).sample
+
+ image = (image / 2 + 0.5).clamp(0, 1)
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
+
+ return self.numpy_to_pil(image)
+
+ def get_latest_timestep_img2img(self, num_inference_steps, strength):
+ """Finds the latest timesteps where an img2img strength does not impose latents anymore"""
+ # get the original timestep using init_timestep
+ offset = self.scheduler.config.get("steps_offset", 0)
+ init_timestep = int(num_inference_steps * (1 - strength)) + offset
+ init_timestep = min(init_timestep, num_inference_steps)
+
+ t_start = min(max(num_inference_steps - init_timestep + offset, 0), num_inference_steps - 1)
+ latest_timestep = self.scheduler.timesteps[t_start]
+
+ return latest_timestep
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ canvas_height: int,
+ canvas_width: int,
+ regions: List[DiffusionRegion],
+ num_inference_steps: Optional[int] = 50,
+ seed: Optional[int] = 12345,
+ reroll_regions: Optional[List[RerollRegion]] = None,
+ cpu_vae: Optional[bool] = False,
+ decode_steps: Optional[bool] = False,
+ ):
+ if reroll_regions is None:
+ reroll_regions = []
+ batch_size = 1
+
+ if decode_steps:
+ steps_images = []
+
+ # Prepare scheduler
+ self.scheduler.set_timesteps(num_inference_steps, device=self.device)
+
+ # Split diffusion regions by their kind
+ text2image_regions = [region for region in regions if isinstance(region, Text2ImageRegion)]
+ image2image_regions = [region for region in regions if isinstance(region, Image2ImageRegion)]
+
+ # Prepare text embeddings
+ for region in text2image_regions:
+ region.tokenize_prompt(self.tokenizer)
+ region.encode_prompt(self.text_encoder, self.device)
+
+ # Create original noisy latents using the timesteps
+ latents_shape = (batch_size, self.unet.config.in_channels, canvas_height // 8, canvas_width // 8)
+ generator = torch.Generator(self.device).manual_seed(seed)
+ init_noise = torch.randn(latents_shape, generator=generator, device=self.device)
+
+ # Reset latents in seed reroll regions, if requested
+ for region in reroll_regions:
+ if region.reroll_mode == RerollModes.RESET.value:
+ region_shape = (
+ latents_shape[0],
+ latents_shape[1],
+ region.latent_row_end - region.latent_row_init,
+ region.latent_col_end - region.latent_col_init,
+ )
+ init_noise[
+ :,
+ :,
+ region.latent_row_init : region.latent_row_end,
+ region.latent_col_init : region.latent_col_end,
+ ] = torch.randn(region_shape, generator=region.get_region_generator(self.device), device=self.device)
+
+ # Apply epsilon noise to regions: first diffusion regions, then reroll regions
+ all_eps_rerolls = regions + [r for r in reroll_regions if r.reroll_mode == RerollModes.EPSILON.value]
+ for region in all_eps_rerolls:
+ if region.noise_eps > 0:
+ region_noise = init_noise[
+ :,
+ :,
+ region.latent_row_init : region.latent_row_end,
+ region.latent_col_init : region.latent_col_end,
+ ]
+ eps_noise = (
+ torch.randn(
+ region_noise.shape, generator=region.get_region_generator(self.device), device=self.device
+ )
+ * region.noise_eps
+ )
+ init_noise[
+ :,
+ :,
+ region.latent_row_init : region.latent_row_end,
+ region.latent_col_init : region.latent_col_end,
+ ] += eps_noise
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = init_noise * self.scheduler.init_noise_sigma
+
+ # Get unconditional embeddings for classifier free guidance in text2image regions
+ for region in text2image_regions:
+ max_length = region.tokenized_prompt.input_ids.shape[-1]
+ uncond_input = self.tokenizer(
+ [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
+ )
+ uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ region.encoded_prompt = torch.cat([uncond_embeddings, region.encoded_prompt])
+
+ # Prepare image latents
+ for region in image2image_regions:
+ region.encode_reference_image(self.vae, device=self.device, generator=generator)
+
+ # Prepare mask of weights for each region
+ mask_builder = MaskWeightsBuilder(latent_space_dim=self.unet.config.in_channels, nbatch=batch_size)
+ mask_weights = [mask_builder.compute_mask_weights(region).to(self.device) for region in text2image_regions]
+
+ # Diffusion timesteps
+ for i, t in tqdm(enumerate(self.scheduler.timesteps)):
+ # Diffuse each region
+ noise_preds_regions = []
+
+ # text2image regions
+ for region in text2image_regions:
+ region_latents = latents[
+ :,
+ :,
+ region.latent_row_init : region.latent_row_end,
+ region.latent_col_init : region.latent_col_end,
+ ]
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([region_latents] * 2)
+ # scale model input following scheduler rules
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+ # predict the noise residual
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=region.encoded_prompt)["sample"]
+ # perform guidance
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred_region = noise_pred_uncond + region.guidance_scale * (noise_pred_text - noise_pred_uncond)
+ noise_preds_regions.append(noise_pred_region)
+
+ # Merge noise predictions for all tiles
+ noise_pred = torch.zeros(latents.shape, device=self.device)
+ contributors = torch.zeros(latents.shape, device=self.device)
+ # Add each tile contribution to overall latents
+ for region, noise_pred_region, mask_weights_region in zip(
+ text2image_regions, noise_preds_regions, mask_weights
+ ):
+ noise_pred[
+ :,
+ :,
+ region.latent_row_init : region.latent_row_end,
+ region.latent_col_init : region.latent_col_end,
+ ] += noise_pred_region * mask_weights_region
+ contributors[
+ :,
+ :,
+ region.latent_row_init : region.latent_row_end,
+ region.latent_col_init : region.latent_col_end,
+ ] += mask_weights_region
+ # Average overlapping areas with more than 1 contributor
+ noise_pred /= contributors
+ noise_pred = torch.nan_to_num(
+ noise_pred
+ ) # Replace NaNs by zeros: NaN can appear if a position is not covered by any DiffusionRegion
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents).prev_sample
+
+ # Image2Image regions: override latents generated by the scheduler
+ for region in image2image_regions:
+ influence_step = self.get_latest_timestep_img2img(num_inference_steps, region.strength)
+ # Only override in the timesteps before the last influence step of the image (given by its strength)
+ if t > influence_step:
+ timestep = t.repeat(batch_size)
+ region_init_noise = init_noise[
+ :,
+ :,
+ region.latent_row_init : region.latent_row_end,
+ region.latent_col_init : region.latent_col_end,
+ ]
+ region_latents = self.scheduler.add_noise(region.reference_latents, region_init_noise, timestep)
+ latents[
+ :,
+ :,
+ region.latent_row_init : region.latent_row_end,
+ region.latent_col_init : region.latent_col_end,
+ ] = region_latents
+
+ if decode_steps:
+ steps_images.append(self.decode_latents(latents, cpu_vae))
+
+ # scale and decode the image latents with vae
+ image = self.decode_latents(latents, cpu_vae)
+
+ output = {"images": image}
+ if decode_steps:
+ output = {**output, "steps_images": steps_images}
+ return output
diff --git a/diffusers/examples/community/mixture_tiling.py b/diffusers/examples/community/mixture_tiling.py
new file mode 100644
index 0000000000000000000000000000000000000000..867bce0d9eb8c9bb7d191144267fcc9a09980dee
--- /dev/null
+++ b/diffusers/examples/community/mixture_tiling.py
@@ -0,0 +1,405 @@
+import inspect
+from copy import deepcopy
+from enum import Enum
+from typing import List, Optional, Tuple, Union
+
+import torch
+from tqdm.auto import tqdm
+
+from diffusers.models import AutoencoderKL, UNet2DConditionModel
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline
+from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
+from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
+from diffusers.utils import logging
+
+
+try:
+ from ligo.segments import segment
+ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
+except ImportError:
+ raise ImportError("Please install transformers and ligo-segments to use the mixture pipeline")
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> from diffusers import LMSDiscreteScheduler, DiffusionPipeline
+
+ >>> scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
+ >>> pipeline = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", scheduler=scheduler, custom_pipeline="mixture_tiling")
+ >>> pipeline.to("cuda")
+
+ >>> image = pipeline(
+ >>> prompt=[[
+ >>> "A charming house in the countryside, by jakub rozalski, sunset lighting, elegant, highly detailed, smooth, sharp focus, artstation, stunning masterpiece",
+ >>> "A dirt road in the countryside crossing pastures, by jakub rozalski, sunset lighting, elegant, highly detailed, smooth, sharp focus, artstation, stunning masterpiece",
+ >>> "An old and rusty giant robot lying on a dirt road, by jakub rozalski, dark sunset lighting, elegant, highly detailed, smooth, sharp focus, artstation, stunning masterpiece"
+ >>> ]],
+ >>> tile_height=640,
+ >>> tile_width=640,
+ >>> tile_row_overlap=0,
+ >>> tile_col_overlap=256,
+ >>> guidance_scale=8,
+ >>> seed=7178915308,
+ >>> num_inference_steps=50,
+ >>> )["images"][0]
+ ```
+"""
+
+
+def _tile2pixel_indices(tile_row, tile_col, tile_width, tile_height, tile_row_overlap, tile_col_overlap):
+ """Given a tile row and column numbers returns the range of pixels affected by that tiles in the overall image
+
+ Returns a tuple with:
+ - Starting coordinates of rows in pixel space
+ - Ending coordinates of rows in pixel space
+ - Starting coordinates of columns in pixel space
+ - Ending coordinates of columns in pixel space
+ """
+ px_row_init = 0 if tile_row == 0 else tile_row * (tile_height - tile_row_overlap)
+ px_row_end = px_row_init + tile_height
+ px_col_init = 0 if tile_col == 0 else tile_col * (tile_width - tile_col_overlap)
+ px_col_end = px_col_init + tile_width
+ return px_row_init, px_row_end, px_col_init, px_col_end
+
+
+def _pixel2latent_indices(px_row_init, px_row_end, px_col_init, px_col_end):
+ """Translates coordinates in pixel space to coordinates in latent space"""
+ return px_row_init // 8, px_row_end // 8, px_col_init // 8, px_col_end // 8
+
+
+def _tile2latent_indices(tile_row, tile_col, tile_width, tile_height, tile_row_overlap, tile_col_overlap):
+ """Given a tile row and column numbers returns the range of latents affected by that tiles in the overall image
+
+ Returns a tuple with:
+ - Starting coordinates of rows in latent space
+ - Ending coordinates of rows in latent space
+ - Starting coordinates of columns in latent space
+ - Ending coordinates of columns in latent space
+ """
+ px_row_init, px_row_end, px_col_init, px_col_end = _tile2pixel_indices(
+ tile_row, tile_col, tile_width, tile_height, tile_row_overlap, tile_col_overlap
+ )
+ return _pixel2latent_indices(px_row_init, px_row_end, px_col_init, px_col_end)
+
+
+def _tile2latent_exclusive_indices(
+ tile_row, tile_col, tile_width, tile_height, tile_row_overlap, tile_col_overlap, rows, columns
+):
+ """Given a tile row and column numbers returns the range of latents affected only by that tile in the overall image
+
+ Returns a tuple with:
+ - Starting coordinates of rows in latent space
+ - Ending coordinates of rows in latent space
+ - Starting coordinates of columns in latent space
+ - Ending coordinates of columns in latent space
+ """
+ row_init, row_end, col_init, col_end = _tile2latent_indices(
+ tile_row, tile_col, tile_width, tile_height, tile_row_overlap, tile_col_overlap
+ )
+ row_segment = segment(row_init, row_end)
+ col_segment = segment(col_init, col_end)
+ # Iterate over the rest of tiles, clipping the region for the current tile
+ for row in range(rows):
+ for column in range(columns):
+ if row != tile_row and column != tile_col:
+ clip_row_init, clip_row_end, clip_col_init, clip_col_end = _tile2latent_indices(
+ row, column, tile_width, tile_height, tile_row_overlap, tile_col_overlap
+ )
+ row_segment = row_segment - segment(clip_row_init, clip_row_end)
+ col_segment = col_segment - segment(clip_col_init, clip_col_end)
+ # return row_init, row_end, col_init, col_end
+ return row_segment[0], row_segment[1], col_segment[0], col_segment[1]
+
+
+class StableDiffusionExtrasMixin:
+ """Mixin providing additional convenience method to Stable Diffusion pipelines"""
+
+ def decode_latents(self, latents, cpu_vae=False):
+ """Decodes a given array of latents into pixel space"""
+ # scale and decode the image latents with vae
+ if cpu_vae:
+ lat = deepcopy(latents).cpu()
+ vae = deepcopy(self.vae).cpu()
+ else:
+ lat = latents
+ vae = self.vae
+
+ lat = 1 / 0.18215 * lat
+ image = vae.decode(lat).sample
+
+ image = (image / 2 + 0.5).clamp(0, 1)
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
+
+ return self.numpy_to_pil(image)
+
+
+class StableDiffusionTilingPipeline(DiffusionPipeline, StableDiffusionExtrasMixin):
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: Union[DDIMScheduler, PNDMScheduler],
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPImageProcessor,
+ ):
+ super().__init__()
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+
+ class SeedTilesMode(Enum):
+ """Modes in which the latents of a particular tile can be re-seeded"""
+
+ FULL = "full"
+ EXCLUSIVE = "exclusive"
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[List[str]]],
+ num_inference_steps: Optional[int] = 50,
+ guidance_scale: Optional[float] = 7.5,
+ eta: Optional[float] = 0.0,
+ seed: Optional[int] = None,
+ tile_height: Optional[int] = 512,
+ tile_width: Optional[int] = 512,
+ tile_row_overlap: Optional[int] = 256,
+ tile_col_overlap: Optional[int] = 256,
+ guidance_scale_tiles: Optional[List[List[float]]] = None,
+ seed_tiles: Optional[List[List[int]]] = None,
+ seed_tiles_mode: Optional[Union[str, List[List[str]]]] = "full",
+ seed_reroll_regions: Optional[List[Tuple[int, int, int, int, int]]] = None,
+ cpu_vae: Optional[bool] = False,
+ ):
+ r"""
+ Function to run the diffusion pipeline with tiling support.
+
+ Args:
+ prompt: either a single string (no tiling) or a list of lists with all the prompts to use (one list for each row of tiles). This will also define the tiling structure.
+ num_inference_steps: number of diffusions steps.
+ guidance_scale: classifier-free guidance.
+ seed: general random seed to initialize latents.
+ tile_height: height in pixels of each grid tile.
+ tile_width: width in pixels of each grid tile.
+ tile_row_overlap: number of overlap pixels between tiles in consecutive rows.
+ tile_col_overlap: number of overlap pixels between tiles in consecutive columns.
+ guidance_scale_tiles: specific weights for classifier-free guidance in each tile.
+ guidance_scale_tiles: specific weights for classifier-free guidance in each tile. If None, the value provided in guidance_scale will be used.
+ seed_tiles: specific seeds for the initialization latents in each tile. These will override the latents generated for the whole canvas using the standard seed parameter.
+ seed_tiles_mode: either "full" "exclusive". If "full", all the latents affected by the tile be overriden. If "exclusive", only the latents that are affected exclusively by this tile (and no other tiles) will be overriden.
+ seed_reroll_regions: a list of tuples in the form (start row, end row, start column, end column, seed) defining regions in pixel space for which the latents will be overriden using the given seed. Takes priority over seed_tiles.
+ cpu_vae: the decoder from latent space to pixel space can require too mucho GPU RAM for large images. If you find out of memory errors at the end of the generation process, try setting this parameter to True to run the decoder in CPU. Slower, but should run without memory issues.
+
+ Examples:
+
+ Returns:
+ A PIL image with the generated image.
+
+ """
+ if not isinstance(prompt, list) or not all(isinstance(row, list) for row in prompt):
+ raise ValueError(f"`prompt` has to be a list of lists but is {type(prompt)}")
+ grid_rows = len(prompt)
+ grid_cols = len(prompt[0])
+ if not all(len(row) == grid_cols for row in prompt):
+ raise ValueError("All prompt rows must have the same number of prompt columns")
+ if not isinstance(seed_tiles_mode, str) and (
+ not isinstance(seed_tiles_mode, list) or not all(isinstance(row, list) for row in seed_tiles_mode)
+ ):
+ raise ValueError(f"`seed_tiles_mode` has to be a string or list of lists but is {type(prompt)}")
+ if isinstance(seed_tiles_mode, str):
+ seed_tiles_mode = [[seed_tiles_mode for _ in range(len(row))] for row in prompt]
+
+ modes = [mode.value for mode in self.SeedTilesMode]
+ if any(mode not in modes for row in seed_tiles_mode for mode in row):
+ raise ValueError(f"Seed tiles mode must be one of {modes}")
+ if seed_reroll_regions is None:
+ seed_reroll_regions = []
+ batch_size = 1
+
+ # create original noisy latents using the timesteps
+ height = tile_height + (grid_rows - 1) * (tile_height - tile_row_overlap)
+ width = tile_width + (grid_cols - 1) * (tile_width - tile_col_overlap)
+ latents_shape = (batch_size, self.unet.config.in_channels, height // 8, width // 8)
+ generator = torch.Generator("cuda").manual_seed(seed)
+ latents = torch.randn(latents_shape, generator=generator, device=self.device)
+
+ # overwrite latents for specific tiles if provided
+ if seed_tiles is not None:
+ for row in range(grid_rows):
+ for col in range(grid_cols):
+ if (seed_tile := seed_tiles[row][col]) is not None:
+ mode = seed_tiles_mode[row][col]
+ if mode == self.SeedTilesMode.FULL.value:
+ row_init, row_end, col_init, col_end = _tile2latent_indices(
+ row, col, tile_width, tile_height, tile_row_overlap, tile_col_overlap
+ )
+ else:
+ row_init, row_end, col_init, col_end = _tile2latent_exclusive_indices(
+ row,
+ col,
+ tile_width,
+ tile_height,
+ tile_row_overlap,
+ tile_col_overlap,
+ grid_rows,
+ grid_cols,
+ )
+ tile_generator = torch.Generator("cuda").manual_seed(seed_tile)
+ tile_shape = (latents_shape[0], latents_shape[1], row_end - row_init, col_end - col_init)
+ latents[:, :, row_init:row_end, col_init:col_end] = torch.randn(
+ tile_shape, generator=tile_generator, device=self.device
+ )
+
+ # overwrite again for seed reroll regions
+ for row_init, row_end, col_init, col_end, seed_reroll in seed_reroll_regions:
+ row_init, row_end, col_init, col_end = _pixel2latent_indices(
+ row_init, row_end, col_init, col_end
+ ) # to latent space coordinates
+ reroll_generator = torch.Generator("cuda").manual_seed(seed_reroll)
+ region_shape = (latents_shape[0], latents_shape[1], row_end - row_init, col_end - col_init)
+ latents[:, :, row_init:row_end, col_init:col_end] = torch.randn(
+ region_shape, generator=reroll_generator, device=self.device
+ )
+
+ # Prepare scheduler
+ accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())
+ extra_set_kwargs = {}
+ if accepts_offset:
+ extra_set_kwargs["offset"] = 1
+ self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
+ # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas
+ if isinstance(self.scheduler, LMSDiscreteScheduler):
+ latents = latents * self.scheduler.sigmas[0]
+
+ # get prompts text embeddings
+ text_input = [
+ [
+ self.tokenizer(
+ col,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ for col in row
+ ]
+ for row in prompt
+ ]
+ text_embeddings = [[self.text_encoder(col.input_ids.to(self.device))[0] for col in row] for row in text_input]
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0 # TODO: also active if any tile has guidance scale
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance:
+ for i in range(grid_rows):
+ for j in range(grid_cols):
+ max_length = text_input[i][j].input_ids.shape[-1]
+ uncond_input = self.tokenizer(
+ [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
+ )
+ uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ text_embeddings[i][j] = torch.cat([uncond_embeddings, text_embeddings[i][j]])
+
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # Mask for tile weights strength
+ tile_weights = self._gaussian_weights(tile_width, tile_height, batch_size)
+
+ # Diffusion timesteps
+ for i, t in tqdm(enumerate(self.scheduler.timesteps)):
+ # Diffuse each tile
+ noise_preds = []
+ for row in range(grid_rows):
+ noise_preds_row = []
+ for col in range(grid_cols):
+ px_row_init, px_row_end, px_col_init, px_col_end = _tile2latent_indices(
+ row, col, tile_width, tile_height, tile_row_overlap, tile_col_overlap
+ )
+ tile_latents = latents[:, :, px_row_init:px_row_end, px_col_init:px_col_end]
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([tile_latents] * 2) if do_classifier_free_guidance else tile_latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+ # predict the noise residual
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings[row][col])[
+ "sample"
+ ]
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ guidance = (
+ guidance_scale
+ if guidance_scale_tiles is None or guidance_scale_tiles[row][col] is None
+ else guidance_scale_tiles[row][col]
+ )
+ noise_pred_tile = noise_pred_uncond + guidance * (noise_pred_text - noise_pred_uncond)
+ noise_preds_row.append(noise_pred_tile)
+ noise_preds.append(noise_preds_row)
+ # Stitch noise predictions for all tiles
+ noise_pred = torch.zeros(latents.shape, device=self.device)
+ contributors = torch.zeros(latents.shape, device=self.device)
+ # Add each tile contribution to overall latents
+ for row in range(grid_rows):
+ for col in range(grid_cols):
+ px_row_init, px_row_end, px_col_init, px_col_end = _tile2latent_indices(
+ row, col, tile_width, tile_height, tile_row_overlap, tile_col_overlap
+ )
+ noise_pred[:, :, px_row_init:px_row_end, px_col_init:px_col_end] += (
+ noise_preds[row][col] * tile_weights
+ )
+ contributors[:, :, px_row_init:px_row_end, px_col_init:px_col_end] += tile_weights
+ # Average overlapping areas with more than 1 contributor
+ noise_pred /= contributors
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents).prev_sample
+
+ # scale and decode the image latents with vae
+ image = self.decode_latents(latents, cpu_vae)
+
+ return {"images": image}
+
+ def _gaussian_weights(self, tile_width, tile_height, nbatches):
+ """Generates a gaussian mask of weights for tile contributions"""
+ import numpy as np
+ from numpy import exp, pi, sqrt
+
+ latent_width = tile_width // 8
+ latent_height = tile_height // 8
+
+ var = 0.01
+ midpoint = (latent_width - 1) / 2 # -1 because index goes from 0 to latent_width - 1
+ x_probs = [
+ exp(-(x - midpoint) * (x - midpoint) / (latent_width * latent_width) / (2 * var)) / sqrt(2 * pi * var)
+ for x in range(latent_width)
+ ]
+ midpoint = latent_height / 2
+ y_probs = [
+ exp(-(y - midpoint) * (y - midpoint) / (latent_height * latent_height) / (2 * var)) / sqrt(2 * pi * var)
+ for y in range(latent_height)
+ ]
+
+ weights = np.outer(y_probs, x_probs)
+ return torch.tile(torch.tensor(weights, device=self.device), (nbatches, self.unet.config.in_channels, 1, 1))
diff --git a/diffusers/examples/community/multilingual_stable_diffusion.py b/diffusers/examples/community/multilingual_stable_diffusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..dc335e0b585e13158e1bfd9b513294a83193ec79
--- /dev/null
+++ b/diffusers/examples/community/multilingual_stable_diffusion.py
@@ -0,0 +1,410 @@
+import inspect
+from typing import Callable, List, Optional, Union
+
+import torch
+from transformers import (
+ CLIPImageProcessor,
+ CLIPTextModel,
+ CLIPTokenizer,
+ MBart50TokenizerFast,
+ MBartForConditionalGeneration,
+ pipeline,
+)
+
+from diffusers.configuration_utils import FrozenDict
+from diffusers.models import AutoencoderKL, UNet2DConditionModel
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
+from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
+from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
+from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
+from diffusers.utils import deprecate, logging
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+def detect_language(pipe, prompt, batch_size):
+ """helper function to detect language(s) of prompt"""
+
+ if batch_size == 1:
+ preds = pipe(prompt, top_k=1, truncation=True, max_length=128)
+ return preds[0]["label"]
+ else:
+ detected_languages = []
+ for p in prompt:
+ preds = pipe(p, top_k=1, truncation=True, max_length=128)
+ detected_languages.append(preds[0]["label"])
+
+ return detected_languages
+
+
+def translate_prompt(prompt, translation_tokenizer, translation_model, device):
+ """helper function to translate prompt to English"""
+
+ encoded_prompt = translation_tokenizer(prompt, return_tensors="pt").to(device)
+ generated_tokens = translation_model.generate(**encoded_prompt, max_new_tokens=1000)
+ en_trans = translation_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
+
+ return en_trans[0]
+
+
+class MultilingualStableDiffusion(DiffusionPipeline, StableDiffusionMixin):
+ r"""
+ Pipeline for text-to-image generation using Stable Diffusion in different languages.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ detection_pipeline ([`pipeline`]):
+ Transformers pipeline to detect prompt's language.
+ translation_model ([`MBartForConditionalGeneration`]):
+ Model to translate prompt to English, if necessary. Please refer to the
+ [model card](https://huggingface.co/docs/transformers/model_doc/mbart) for details.
+ translation_tokenizer ([`MBart50TokenizerFast`]):
+ Tokenizer of the translation model.
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ feature_extractor ([`CLIPImageProcessor`]):
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+ """
+
+ def __init__(
+ self,
+ detection_pipeline: pipeline,
+ translation_model: MBartForConditionalGeneration,
+ translation_tokenizer: MBart50TokenizerFast,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPImageProcessor,
+ ):
+ super().__init__()
+
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
+ " file"
+ )
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["steps_offset"] = 1
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if safety_checker is None:
+ logger.warning(
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
+ )
+
+ self.register_modules(
+ detection_pipeline=detection_pipeline,
+ translation_model=translation_model,
+ translation_tokenizer=translation_tokenizer,
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]],
+ height: int = 512,
+ width: int = 512,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[torch.Generator] = None,
+ latents: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
+ callback_steps: int = 1,
+ **kwargs,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image generation. Can be in different languages.
+ height (`int`, *optional*, defaults to 512):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to 512):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator`, *optional*):
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
+ deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+ if isinstance(prompt, str):
+ batch_size = 1
+ elif isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ # detect language and translate if necessary
+ prompt_language = detect_language(self.detection_pipeline, prompt, batch_size)
+ if batch_size == 1 and prompt_language != "en":
+ prompt = translate_prompt(prompt, self.translation_tokenizer, self.translation_model, self.device)
+
+ if isinstance(prompt, list):
+ for index in range(batch_size):
+ if prompt_language[index] != "en":
+ p = translate_prompt(
+ prompt[index], self.translation_tokenizer, self.translation_model, self.device
+ )
+ prompt[index] = p
+
+ # get prompt text embeddings
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+
+ if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
+ removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+ text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
+ text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ bs_embed, seq_len, _ = text_embeddings.shape
+ text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
+ text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ # detect language and translate it if necessary
+ negative_prompt_language = detect_language(self.detection_pipeline, negative_prompt, batch_size)
+ if negative_prompt_language != "en":
+ negative_prompt = translate_prompt(
+ negative_prompt, self.translation_tokenizer, self.translation_model, self.device
+ )
+ if isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ # detect language and translate it if necessary
+ if isinstance(negative_prompt, list):
+ negative_prompt_languages = detect_language(self.detection_pipeline, negative_prompt, batch_size)
+ for index in range(batch_size):
+ if negative_prompt_languages[index] != "en":
+ p = translate_prompt(
+ negative_prompt[index], self.translation_tokenizer, self.translation_model, self.device
+ )
+ negative_prompt[index] = p
+ uncond_tokens = negative_prompt
+
+ max_length = text_input_ids.shape[-1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
+
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = uncond_embeddings.shape[1]
+ uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
+ uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
+
+ # get the initial random noise unless the user supplied it
+
+ # Unlike in other pipelines, latents need to be generated in the target device
+ # for 1-to-1 results reproducibility with the CompVis implementation.
+ # However this currently doesn't work in `mps`.
+ latents_shape = (batch_size * num_images_per_prompt, self.unet.config.in_channels, height // 8, width // 8)
+ latents_dtype = text_embeddings.dtype
+ if latents is None:
+ if self.device.type == "mps":
+ # randn does not work reproducibly on mps
+ latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to(
+ self.device
+ )
+ else:
+ latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)
+ else:
+ if latents.shape != latents_shape:
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
+ latents = latents.to(self.device)
+
+ # set timesteps
+ self.scheduler.set_timesteps(num_inference_steps)
+
+ # Some schedulers like PNDM have timesteps as arrays
+ # It's more optimized to move all timesteps to correct device beforehand
+ timesteps_tensor = self.scheduler.timesteps.to(self.device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ for i, t in enumerate(self.progress_bar(timesteps_tensor)):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+ # call the callback, if provided
+ if callback is not None and i % callback_steps == 0:
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
+
+ latents = 1 / 0.18215 * latents
+ image = self.vae.decode(latents).sample
+
+ image = (image / 2 + 0.5).clamp(0, 1)
+
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+
+ if self.safety_checker is not None:
+ safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(
+ self.device
+ )
+ image, has_nsfw_concept = self.safety_checker(
+ images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)
+ )
+ else:
+ has_nsfw_concept = None
+
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
diff --git a/diffusers/examples/community/one_step_unet.py b/diffusers/examples/community/one_step_unet.py
new file mode 100644
index 0000000000000000000000000000000000000000..7d34bfd83191d63483bc562cb54cc887660cdffa
--- /dev/null
+++ b/diffusers/examples/community/one_step_unet.py
@@ -0,0 +1,24 @@
+#!/usr/bin/env python3
+import torch
+
+from diffusers import DiffusionPipeline
+
+
+class UnetSchedulerOneForwardPipeline(DiffusionPipeline):
+ def __init__(self, unet, scheduler):
+ super().__init__()
+
+ self.register_modules(unet=unet, scheduler=scheduler)
+
+ def __call__(self):
+ image = torch.randn(
+ (1, self.unet.config.in_channels, self.unet.config.sample_size, self.unet.config.sample_size),
+ )
+ timestep = 1
+
+ model_output = self.unet(image, timestep).sample
+ scheduler_output = self.scheduler.step(model_output, timestep, image).prev_sample
+
+ result = scheduler_output - scheduler_output + torch.ones_like(scheduler_output)
+
+ return result
diff --git a/diffusers/examples/community/pipeline_animatediff_controlnet.py b/diffusers/examples/community/pipeline_animatediff_controlnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..bedf002d024c575d2ffb3aa8839b9b63559d2cf4
--- /dev/null
+++ b/diffusers/examples/community/pipeline_animatediff_controlnet.py
@@ -0,0 +1,1129 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from PIL import Image
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
+
+from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
+from diffusers.loaders import IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
+from diffusers.models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel, UNetMotionModel
+from diffusers.models.lora import adjust_lora_scale_text_encoder
+from diffusers.models.unets.unet_motion_model import MotionAdapter
+from diffusers.pipelines.animatediff.pipeline_output import AnimateDiffPipelineOutput
+from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
+from diffusers.schedulers import (
+ DDIMScheduler,
+ DPMSolverMultistepScheduler,
+ EulerAncestralDiscreteScheduler,
+ EulerDiscreteScheduler,
+ LMSDiscreteScheduler,
+ PNDMScheduler,
+)
+from diffusers.utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
+from diffusers.utils.torch_utils import is_compiled_module, randn_tensor
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import AutoencoderKL, ControlNetModel, MotionAdapter
+ >>> from diffusers.pipelines import DiffusionPipeline
+ >>> from diffusers.schedulers import DPMSolverMultistepScheduler
+ >>> from PIL import Image
+
+ >>> motion_id = "guoyww/animatediff-motion-adapter-v1-5-2"
+ >>> adapter = MotionAdapter.from_pretrained(motion_id)
+ >>> controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_openpose", torch_dtype=torch.float16)
+ >>> vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16)
+
+ >>> model_id = "SG161222/Realistic_Vision_V5.1_noVAE"
+ >>> pipe = DiffusionPipeline.from_pretrained(
+ ... model_id,
+ ... motion_adapter=adapter,
+ ... controlnet=controlnet,
+ ... vae=vae,
+ ... custom_pipeline="pipeline_animatediff_controlnet",
+ ... ).to(device="cuda", dtype=torch.float16)
+ >>> pipe.scheduler = DPMSolverMultistepScheduler.from_pretrained(
+ ... model_id, subfolder="scheduler", clip_sample=False, timestep_spacing="linspace", steps_offset=1, beta_schedule="linear",
+ ... )
+ >>> pipe.enable_vae_slicing()
+
+ >>> conditioning_frames = []
+ >>> for i in range(1, 16 + 1):
+ ... conditioning_frames.append(Image.open(f"frame_{i}.png"))
+
+ >>> prompt = "astronaut in space, dancing"
+ >>> negative_prompt = "bad quality, worst quality, jpeg artifacts, ugly"
+ >>> result = pipe(
+ ... prompt=prompt,
+ ... negative_prompt=negative_prompt,
+ ... width=512,
+ ... height=768,
+ ... conditioning_frames=conditioning_frames,
+ ... num_inference_steps=12,
+ ... )
+
+ >>> from diffusers.utils import export_to_gif
+ >>> export_to_gif(result.frames[0], "result.gif")
+ ```
+"""
+
+
+# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.tensor2vid
+def tensor2vid(video: torch.Tensor, processor, output_type="np"):
+ batch_size, channels, num_frames, height, width = video.shape
+ outputs = []
+ for batch_idx in range(batch_size):
+ batch_vid = video[batch_idx].permute(1, 0, 2, 3)
+ batch_output = processor.postprocess(batch_vid, output_type)
+
+ outputs.append(batch_output)
+
+ if output_type == "np":
+ outputs = np.stack(outputs)
+
+ elif output_type == "pt":
+ outputs = torch.stack(outputs)
+
+ elif not output_type == "pil":
+ raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']")
+
+ return outputs
+
+
+class AnimateDiffControlNetPipeline(
+ DiffusionPipeline,
+ StableDiffusionMixin,
+ TextualInversionLoaderMixin,
+ IPAdapterMixin,
+ StableDiffusionLoraLoaderMixin,
+):
+ r"""
+ Pipeline for text-to-video generation.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
+
+ The pipeline also inherits the following loading methods:
+ - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
+ - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
+ - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
+ - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
+ tokenizer (`CLIPTokenizer`):
+ A [`~transformers.CLIPTokenizer`] to tokenize text.
+ unet ([`UNet2DConditionModel`]):
+ A [`UNet2DConditionModel`] used to create a UNetMotionModel to denoise the encoded video latents.
+ motion_adapter ([`MotionAdapter`]):
+ A [`MotionAdapter`] to be used in combination with `unet` to denoise the encoded video latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ """
+
+ model_cpu_offload_seq = "text_encoder->unet->vae"
+ _optional_components = ["feature_extractor", "image_encoder"]
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ motion_adapter: MotionAdapter,
+ controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],
+ scheduler: Union[
+ DDIMScheduler,
+ PNDMScheduler,
+ LMSDiscreteScheduler,
+ EulerDiscreteScheduler,
+ EulerAncestralDiscreteScheduler,
+ DPMSolverMultistepScheduler,
+ ],
+ feature_extractor: Optional[CLIPImageProcessor] = None,
+ image_encoder: Optional[CLIPVisionModelWithProjection] = None,
+ ):
+ super().__init__()
+ unet = UNetMotionModel.from_unet2d(unet, motion_adapter)
+
+ if isinstance(controlnet, (list, tuple)):
+ controlnet = MultiControlNetModel(controlnet)
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ motion_adapter=motion_adapter,
+ controlnet=controlnet,
+ scheduler=scheduler,
+ feature_extractor=feature_extractor,
+ image_encoder=image_encoder,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+ self.control_image_processor = VaeImageProcessor(
+ vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
+ )
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt
+ def encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt=None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ lora_scale: Optional[float] = None,
+ clip_skip: Optional[int] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ lora_scale (`float`, *optional*):
+ A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+ """
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ # dynamically adjust the LoRA scale
+ if not USE_PEFT_BACKEND:
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
+ else:
+ scale_lora_layers(self.text_encoder, lora_scale)
+
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ # textual inversion: process multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = self.tokenizer.batch_decode(
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
+ )
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = text_inputs.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ if clip_skip is None:
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
+ prompt_embeds = prompt_embeds[0]
+ else:
+ prompt_embeds = self.text_encoder(
+ text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True
+ )
+ # Access the `hidden_states` first, that contains a tuple of
+ # all the hidden states from the encoder layers. Then index into
+ # the tuple to access the hidden states from the desired layer.
+ prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
+ # We also need to apply the final LayerNorm here to not mess with the
+ # representations. The `last_hidden_states` that we typically use for
+ # obtaining the final prompt representations passes through the LayerNorm
+ # layer.
+ prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
+
+ if self.text_encoder is not None:
+ prompt_embeds_dtype = self.text_encoder.dtype
+ elif self.unet is not None:
+ prompt_embeds_dtype = self.unet.dtype
+ else:
+ prompt_embeds_dtype = prompt_embeds.dtype
+
+ prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ # textual inversion: process multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = uncond_input.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ negative_prompt_embeds = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ negative_prompt_embeds = negative_prompt_embeds[0]
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder, lora_scale)
+
+ return prompt_embeds, negative_prompt_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
+ def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
+ dtype = next(self.image_encoder.parameters()).dtype
+
+ if not isinstance(image, torch.Tensor):
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
+
+ image = image.to(device=device, dtype=dtype)
+ if output_hidden_states:
+ image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
+ image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
+ uncond_image_enc_hidden_states = self.image_encoder(
+ torch.zeros_like(image), output_hidden_states=True
+ ).hidden_states[-2]
+ uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
+ num_images_per_prompt, dim=0
+ )
+ return image_enc_hidden_states, uncond_image_enc_hidden_states
+ else:
+ image_embeds = self.image_encoder(image).image_embeds
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+ uncond_image_embeds = torch.zeros_like(image_embeds)
+
+ return image_embeds, uncond_image_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
+ def prepare_ip_adapter_image_embeds(
+ self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt
+ ):
+ if ip_adapter_image_embeds is None:
+ if not isinstance(ip_adapter_image, list):
+ ip_adapter_image = [ip_adapter_image]
+
+ if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
+ raise ValueError(
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
+ )
+
+ image_embeds = []
+ for single_ip_adapter_image, image_proj_layer in zip(
+ ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
+ ):
+ output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
+ single_image_embeds, single_negative_image_embeds = self.encode_image(
+ single_ip_adapter_image, device, 1, output_hidden_state
+ )
+ single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
+ single_negative_image_embeds = torch.stack(
+ [single_negative_image_embeds] * num_images_per_prompt, dim=0
+ )
+
+ if self.do_classifier_free_guidance:
+ single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
+ single_image_embeds = single_image_embeds.to(device)
+
+ image_embeds.append(single_image_embeds)
+ else:
+ image_embeds = ip_adapter_image_embeds
+ return image_embeds
+
+ # Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents
+ def decode_latents(self, latents):
+ latents = 1 / self.vae.config.scaling_factor * latents
+
+ batch_size, channels, num_frames, height, width = latents.shape
+ latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width)
+
+ image = self.vae.decode(latents).sample
+ video = (
+ image[None, :]
+ .reshape(
+ (
+ batch_size,
+ num_frames,
+ -1,
+ )
+ + image.shape[2:]
+ )
+ .permute(0, 2, 1, 3, 4)
+ )
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+ video = video.float()
+ return video
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ num_frames,
+ callback_steps,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ image=None,
+ controlnet_conditioning_scale=1.0,
+ control_guidance_start=0.0,
+ control_guidance_end=1.0,
+ ):
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ # `prompt` needs more sophisticated handling when there are multiple
+ # conditionings.
+ if isinstance(self.controlnet, MultiControlNetModel):
+ if isinstance(prompt, list):
+ logger.warning(
+ f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}"
+ " prompts. The conditionings will be fixed across the prompts."
+ )
+
+ # Check `image`
+ is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
+ self.controlnet, torch._dynamo.eval_frame.OptimizedModule
+ )
+ if (
+ isinstance(self.controlnet, ControlNetModel)
+ or is_compiled
+ and isinstance(self.controlnet._orig_mod, ControlNetModel)
+ ):
+ if not isinstance(image, list):
+ raise TypeError(f"For single controlnet, `image` must be of type `list` but got {type(image)}")
+ if len(image) != num_frames:
+ raise ValueError(f"Excepted image to have length {num_frames} but got {len(image)=}")
+ elif (
+ isinstance(self.controlnet, MultiControlNetModel)
+ or is_compiled
+ and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
+ ):
+ if not isinstance(image, list) or not isinstance(image[0], list):
+ raise TypeError(f"For multiple controlnets: `image` must be type list of lists but got {type(image)=}")
+ if len(image[0]) != num_frames:
+ raise ValueError(f"Expected length of image sublist as {num_frames} but got {len(image[0])=}")
+ if any(len(img) != len(image[0]) for img in image):
+ raise ValueError("All conditioning frame batches for multicontrolnet must be same size")
+ else:
+ assert False
+
+ # Check `controlnet_conditioning_scale`
+ if (
+ isinstance(self.controlnet, ControlNetModel)
+ or is_compiled
+ and isinstance(self.controlnet._orig_mod, ControlNetModel)
+ ):
+ if not isinstance(controlnet_conditioning_scale, float):
+ raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
+ elif (
+ isinstance(self.controlnet, MultiControlNetModel)
+ or is_compiled
+ and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
+ ):
+ if isinstance(controlnet_conditioning_scale, list):
+ if any(isinstance(i, list) for i in controlnet_conditioning_scale):
+ raise ValueError("A single batch of multiple conditionings are supported at the moment.")
+ elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len(
+ self.controlnet.nets
+ ):
+ raise ValueError(
+ "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have"
+ " the same length as the number of controlnets"
+ )
+ else:
+ assert False
+
+ if not isinstance(control_guidance_start, (tuple, list)):
+ control_guidance_start = [control_guidance_start]
+
+ if not isinstance(control_guidance_end, (tuple, list)):
+ control_guidance_end = [control_guidance_end]
+
+ if len(control_guidance_start) != len(control_guidance_end):
+ raise ValueError(
+ f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list."
+ )
+
+ if isinstance(self.controlnet, MultiControlNetModel):
+ if len(control_guidance_start) != len(self.controlnet.nets):
+ raise ValueError(
+ f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}."
+ )
+
+ for start, end in zip(control_guidance_start, control_guidance_end):
+ if start >= end:
+ raise ValueError(
+ f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}."
+ )
+ if start < 0.0:
+ raise ValueError(f"control guidance start: {start} can't be smaller than 0.")
+ if end > 1.0:
+ raise ValueError(f"control guidance end: {end} can't be larger than 1.0.")
+
+ # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image
+ def check_image(self, image, prompt, prompt_embeds):
+ image_is_pil = isinstance(image, Image.Image)
+ image_is_tensor = isinstance(image, torch.Tensor)
+ image_is_np = isinstance(image, np.ndarray)
+ image_is_pil_list = isinstance(image, list) and isinstance(image[0], Image.Image)
+ image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)
+ image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray)
+
+ if (
+ not image_is_pil
+ and not image_is_tensor
+ and not image_is_np
+ and not image_is_pil_list
+ and not image_is_tensor_list
+ and not image_is_np_list
+ ):
+ raise TypeError(
+ f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}"
+ )
+
+ if image_is_pil:
+ image_batch_size = 1
+ else:
+ image_batch_size = len(image)
+
+ if prompt is not None and isinstance(prompt, str):
+ prompt_batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ prompt_batch_size = len(prompt)
+ elif prompt_embeds is not None:
+ prompt_batch_size = prompt_embeds.shape[0]
+
+ if image_batch_size != 1 and image_batch_size != prompt_batch_size:
+ raise ValueError(
+ f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
+ )
+
+ # Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.prepare_latents
+ def prepare_latents(
+ self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
+ ):
+ shape = (
+ batch_size,
+ num_channels_latents,
+ num_frames,
+ height // self.vae_scale_factor,
+ width // self.vae_scale_factor,
+ )
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image
+ def prepare_image(
+ self,
+ image,
+ width,
+ height,
+ batch_size,
+ num_images_per_prompt,
+ device,
+ dtype,
+ do_classifier_free_guidance=False,
+ guess_mode=False,
+ ):
+ image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
+ image_batch_size = image.shape[0]
+
+ if image_batch_size == 1:
+ repeat_by = batch_size
+ else:
+ # image batch size is the same as prompt batch size
+ repeat_by = num_images_per_prompt
+
+ image = image.repeat_interleave(repeat_by, dim=0)
+
+ image = image.to(device=device, dtype=dtype)
+
+ if do_classifier_free_guidance and not guess_mode:
+ image = torch.cat([image] * 2)
+
+ return image
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def clip_skip(self):
+ return self._clip_skip
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1
+
+ @property
+ def cross_attention_kwargs(self):
+ return self._cross_attention_kwargs
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ num_frames: Optional[int] = 16,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_videos_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ ip_adapter_image: Optional[PipelineImageInput] = None,
+ ip_adapter_image_embeds: Optional[PipelineImageInput] = None,
+ conditioning_frames: Optional[List[PipelineImageInput]] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
+ guess_mode: bool = False,
+ control_guidance_start: Union[float, List[float]] = 0.0,
+ control_guidance_end: Union[float, List[float]] = 1.0,
+ clip_skip: Optional[int] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ **kwargs,
+ ):
+ r"""
+ The call function to the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
+ height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
+ The height in pixels of the generated video.
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
+ The width in pixels of the generated video.
+ num_frames (`int`, *optional*, defaults to 16):
+ The number of video frames that are generated. Defaults to 16 frames which at 8 frames per seconds
+ amounts to 2 seconds of video.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality videos at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ A higher guidance scale value encourages the model to generate images closely linked to the text
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
+ generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor is generated by sampling using the supplied random `generator`. Latents should be of shape
+ `(batch_size, num_channel, num_frames, height, width)`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
+ provided, text embeddings are generated from the `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
+ ip_adapter_image (`PipelineImageInput`, *optional*):
+ Optional image input to work with IP Adapters.
+ ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters.
+ Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding
+ if `do_classifier_free_guidance` is set to `True`.
+ If not provided, embeddings are computed from the `ip_adapter_image` input argument.
+ conditioning_frames (`List[PipelineImageInput]`, *optional*):
+ The ControlNet input condition to provide guidance to the `unet` for generation. If multiple ControlNets
+ are specified, images must be passed as a list such that each element of the list can be correctly
+ batched for input to a single ControlNet.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generated video. Choose between `torch.Tensor`, `PIL.Image` or
+ `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] instead
+ of a plain tuple.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
+ The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added
+ to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set
+ the corresponding scale as a list.
+ guess_mode (`bool`, *optional*, defaults to `False`):
+ The ControlNet encoder tries to recognize the content of the input image even if you remove all
+ prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended.
+ control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
+ The percentage of total steps at which the ControlNet starts applying.
+ control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
+ The percentage of total steps at which the ControlNet stops applying.
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] is
+ returned, otherwise a `tuple` is returned where the first element is a list with the generated frames.
+ """
+
+ callback = kwargs.pop("callback", None)
+ callback_steps = kwargs.pop("callback_steps", None)
+
+ if callback is not None:
+ deprecate(
+ "callback",
+ "1.0.0",
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
+ )
+ if callback_steps is not None:
+ deprecate(
+ "callback_steps",
+ "1.0.0",
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
+ )
+
+ controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
+
+ # align format for control guidance
+ if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
+ control_guidance_start = len(control_guidance_end) * [control_guidance_start]
+ elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
+ control_guidance_end = len(control_guidance_start) * [control_guidance_end]
+ elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
+ mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
+ control_guidance_start, control_guidance_end = (
+ mult * [control_guidance_start],
+ mult * [control_guidance_end],
+ )
+
+ # 0. Default height and width to unet
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+ num_videos_per_prompt = 1
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt=prompt,
+ height=height,
+ width=width,
+ num_frames=num_frames,
+ callback_steps=callback_steps,
+ negative_prompt=negative_prompt,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ image=conditioning_frames,
+ controlnet_conditioning_scale=controlnet_conditioning_scale,
+ control_guidance_start=control_guidance_start,
+ control_guidance_end=control_guidance_end,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._clip_skip = clip_skip
+ self._cross_attention_kwargs = cross_attention_kwargs
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
+ controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
+
+ global_pool_conditions = (
+ controlnet.config.global_pool_conditions
+ if isinstance(controlnet, ControlNetModel)
+ else controlnet.nets[0].config.global_pool_conditions
+ )
+ guess_mode = guess_mode or global_pool_conditions
+
+ # 3. Encode input prompt
+ text_encoder_lora_scale = (
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
+ )
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
+ prompt,
+ device,
+ num_videos_per_prompt,
+ self.do_classifier_free_guidance,
+ negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ lora_scale=text_encoder_lora_scale,
+ clip_skip=self.clip_skip,
+ )
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ if self.do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+
+ if ip_adapter_image is not None:
+ image_embeds = self.prepare_ip_adapter_image_embeds(
+ ip_adapter_image, ip_adapter_image_embeds, device, batch_size * num_videos_per_prompt
+ )
+
+ if isinstance(controlnet, ControlNetModel):
+ conditioning_frames = self.prepare_image(
+ image=conditioning_frames,
+ width=width,
+ height=height,
+ batch_size=batch_size * num_videos_per_prompt * num_frames,
+ num_images_per_prompt=num_videos_per_prompt,
+ device=device,
+ dtype=controlnet.dtype,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ guess_mode=guess_mode,
+ )
+ elif isinstance(controlnet, MultiControlNetModel):
+ cond_prepared_frames = []
+ for frame_ in conditioning_frames:
+ prepared_frame = self.prepare_image(
+ image=frame_,
+ width=width,
+ height=height,
+ batch_size=batch_size * num_videos_per_prompt * num_frames,
+ num_images_per_prompt=num_videos_per_prompt,
+ device=device,
+ dtype=controlnet.dtype,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ guess_mode=guess_mode,
+ )
+ cond_prepared_frames.append(prepared_frame)
+ conditioning_frames = cond_prepared_frames
+ else:
+ assert False
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+ self._num_timesteps = len(timesteps)
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.unet.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_videos_per_prompt,
+ num_channels_latents,
+ num_frames,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 7. Add image embeds for IP-Adapter
+ added_cond_kwargs = (
+ {"image_embeds": image_embeds}
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None
+ else None
+ )
+
+ # 7.1 Create tensor stating which controlnets to keep
+ controlnet_keep = []
+ for i in range(len(timesteps)):
+ keeps = [
+ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
+ for s, e in zip(control_guidance_start, control_guidance_end)
+ ]
+ controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)
+
+ # 8. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ if guess_mode and self.do_classifier_free_guidance:
+ # Infer ControlNet only for the conditional batch.
+ control_model_input = latents
+ control_model_input = self.scheduler.scale_model_input(control_model_input, t)
+ controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
+ else:
+ control_model_input = latent_model_input
+ controlnet_prompt_embeds = prompt_embeds
+ controlnet_prompt_embeds = controlnet_prompt_embeds.repeat_interleave(num_frames, dim=0)
+
+ if isinstance(controlnet_keep[i], list):
+ cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
+ else:
+ controlnet_cond_scale = controlnet_conditioning_scale
+ if isinstance(controlnet_cond_scale, list):
+ controlnet_cond_scale = controlnet_cond_scale[0]
+ cond_scale = controlnet_cond_scale * controlnet_keep[i]
+
+ control_model_input = torch.transpose(control_model_input, 1, 2)
+ control_model_input = control_model_input.reshape(
+ (-1, control_model_input.shape[2], control_model_input.shape[3], control_model_input.shape[4])
+ )
+
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
+ control_model_input,
+ t,
+ encoder_hidden_states=controlnet_prompt_embeds,
+ controlnet_cond=conditioning_frames,
+ conditioning_scale=cond_scale,
+ guess_mode=guess_mode,
+ return_dict=False,
+ )
+
+ # predict the noise residual
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs=self.cross_attention_kwargs,
+ added_cond_kwargs=added_cond_kwargs,
+ down_block_additional_residuals=down_block_res_samples,
+ mid_block_additional_residual=mid_block_res_sample,
+ ).sample
+
+ # perform guidance
+ if self.do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, latents)
+
+ # 9. Post processing
+ if output_type == "latent":
+ video = latents
+ else:
+ video_tensor = self.decode_latents(latents)
+ video = tensor2vid(video_tensor, self.image_processor, output_type=output_type)
+
+ # 10. Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (video,)
+
+ return AnimateDiffPipelineOutput(frames=video)
diff --git a/diffusers/examples/community/pipeline_animatediff_img2video.py b/diffusers/examples/community/pipeline_animatediff_img2video.py
new file mode 100644
index 0000000000000000000000000000000000000000..0a578d4b8ef6d251c327f2e2d3cd9d1b2845e6a3
--- /dev/null
+++ b/diffusers/examples/community/pipeline_animatediff_img2video.py
@@ -0,0 +1,984 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# Note:
+# This pipeline relies on a "hack" discovered by the community that allows
+# the generation of videos given an input image with AnimateDiff. It works
+# by creating a copy of the image `num_frames` times and progressively adding
+# more noise to the image based on the strength and latent interpolation method.
+
+import inspect
+from types import FunctionType
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import numpy as np
+import torch
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
+
+from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
+from diffusers.loaders import IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
+from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel
+from diffusers.models.lora import adjust_lora_scale_text_encoder
+from diffusers.models.unet_motion_model import MotionAdapter
+from diffusers.pipelines.animatediff.pipeline_output import AnimateDiffPipelineOutput
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
+from diffusers.schedulers import (
+ DDIMScheduler,
+ DPMSolverMultistepScheduler,
+ EulerAncestralDiscreteScheduler,
+ EulerDiscreteScheduler,
+ LMSDiscreteScheduler,
+ PNDMScheduler,
+)
+from diffusers.utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
+from diffusers.utils.torch_utils import randn_tensor
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import MotionAdapter, DiffusionPipeline, DDIMScheduler
+ >>> from diffusers.utils import export_to_gif, load_image
+
+ >>> model_id = "SG161222/Realistic_Vision_V5.1_noVAE"
+ >>> adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2")
+ >>> pipe = DiffusionPipeline.from_pretrained("SG161222/Realistic_Vision_V5.1_noVAE", motion_adapter=adapter, custom_pipeline="pipeline_animatediff_img2video").to("cuda")
+ >>> pipe.scheduler = pipe.scheduler = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler", clip_sample=False, timestep_spacing="linspace", beta_schedule="linear", steps_offset=1)
+
+ >>> image = load_image("snail.png")
+ >>> output = pipe(image=image, prompt="A snail moving on the ground", strength=0.8, latent_interpolation_method="slerp")
+ >>> frames = output.frames[0]
+ >>> export_to_gif(frames, "animation.gif")
+ ```
+"""
+
+
+def lerp(
+ v0: torch.Tensor,
+ v1: torch.Tensor,
+ t: Union[float, torch.Tensor],
+) -> torch.Tensor:
+ r"""
+ Linear Interpolation between two tensors.
+
+ Args:
+ v0 (`torch.Tensor`): First tensor.
+ v1 (`torch.Tensor`): Second tensor.
+ t: (`float` or `torch.Tensor`): Interpolation factor.
+ """
+ t_is_float = False
+ input_device = v0.device
+ v0 = v0.cpu().numpy()
+ v1 = v1.cpu().numpy()
+
+ if isinstance(t, torch.Tensor):
+ t = t.cpu().numpy()
+ else:
+ t_is_float = True
+ t = np.array([t], dtype=v0.dtype)
+
+ t = t[..., None]
+ v0 = v0[None, ...]
+ v1 = v1[None, ...]
+ v2 = (1 - t) * v0 + t * v1
+
+ if t_is_float and v0.ndim > 1:
+ assert v2.shape[0] == 1
+ v2 = np.squeeze(v2, axis=0)
+
+ v2 = torch.from_numpy(v2).to(input_device)
+ return v2
+
+
+def slerp(
+ v0: torch.Tensor,
+ v1: torch.Tensor,
+ t: Union[float, torch.Tensor],
+ DOT_THRESHOLD: float = 0.9995,
+) -> torch.Tensor:
+ r"""
+ Spherical Linear Interpolation between two tensors.
+
+ Args:
+ v0 (`torch.Tensor`): First tensor.
+ v1 (`torch.Tensor`): Second tensor.
+ t: (`float` or `torch.Tensor`): Interpolation factor.
+ DOT_THRESHOLD (`float`):
+ Dot product threshold exceeding which linear interpolation will be used
+ because input tensors are close to parallel.
+ """
+ t_is_float = False
+ input_device = v0.device
+ v0 = v0.cpu().numpy()
+ v1 = v1.cpu().numpy()
+
+ if isinstance(t, torch.Tensor):
+ t = t.cpu().numpy()
+ else:
+ t_is_float = True
+ t = np.array([t], dtype=v0.dtype)
+
+ dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1)))
+
+ if np.abs(dot) > DOT_THRESHOLD:
+ # v0 and v1 are close to parallel, so use linear interpolation instead
+ v2 = lerp(v0, v1, t)
+ else:
+ theta_0 = np.arccos(dot)
+ sin_theta_0 = np.sin(theta_0)
+ theta_t = theta_0 * t
+ sin_theta_t = np.sin(theta_t)
+ s0 = np.sin(theta_0 - theta_t) / sin_theta_0
+ s1 = sin_theta_t / sin_theta_0
+ s0 = s0[..., None]
+ s1 = s1[..., None]
+ v0 = v0[None, ...]
+ v1 = v1[None, ...]
+ v2 = s0 * v0 + s1 * v1
+
+ if t_is_float and v0.ndim > 1:
+ assert v2.shape[0] == 1
+ v2 = np.squeeze(v2, axis=0)
+
+ v2 = torch.from_numpy(v2).to(input_device)
+ return v2
+
+
+# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.tensor2vid
+def tensor2vid(video: torch.Tensor, processor, output_type="np"):
+ batch_size, channels, num_frames, height, width = video.shape
+ outputs = []
+ for batch_idx in range(batch_size):
+ batch_vid = video[batch_idx].permute(1, 0, 2, 3)
+ batch_output = processor.postprocess(batch_vid, output_type)
+
+ outputs.append(batch_output)
+
+ if output_type == "np":
+ outputs = np.stack(outputs)
+
+ elif output_type == "pt":
+ outputs = torch.stack(outputs)
+
+ elif not output_type == "pil":
+ raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']")
+
+ return outputs
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ **kwargs,
+):
+ """
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used,
+ `timesteps` must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
+ timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
+ must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class AnimateDiffImgToVideoPipeline(
+ DiffusionPipeline,
+ StableDiffusionMixin,
+ TextualInversionLoaderMixin,
+ IPAdapterMixin,
+ StableDiffusionLoraLoaderMixin,
+):
+ r"""
+ Pipeline for image-to-video generation.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
+
+ The pipeline also inherits the following loading methods:
+ - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
+ - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
+ - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
+ - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
+ tokenizer (`CLIPTokenizer`):
+ A [`~transformers.CLIPTokenizer`] to tokenize text.
+ unet ([`UNet2DConditionModel`]):
+ A [`UNet2DConditionModel`] used to create a UNetMotionModel to denoise the encoded video latents.
+ motion_adapter ([`MotionAdapter`]):
+ A [`MotionAdapter`] to be used in combination with `unet` to denoise the encoded video latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ """
+
+ model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
+ _optional_components = ["feature_extractor", "image_encoder"]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ motion_adapter: MotionAdapter,
+ scheduler: Union[
+ DDIMScheduler,
+ PNDMScheduler,
+ LMSDiscreteScheduler,
+ EulerDiscreteScheduler,
+ EulerAncestralDiscreteScheduler,
+ DPMSolverMultistepScheduler,
+ ],
+ feature_extractor: CLIPImageProcessor = None,
+ image_encoder: CLIPVisionModelWithProjection = None,
+ ):
+ super().__init__()
+ unet = UNetMotionModel.from_unet2d(unet, motion_adapter)
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ motion_adapter=motion_adapter,
+ scheduler=scheduler,
+ feature_extractor=feature_extractor,
+ image_encoder=image_encoder,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt
+ def encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt=None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ lora_scale: Optional[float] = None,
+ clip_skip: Optional[int] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ lora_scale (`float`, *optional*):
+ A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+ """
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ # dynamically adjust the LoRA scale
+ if not USE_PEFT_BACKEND:
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
+ else:
+ scale_lora_layers(self.text_encoder, lora_scale)
+
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = self.tokenizer.batch_decode(
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
+ )
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = text_inputs.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ if clip_skip is None:
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
+ prompt_embeds = prompt_embeds[0]
+ else:
+ prompt_embeds = self.text_encoder(
+ text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True
+ )
+ # Access the `hidden_states` first, that contains a tuple of
+ # all the hidden states from the encoder layers. Then index into
+ # the tuple to access the hidden states from the desired layer.
+ prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
+ # We also need to apply the final LayerNorm here to not mess with the
+ # representations. The `last_hidden_states` that we typically use for
+ # obtaining the final prompt representations passes through the LayerNorm
+ # layer.
+ prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
+
+ if self.text_encoder is not None:
+ prompt_embeds_dtype = self.text_encoder.dtype
+ elif self.unet is not None:
+ prompt_embeds_dtype = self.unet.dtype
+ else:
+ prompt_embeds_dtype = prompt_embeds.dtype
+
+ prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = uncond_input.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ negative_prompt_embeds = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ negative_prompt_embeds = negative_prompt_embeds[0]
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder, lora_scale)
+
+ return prompt_embeds, negative_prompt_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
+ def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
+ dtype = next(self.image_encoder.parameters()).dtype
+
+ if not isinstance(image, torch.Tensor):
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
+
+ image = image.to(device=device, dtype=dtype)
+ if output_hidden_states:
+ image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
+ image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
+ uncond_image_enc_hidden_states = self.image_encoder(
+ torch.zeros_like(image), output_hidden_states=True
+ ).hidden_states[-2]
+ uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
+ num_images_per_prompt, dim=0
+ )
+ return image_enc_hidden_states, uncond_image_enc_hidden_states
+ else:
+ image_embeds = self.image_encoder(image).image_embeds
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+ uncond_image_embeds = torch.zeros_like(image_embeds)
+
+ return image_embeds, uncond_image_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
+ def prepare_ip_adapter_image_embeds(
+ self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt
+ ):
+ if ip_adapter_image_embeds is None:
+ if not isinstance(ip_adapter_image, list):
+ ip_adapter_image = [ip_adapter_image]
+
+ if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
+ raise ValueError(
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
+ )
+
+ image_embeds = []
+ for single_ip_adapter_image, image_proj_layer in zip(
+ ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
+ ):
+ output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
+ single_image_embeds, single_negative_image_embeds = self.encode_image(
+ single_ip_adapter_image, device, 1, output_hidden_state
+ )
+ single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
+ single_negative_image_embeds = torch.stack(
+ [single_negative_image_embeds] * num_images_per_prompt, dim=0
+ )
+
+ if self.do_classifier_free_guidance:
+ single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
+ single_image_embeds = single_image_embeds.to(device)
+
+ image_embeds.append(single_image_embeds)
+ else:
+ image_embeds = ip_adapter_image_embeds
+ return image_embeds
+
+ # Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents
+ def decode_latents(self, latents):
+ latents = 1 / self.vae.config.scaling_factor * latents
+
+ batch_size, channels, num_frames, height, width = latents.shape
+ latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width)
+
+ image = self.vae.decode(latents).sample
+ video = (
+ image[None, :]
+ .reshape(
+ (
+ batch_size,
+ num_frames,
+ -1,
+ )
+ + image.shape[2:]
+ )
+ .permute(0, 2, 1, 3, 4)
+ )
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+ video = video.float()
+ return video
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ callback_steps,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ latent_interpolation_method=None,
+ ):
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ if latent_interpolation_method is not None:
+ if latent_interpolation_method not in ["lerp", "slerp"] and not isinstance(
+ latent_interpolation_method, FunctionType
+ ):
+ raise ValueError(
+ "`latent_interpolation_method` must be one of `lerp`, `slerp` or a Callable[[torch.Tensor, torch.Tensor, int], torch.Tensor]"
+ )
+
+ def prepare_latents(
+ self,
+ image,
+ strength,
+ batch_size,
+ num_channels_latents,
+ num_frames,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ latents=None,
+ latent_interpolation_method="slerp",
+ ):
+ shape = (
+ batch_size,
+ num_channels_latents,
+ num_frames,
+ height // self.vae_scale_factor,
+ width // self.vae_scale_factor,
+ )
+
+ if latents is None:
+ image = image.to(device=device, dtype=dtype)
+
+ if image.shape[1] == 4:
+ latents = image
+ else:
+ # make sure the VAE is in float32 mode, as it overflows in float16
+ if self.vae.config.force_upcast:
+ image = image.float()
+ self.vae.to(dtype=torch.float32)
+
+ if isinstance(generator, list):
+ if len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ init_latents = [
+ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
+ for i in range(batch_size)
+ ]
+ init_latents = torch.cat(init_latents, dim=0)
+ else:
+ init_latents = retrieve_latents(self.vae.encode(image), generator=generator)
+
+ if self.vae.config.force_upcast:
+ self.vae.to(dtype)
+
+ init_latents = init_latents.to(dtype)
+ init_latents = self.vae.config.scaling_factor * init_latents
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ latents = latents * self.scheduler.init_noise_sigma
+
+ if latent_interpolation_method == "lerp":
+
+ def latent_cls(v0, v1, index):
+ return lerp(v0, v1, index / num_frames * (1 - strength))
+ elif latent_interpolation_method == "slerp":
+
+ def latent_cls(v0, v1, index):
+ return slerp(v0, v1, index / num_frames * (1 - strength))
+ else:
+ latent_cls = latent_interpolation_method
+
+ for i in range(num_frames):
+ latents[:, :, i, :, :] = latent_cls(latents[:, :, i, :, :], init_latents, i)
+ else:
+ if shape != latents.shape:
+ # [B, C, F, H, W]
+ raise ValueError(f"`latents` expected to have {shape=}, but found {latents.shape=}")
+ latents = latents.to(device, dtype=dtype)
+
+ return latents
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ image: PipelineImageInput,
+ prompt: Optional[Union[str, List[str]]] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_frames: int = 16,
+ num_inference_steps: int = 50,
+ timesteps: Optional[List[int]] = None,
+ guidance_scale: float = 7.5,
+ strength: float = 0.8,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_videos_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ ip_adapter_image: Optional[PipelineImageInput] = None,
+ ip_adapter_image_embeds: Optional[PipelineImageInput] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
+ callback_steps: Optional[int] = 1,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ clip_skip: Optional[int] = None,
+ latent_interpolation_method: Union[str, Callable[[torch.Tensor, torch.Tensor, int], torch.Tensor]] = "slerp",
+ ):
+ r"""
+ The call function to the pipeline for generation.
+
+ Args:
+ image (`PipelineImageInput`):
+ The input image to condition the generation on.
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
+ height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
+ The height in pixels of the generated video.
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
+ The width in pixels of the generated video.
+ num_frames (`int`, *optional*, defaults to 16):
+ The number of video frames that are generated. Defaults to 16 frames which at 8 frames per seconds
+ amounts to 2 seconds of video.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality videos at the
+ expense of slower inference.
+ strength (`float`, *optional*, defaults to 0.8):
+ Higher strength leads to more differences between original image and generated video.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ A higher guidance scale value encourages the model to generate images closely linked to the text
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
+ generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor is generated by sampling using the supplied random `generator`. Latents should be of shape
+ `(batch_size, num_channel, num_frames, height, width)`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
+ provided, text embeddings are generated from the `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
+ ip_adapter_image: (`PipelineImageInput`, *optional*):
+ Optional image input to work with IP Adapters.
+ ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters.
+ Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding
+ if `do_classifier_free_guidance` is set to `True`.
+ If not provided, embeddings are computed from the `ip_adapter_image` input argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generated video. Choose between `torch.Tensor`, `PIL.Image` or
+ `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`AnimateDiffImgToVideoPipelineOutput`] instead
+ of a plain tuple.
+ callback (`Callable`, *optional*):
+ A function that calls every `callback_steps` steps during inference. The function is called with the
+ following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function is called. If not specified, the callback is called at
+ every step.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+ latent_interpolation_method (`str` or `Callable[[torch.Tensor, torch.Tensor, int], torch.Tensor]]`, *optional*):
+ Must be one of "lerp", "slerp" or a callable that takes in a random noisy latent, image latent and a frame index
+ as input and returns an initial latent for sampling.
+ Examples:
+
+ Returns:
+ [`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] is
+ returned, otherwise a `tuple` is returned where the first element is a list with the generated frames.
+ """
+ # 0. Default height and width to unet
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+ num_videos_per_prompt = 1
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt=prompt,
+ height=height,
+ width=width,
+ callback_steps=callback_steps,
+ negative_prompt=negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ latent_interpolation_method=latent_interpolation_method,
+ )
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ text_encoder_lora_scale = (
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
+ )
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
+ prompt,
+ device,
+ num_videos_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ lora_scale=text_encoder_lora_scale,
+ clip_skip=clip_skip,
+ )
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ if do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+
+ if ip_adapter_image is not None:
+ image_embeds = self.prepare_ip_adapter_image_embeds(
+ ip_adapter_image, ip_adapter_image_embeds, device, batch_size * num_videos_per_prompt
+ )
+
+ # 4. Preprocess image
+ image = self.image_processor.preprocess(image, height=height, width=width)
+
+ # 5. Prepare timesteps
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
+
+ # 6. Prepare latent variables
+ num_channels_latents = self.unet.config.in_channels
+ latents = self.prepare_latents(
+ image=image,
+ strength=strength,
+ batch_size=batch_size * num_videos_per_prompt,
+ num_channels_latents=num_channels_latents,
+ num_frames=num_frames,
+ height=height,
+ width=width,
+ dtype=prompt_embeds.dtype,
+ device=device,
+ generator=generator,
+ latents=latents,
+ latent_interpolation_method=latent_interpolation_method,
+ )
+
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 8. Add image embeds for IP-Adapter
+ added_cond_kwargs = (
+ {"image_embeds": image_embeds}
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None
+ else None
+ )
+
+ # 9. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs=cross_attention_kwargs,
+ added_cond_kwargs=added_cond_kwargs,
+ ).sample
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, latents)
+
+ if output_type == "latent":
+ return AnimateDiffPipelineOutput(frames=latents)
+
+ # 10. Post-processing
+ if output_type == "latent":
+ video = latents
+ else:
+ video_tensor = self.decode_latents(latents)
+ video = tensor2vid(video_tensor, self.image_processor, output_type=output_type)
+
+ # 11. Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (video,)
+
+ return AnimateDiffPipelineOutput(frames=video)
diff --git a/diffusers/examples/community/pipeline_animatediff_ipex.py b/diffusers/examples/community/pipeline_animatediff_ipex.py
new file mode 100644
index 0000000000000000000000000000000000000000..dc65e76bc43b03f7f8315a924e1a7137c5487cfd
--- /dev/null
+++ b/diffusers/examples/community/pipeline_animatediff_ipex.py
@@ -0,0 +1,1002 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import intel_extension_for_pytorch as ipex
+import torch
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
+
+from diffusers.image_processor import PipelineImageInput
+from diffusers.loaders import IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
+from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel
+from diffusers.models.lora import adjust_lora_scale_text_encoder
+from diffusers.models.unets.unet_motion_model import MotionAdapter
+from diffusers.pipelines.animatediff.pipeline_output import AnimateDiffPipelineOutput
+from diffusers.pipelines.free_init_utils import FreeInitMixin
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
+from diffusers.schedulers import (
+ DDIMScheduler,
+ DPMSolverMultistepScheduler,
+ EulerAncestralDiscreteScheduler,
+ EulerDiscreteScheduler,
+ LMSDiscreteScheduler,
+ PNDMScheduler,
+)
+from diffusers.utils import (
+ USE_PEFT_BACKEND,
+ logging,
+ replace_example_docstring,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
+from diffusers.utils.torch_utils import randn_tensor
+from diffusers.video_processor import VideoProcessor
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import MotionAdapter, AnimateDiffPipelineIpex, EulerDiscreteScheduler
+ >>> from diffusers.utils import export_to_gif
+ >>> from safetensors.torch import load_file
+
+ >>> device = "cpu"
+ >>> dtype = torch.float32
+
+ >>> # ByteDance/AnimateDiff-Lightning, a distilled version of AnimateDiff SD1.5 v2,
+ >>> # a lightning-fast text-to-video generation model which can generate videos
+ >>> # more than ten times faster than the original AnimateDiff.
+ >>> step = 8 # Options: [1,2,4,8]
+ >>> repo = "ByteDance/AnimateDiff-Lightning"
+ >>> ckpt = f"animatediff_lightning_{step}step_diffusers.safetensors"
+ >>> base = "emilianJR/epiCRealism" # Choose to your favorite base model.
+
+ >>> adapter = MotionAdapter().to(device, dtype)
+ >>> adapter.load_state_dict(load_file(hf_hub_download(repo ,ckpt), device=device))
+
+ >>> pipe = AnimateDiffPipelineIpex.from_pretrained(base, motion_adapter=adapter, torch_dtype=dtype).to(device)
+ >>> pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", beta_schedule="linear")
+
+ >>> # For Float32
+ >>> pipe.prepare_for_ipex(torch.float32, prompt = "A girl smiling")
+ >>> # For BFloat16
+ >>> pipe.prepare_for_ipex(torch.bfloat16, prompt = "A girl smiling")
+
+ >>> # For Float32
+ >>> output = pipe(prompt = "A girl smiling", guidance_scale=1.0, num_inference_steps = step)
+ >>> # For BFloat16
+ >>> with torch.cpu.amp.autocast(enabled=True, dtype=torch.bfloat16):
+ >>> output = pipe(prompt = "A girl smiling", guidance_scale=1.0, num_inference_steps = step)
+
+ >>> frames = output.frames[0]
+ >>> export_to_gif(frames, "animation.gif")
+ ```
+"""
+
+
+class AnimateDiffPipelineIpex(
+ DiffusionPipeline,
+ StableDiffusionMixin,
+ TextualInversionLoaderMixin,
+ IPAdapterMixin,
+ LoraLoaderMixin,
+ FreeInitMixin,
+):
+ r"""
+ Pipeline for text-to-video generation.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
+
+ The pipeline also inherits the following loading methods:
+ - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
+ - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
+ - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
+ - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
+ tokenizer (`CLIPTokenizer`):
+ A [`~transformers.CLIPTokenizer`] to tokenize text.
+ unet ([`UNet2DConditionModel`]):
+ A [`UNet2DConditionModel`] used to create a UNetMotionModel to denoise the encoded video latents.
+ motion_adapter ([`MotionAdapter`]):
+ A [`MotionAdapter`] to be used in combination with `unet` to denoise the encoded video latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ """
+
+ model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
+ _optional_components = ["feature_extractor", "image_encoder", "motion_adapter"]
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: Union[UNet2DConditionModel, UNetMotionModel],
+ motion_adapter: MotionAdapter,
+ scheduler: Union[
+ DDIMScheduler,
+ PNDMScheduler,
+ LMSDiscreteScheduler,
+ EulerDiscreteScheduler,
+ EulerAncestralDiscreteScheduler,
+ DPMSolverMultistepScheduler,
+ ],
+ feature_extractor: CLIPImageProcessor = None,
+ image_encoder: CLIPVisionModelWithProjection = None,
+ ):
+ super().__init__()
+ if isinstance(unet, UNet2DConditionModel):
+ unet = UNetMotionModel.from_unet2d(unet, motion_adapter)
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ motion_adapter=motion_adapter,
+ scheduler=scheduler,
+ feature_extractor=feature_extractor,
+ image_encoder=image_encoder,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.video_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor)
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt
+ def encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt=None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ lora_scale: Optional[float] = None,
+ clip_skip: Optional[int] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ lora_scale (`float`, *optional*):
+ A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+ """
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ # dynamically adjust the LoRA scale
+ if not USE_PEFT_BACKEND:
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
+ else:
+ scale_lora_layers(self.text_encoder, lora_scale)
+
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ # textual inversion: process multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = self.tokenizer.batch_decode(
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
+ )
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = text_inputs.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ if clip_skip is None:
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
+ prompt_embeds = prompt_embeds[0]
+ else:
+ prompt_embeds = self.text_encoder(
+ text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True
+ )
+ # Access the `hidden_states` first, that contains a tuple of
+ # all the hidden states from the encoder layers. Then index into
+ # the tuple to access the hidden states from the desired layer.
+ prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
+ # We also need to apply the final LayerNorm here to not mess with the
+ # representations. The `last_hidden_states` that we typically use for
+ # obtaining the final prompt representations passes through the LayerNorm
+ # layer.
+ prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
+
+ if self.text_encoder is not None:
+ prompt_embeds_dtype = self.text_encoder.dtype
+ elif self.unet is not None:
+ prompt_embeds_dtype = self.unet.dtype
+ else:
+ prompt_embeds_dtype = prompt_embeds.dtype
+
+ prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ # textual inversion: process multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = uncond_input.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ negative_prompt_embeds = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ negative_prompt_embeds = negative_prompt_embeds[0]
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ if self.text_encoder is not None:
+ if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder, lora_scale)
+
+ return prompt_embeds, negative_prompt_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
+ def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
+ dtype = next(self.image_encoder.parameters()).dtype
+
+ if not isinstance(image, torch.Tensor):
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
+
+ image = image.to(device=device, dtype=dtype)
+ if output_hidden_states:
+ image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
+ image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
+ uncond_image_enc_hidden_states = self.image_encoder(
+ torch.zeros_like(image), output_hidden_states=True
+ ).hidden_states[-2]
+ uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
+ num_images_per_prompt, dim=0
+ )
+ return image_enc_hidden_states, uncond_image_enc_hidden_states
+ else:
+ image_embeds = self.image_encoder(image).image_embeds
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+ uncond_image_embeds = torch.zeros_like(image_embeds)
+
+ return image_embeds, uncond_image_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
+ def prepare_ip_adapter_image_embeds(
+ self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
+ ):
+ if ip_adapter_image_embeds is None:
+ if not isinstance(ip_adapter_image, list):
+ ip_adapter_image = [ip_adapter_image]
+
+ if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
+ raise ValueError(
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
+ )
+
+ image_embeds = []
+ for single_ip_adapter_image, image_proj_layer in zip(
+ ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
+ ):
+ output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
+ single_image_embeds, single_negative_image_embeds = self.encode_image(
+ single_ip_adapter_image, device, 1, output_hidden_state
+ )
+ single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
+ single_negative_image_embeds = torch.stack(
+ [single_negative_image_embeds] * num_images_per_prompt, dim=0
+ )
+
+ if do_classifier_free_guidance:
+ single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
+ single_image_embeds = single_image_embeds.to(device)
+
+ image_embeds.append(single_image_embeds)
+ else:
+ repeat_dims = [1]
+ image_embeds = []
+ for single_image_embeds in ip_adapter_image_embeds:
+ if do_classifier_free_guidance:
+ single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
+ single_image_embeds = single_image_embeds.repeat(
+ num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
+ )
+ single_negative_image_embeds = single_negative_image_embeds.repeat(
+ num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
+ )
+ single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
+ else:
+ single_image_embeds = single_image_embeds.repeat(
+ num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
+ )
+ image_embeds.append(single_image_embeds)
+
+ return image_embeds
+
+ # Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents
+ def decode_latents(self, latents):
+ latents = 1 / self.vae.config.scaling_factor * latents
+
+ batch_size, channels, num_frames, height, width = latents.shape
+ latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width)
+
+ image = self.vae.decode(latents).sample
+ video = image[None, :].reshape((batch_size, num_frames, -1) + image.shape[2:]).permute(0, 2, 1, 3, 4)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+ video = video.float()
+ return video
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ ip_adapter_image=None,
+ ip_adapter_image_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ ):
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
+ raise ValueError(
+ "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
+ )
+
+ if ip_adapter_image_embeds is not None:
+ if not isinstance(ip_adapter_image_embeds, list):
+ raise ValueError(
+ f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
+ )
+ elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
+ raise ValueError(
+ f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
+ )
+
+ # Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.prepare_latents
+ def prepare_latents(
+ self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
+ ):
+ shape = (
+ batch_size,
+ num_channels_latents,
+ num_frames,
+ height // self.vae_scale_factor,
+ width // self.vae_scale_factor,
+ )
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=torch.float32)
+ else:
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def clip_skip(self):
+ return self._clip_skip
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1
+
+ @property
+ def cross_attention_kwargs(self):
+ return self._cross_attention_kwargs
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ num_frames: Optional[int] = 16,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_videos_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ ip_adapter_image: Optional[PipelineImageInput] = None,
+ ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ clip_skip: Optional[int] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ ):
+ r"""
+ The call function to the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
+ height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
+ The height in pixels of the generated video.
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
+ The width in pixels of the generated video.
+ num_frames (`int`, *optional*, defaults to 16):
+ The number of video frames that are generated. Defaults to 16 frames which at 8 frames per seconds
+ amounts to 2 seconds of video.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality videos at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ A higher guidance scale value encourages the model to generate images closely linked to the text
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
+ generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor is generated by sampling using the supplied random `generator`. Latents should be of shape
+ `(batch_size, num_channel, num_frames, height, width)`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
+ provided, text embeddings are generated from the `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
+ ip_adapter_image: (`PipelineImageInput`, *optional*):
+ Optional image input to work with IP Adapters.
+ ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
+ contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generated video. Choose between `torch.Tensor`, `PIL.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] instead
+ of a plain tuple.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] is
+ returned, otherwise a `tuple` is returned where the first element is a list with the generated frames.
+ """
+
+ # 0. Default height and width to unet
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+ num_videos_per_prompt = 1
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ height,
+ width,
+ negative_prompt,
+ prompt_embeds,
+ negative_prompt_embeds,
+ ip_adapter_image,
+ ip_adapter_image_embeds,
+ callback_on_step_end_tensor_inputs,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._clip_skip = clip_skip
+ self._cross_attention_kwargs = cross_attention_kwargs
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ # 3. Encode input prompt
+ text_encoder_lora_scale = (
+ self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
+ )
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
+ prompt,
+ device,
+ num_videos_per_prompt,
+ self.do_classifier_free_guidance,
+ negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ lora_scale=text_encoder_lora_scale,
+ clip_skip=self.clip_skip,
+ )
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ if self.do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
+ image_embeds = self.prepare_ip_adapter_image_embeds(
+ ip_adapter_image,
+ ip_adapter_image_embeds,
+ device,
+ batch_size * num_videos_per_prompt,
+ self.do_classifier_free_guidance,
+ )
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.unet.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_videos_per_prompt,
+ num_channels_latents,
+ num_frames,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 7. Add image embeds for IP-Adapter
+ added_cond_kwargs = (
+ {"image_embeds": image_embeds}
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None
+ else None
+ )
+
+ num_free_init_iters = self._free_init_num_iters if self.free_init_enabled else 1
+ for free_init_iter in range(num_free_init_iters):
+ if self.free_init_enabled:
+ latents, timesteps = self._apply_free_init(
+ latents, free_init_iter, num_inference_steps, device, latents.dtype, generator
+ )
+
+ self._num_timesteps = len(timesteps)
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+
+ # 8. Denoising loop
+ with self.progress_bar(total=self._num_timesteps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ # cross_attention_kwargs=cross_attention_kwargs,
+ # added_cond_kwargs=added_cond_kwargs,
+ # ).sample
+ )["sample"]
+
+ # perform guidance
+ if self.do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ # 9. Post processing
+ if output_type == "latent":
+ video = latents
+ else:
+ video_tensor = self.decode_latents(latents)
+ video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type)
+
+ # 10. Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (video,)
+
+ return AnimateDiffPipelineOutput(frames=video)
+
+ @torch.no_grad()
+ def prepare_for_ipex(
+ self,
+ dtype=torch.float32,
+ prompt: Union[str, List[str]] = None,
+ num_frames: Optional[int] = 16,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_videos_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ ip_adapter_image: Optional[PipelineImageInput] = None,
+ ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ clip_skip: Optional[int] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ ):
+ # 0. Default height and width to unet
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+ num_videos_per_prompt = 1
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ height,
+ width,
+ negative_prompt,
+ prompt_embeds,
+ negative_prompt_embeds,
+ ip_adapter_image,
+ ip_adapter_image_embeds,
+ callback_on_step_end_tensor_inputs,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._clip_skip = clip_skip
+ self._cross_attention_kwargs = cross_attention_kwargs
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ # 3. Encode input prompt
+ text_encoder_lora_scale = (
+ self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
+ )
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
+ prompt,
+ device,
+ num_videos_per_prompt,
+ self.do_classifier_free_guidance,
+ negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ lora_scale=text_encoder_lora_scale,
+ clip_skip=self.clip_skip,
+ )
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ if self.do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.unet.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_videos_per_prompt,
+ num_channels_latents,
+ num_frames,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ num_free_init_iters = self._free_init_num_iters if self.free_init_enabled else 1
+ for free_init_iter in range(num_free_init_iters):
+ if self.free_init_enabled:
+ latents, timesteps = self._apply_free_init(
+ latents, free_init_iter, num_inference_steps, device, latents.dtype, generator
+ )
+
+ self._num_timesteps = len(timesteps)
+
+ dummy = timesteps[0]
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, dummy)
+
+ self.unet = self.unet.to(memory_format=torch.channels_last)
+ self.vae.decoder = self.vae.decoder.to(memory_format=torch.channels_last)
+ self.text_encoder = self.text_encoder.to(memory_format=torch.channels_last)
+
+ unet_input_example = {
+ "sample": latent_model_input,
+ "timestep": dummy,
+ "encoder_hidden_states": prompt_embeds,
+ }
+
+ fake_latents = 1 / self.vae.config.scaling_factor * latents
+ batch_size, channels, num_frames, height, width = fake_latents.shape
+ fake_latents = fake_latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width)
+ vae_decoder_input_example = fake_latents
+
+ # optimize with ipex
+ if dtype == torch.bfloat16:
+ self.unet = ipex.optimize(self.unet.eval(), dtype=torch.bfloat16, inplace=True)
+ self.vae.decoder = ipex.optimize(self.vae.decoder.eval(), dtype=torch.bfloat16, inplace=True)
+ self.text_encoder = ipex.optimize(self.text_encoder.eval(), dtype=torch.bfloat16, inplace=True)
+ elif dtype == torch.float32:
+ self.unet = ipex.optimize(
+ self.unet.eval(),
+ dtype=torch.float32,
+ inplace=True,
+ # sample_input=unet_input_example,
+ level="O1",
+ weights_prepack=True,
+ auto_kernel_selection=False,
+ )
+ self.vae.decoder = ipex.optimize(
+ self.vae.decoder.eval(),
+ dtype=torch.float32,
+ inplace=True,
+ level="O1",
+ weights_prepack=True,
+ auto_kernel_selection=False,
+ )
+ self.text_encoder = ipex.optimize(
+ self.text_encoder.eval(),
+ dtype=torch.float32,
+ inplace=True,
+ level="O1",
+ weights_prepack=True,
+ auto_kernel_selection=False,
+ )
+ else:
+ raise ValueError(" The value of 'dtype' should be 'torch.bfloat16' or 'torch.float32' !")
+
+ # trace unet model to get better performance on IPEX
+ with torch.cpu.amp.autocast(enabled=dtype == torch.bfloat16), torch.no_grad():
+ unet_trace_model = torch.jit.trace(
+ self.unet, example_kwarg_inputs=unet_input_example, check_trace=False, strict=False
+ )
+ unet_trace_model = torch.jit.freeze(unet_trace_model)
+ self.unet.forward = unet_trace_model.forward
+
+ # trace vae.decoder model to get better performance on IPEX
+ with torch.cpu.amp.autocast(enabled=dtype == torch.bfloat16), torch.no_grad():
+ vae_decoder_trace_model = torch.jit.trace(
+ self.vae.decoder, vae_decoder_input_example, check_trace=False, strict=False
+ )
+ vae_decoder_trace_model = torch.jit.freeze(vae_decoder_trace_model)
+ self.vae.decoder.forward = vae_decoder_trace_model.forward
diff --git a/diffusers/examples/community/pipeline_demofusion_sdxl.py b/diffusers/examples/community/pipeline_demofusion_sdxl.py
new file mode 100644
index 0000000000000000000000000000000000000000..f83d1b4014200681d8e96d920b588c8fc228f9e3
--- /dev/null
+++ b/diffusers/examples/community/pipeline_demofusion_sdxl.py
@@ -0,0 +1,1386 @@
+import inspect
+import os
+import random
+import warnings
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import matplotlib.pyplot as plt
+import torch
+import torch.nn.functional as F
+from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
+
+from diffusers.image_processor import VaeImageProcessor
+from diffusers.loaders import (
+ FromSingleFileMixin,
+ StableDiffusionLoraLoaderMixin,
+ TextualInversionLoaderMixin,
+)
+from diffusers.models import AutoencoderKL, UNet2DConditionModel
+from diffusers.models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor
+from diffusers.models.lora import adjust_lora_scale_text_encoder
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
+from diffusers.schedulers import KarrasDiffusionSchedulers
+from diffusers.utils import (
+ is_accelerate_available,
+ is_accelerate_version,
+ is_invisible_watermark_available,
+ logging,
+ replace_example_docstring,
+)
+from diffusers.utils.torch_utils import randn_tensor
+
+
+if is_invisible_watermark_available():
+ from diffusers.pipelines.stable_diffusion_xl.watermark import (
+ StableDiffusionXLWatermarker,
+ )
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import StableDiffusionXLPipeline
+
+ >>> pipe = StableDiffusionXLPipeline.from_pretrained(
+ ... "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
+ ... )
+ >>> pipe = pipe.to("cuda")
+
+ >>> prompt = "a photo of an astronaut riding a horse on mars"
+ >>> image = pipe(prompt).images[0]
+ ```
+"""
+
+
+def gaussian_kernel(kernel_size=3, sigma=1.0, channels=3):
+ x_coord = torch.arange(kernel_size)
+ gaussian_1d = torch.exp(-((x_coord - (kernel_size - 1) / 2) ** 2) / (2 * sigma**2))
+ gaussian_1d = gaussian_1d / gaussian_1d.sum()
+ gaussian_2d = gaussian_1d[:, None] * gaussian_1d[None, :]
+ kernel = gaussian_2d[None, None, :, :].repeat(channels, 1, 1, 1)
+
+ return kernel
+
+
+def gaussian_filter(latents, kernel_size=3, sigma=1.0):
+ channels = latents.shape[1]
+ kernel = gaussian_kernel(kernel_size, sigma, channels).to(latents.device, latents.dtype)
+ blurred_latents = F.conv2d(latents, kernel, padding=kernel_size // 2, groups=channels)
+
+ return blurred_latents
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
+def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
+ """
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
+ """
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
+ # rescale the results from guidance (fixes overexposure)
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
+ return noise_cfg
+
+
+class DemoFusionSDXLPipeline(
+ DiffusionPipeline,
+ StableDiffusionMixin,
+ FromSingleFileMixin,
+ StableDiffusionLoraLoaderMixin,
+ TextualInversionLoaderMixin,
+):
+ r"""
+ Pipeline for text-to-image generation using Stable Diffusion XL.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ In addition the pipeline inherits the following loading methods:
+ - *LoRA*: [`StableDiffusionXLPipeline.load_lora_weights`]
+ - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`]
+
+ as well as the following saving methods:
+ - *LoRA*: [`loaders.StableDiffusionXLPipeline.save_lora_weights`]
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion XL uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ text_encoder_2 ([` CLIPTextModelWithProjection`]):
+ Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
+ specifically the
+ [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
+ variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ tokenizer_2 (`CLIPTokenizer`):
+ Second Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`):
+ Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of
+ `stabilityai/stable-diffusion-xl-base-1-0`.
+ add_watermarker (`bool`, *optional*):
+ Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to
+ watermark output images. If not defined, it will default to True if the package is installed, otherwise no
+ watermarker will be used.
+ """
+
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae"
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ text_encoder_2: CLIPTextModelWithProjection,
+ tokenizer: CLIPTokenizer,
+ tokenizer_2: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: KarrasDiffusionSchedulers,
+ force_zeros_for_empty_prompt: bool = True,
+ add_watermarker: Optional[bool] = None,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ text_encoder_2=text_encoder_2,
+ tokenizer=tokenizer,
+ tokenizer_2=tokenizer_2,
+ unet=unet,
+ scheduler=scheduler,
+ )
+ self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+ self.default_sample_size = self.unet.config.sample_size
+
+ add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
+
+ if add_watermarker:
+ self.watermark = StableDiffusionXLWatermarker()
+ else:
+ self.watermark = None
+
+ def encode_prompt(
+ self,
+ prompt: str,
+ prompt_2: Optional[str] = None,
+ device: Optional[torch.device] = None,
+ num_images_per_prompt: int = 1,
+ do_classifier_free_guidance: bool = True,
+ negative_prompt: Optional[str] = None,
+ negative_prompt_2: Optional[str] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ pooled_prompt_embeds: Optional[torch.Tensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
+ lora_scale: Optional[float] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ used in both text-encoders
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ pooled_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
+ input argument.
+ lora_scale (`float`, *optional*):
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ """
+ device = device or self._execution_device
+
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ # dynamically adjust the LoRA scale
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
+ adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
+
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # Define tokenizers and text encoders
+ tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
+ text_encoders = (
+ [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
+ )
+
+ if prompt_embeds is None:
+ prompt_2 = prompt_2 or prompt
+ # textual inversion: process multi-vector tokens if necessary
+ prompt_embeds_list = []
+ prompts = [prompt, prompt_2]
+ for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, tokenizer)
+
+ text_inputs = tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ prompt_embeds = text_encoder(
+ text_input_ids.to(device),
+ output_hidden_states=True,
+ )
+
+ # We are only ALWAYS interested in the pooled output of the final text encoder
+ pooled_prompt_embeds = prompt_embeds[0]
+ prompt_embeds = prompt_embeds.hidden_states[-2]
+
+ prompt_embeds_list.append(prompt_embeds)
+
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
+
+ # get unconditional embeddings for classifier free guidance
+ zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt
+ if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:
+ negative_prompt_embeds = torch.zeros_like(prompt_embeds)
+ negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
+ elif do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = negative_prompt or ""
+ negative_prompt_2 = negative_prompt_2 or negative_prompt
+
+ uncond_tokens: List[str]
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt, negative_prompt_2]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = [negative_prompt, negative_prompt_2]
+
+ negative_prompt_embeds_list = []
+ for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):
+ if isinstance(self, TextualInversionLoaderMixin):
+ negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer)
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = tokenizer(
+ negative_prompt,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ negative_prompt_embeds = text_encoder(
+ uncond_input.input_ids.to(device),
+ output_hidden_states=True,
+ )
+ # We are only ALWAYS interested in the pooled output of the final text encoder
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
+ negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
+
+ negative_prompt_embeds_list.append(negative_prompt_embeds)
+
+ negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
+
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
+ bs_embed * num_images_per_prompt, -1
+ )
+ if do_classifier_free_guidance:
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
+ bs_embed * num_images_per_prompt, -1
+ )
+
+ return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(
+ self,
+ prompt,
+ prompt_2,
+ height,
+ width,
+ callback_steps,
+ negative_prompt=None,
+ negative_prompt_2=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ pooled_prompt_embeds=None,
+ negative_pooled_prompt_embeds=None,
+ num_images_per_prompt=None,
+ ):
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt_2 is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
+ raise ValueError(
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
+ )
+
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
+ raise ValueError(
+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
+ )
+
+ # DemoFusion specific checks
+ if max(height, width) % 1024 != 0:
+ raise ValueError(
+ f"the larger one of `height` and `width` has to be divisible by 1024 but are {height} and {width}."
+ )
+
+ if num_images_per_prompt != 1:
+ warnings.warn("num_images_per_prompt != 1 is not supported by DemoFusion and will be ignored.")
+ num_images_per_prompt = 1
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
+ shape = (
+ batch_size,
+ num_channels_latents,
+ int(height) // self.vae_scale_factor,
+ int(width) // self.vae_scale_factor,
+ )
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype):
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
+
+ passed_add_embed_dim = (
+ self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim
+ )
+ expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
+
+ if expected_add_embed_dim != passed_add_embed_dim:
+ raise ValueError(
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
+ )
+
+ add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
+ return add_time_ids
+
+ def get_views(self, height, width, window_size=128, stride=64, random_jitter=False):
+ height //= self.vae_scale_factor
+ width //= self.vae_scale_factor
+ num_blocks_height = int((height - window_size) / stride - 1e-6) + 2 if height > window_size else 1
+ num_blocks_width = int((width - window_size) / stride - 1e-6) + 2 if width > window_size else 1
+ total_num_blocks = int(num_blocks_height * num_blocks_width)
+ views = []
+ for i in range(total_num_blocks):
+ h_start = int((i // num_blocks_width) * stride)
+ h_end = h_start + window_size
+ w_start = int((i % num_blocks_width) * stride)
+ w_end = w_start + window_size
+
+ if h_end > height:
+ h_start = int(h_start + height - h_end)
+ h_end = int(height)
+ if w_end > width:
+ w_start = int(w_start + width - w_end)
+ w_end = int(width)
+ if h_start < 0:
+ h_end = int(h_end - h_start)
+ h_start = 0
+ if w_start < 0:
+ w_end = int(w_end - w_start)
+ w_start = 0
+
+ if random_jitter:
+ jitter_range = (window_size - stride) // 4
+ w_jitter = 0
+ h_jitter = 0
+ if (w_start != 0) and (w_end != width):
+ w_jitter = random.randint(-jitter_range, jitter_range)
+ elif (w_start == 0) and (w_end != width):
+ w_jitter = random.randint(-jitter_range, 0)
+ elif (w_start != 0) and (w_end == width):
+ w_jitter = random.randint(0, jitter_range)
+ if (h_start != 0) and (h_end != height):
+ h_jitter = random.randint(-jitter_range, jitter_range)
+ elif (h_start == 0) and (h_end != height):
+ h_jitter = random.randint(-jitter_range, 0)
+ elif (h_start != 0) and (h_end == height):
+ h_jitter = random.randint(0, jitter_range)
+ h_start += h_jitter + jitter_range
+ h_end += h_jitter + jitter_range
+ w_start += w_jitter + jitter_range
+ w_end += w_jitter + jitter_range
+
+ views.append((h_start, h_end, w_start, w_end))
+ return views
+
+ def tiled_decode(self, latents, current_height, current_width):
+ core_size = self.unet.config.sample_size // 4
+ core_stride = core_size
+ pad_size = self.unet.config.sample_size // 4 * 3
+ decoder_view_batch_size = 1
+
+ views = self.get_views(current_height, current_width, stride=core_stride, window_size=core_size)
+ views_batch = [views[i : i + decoder_view_batch_size] for i in range(0, len(views), decoder_view_batch_size)]
+ latents_ = F.pad(latents, (pad_size, pad_size, pad_size, pad_size), "constant", 0)
+ image = torch.zeros(latents.size(0), 3, current_height, current_width).to(latents.device)
+ count = torch.zeros_like(image).to(latents.device)
+ # get the latents corresponding to the current view coordinates
+ with self.progress_bar(total=len(views_batch)) as progress_bar:
+ for j, batch_view in enumerate(views_batch):
+ len(batch_view)
+ latents_for_view = torch.cat(
+ [
+ latents_[:, :, h_start : h_end + pad_size * 2, w_start : w_end + pad_size * 2]
+ for h_start, h_end, w_start, w_end in batch_view
+ ]
+ )
+ image_patch = self.vae.decode(latents_for_view / self.vae.config.scaling_factor, return_dict=False)[0]
+ h_start, h_end, w_start, w_end = views[j]
+ h_start, h_end, w_start, w_end = (
+ h_start * self.vae_scale_factor,
+ h_end * self.vae_scale_factor,
+ w_start * self.vae_scale_factor,
+ w_end * self.vae_scale_factor,
+ )
+ p_h_start, p_h_end, p_w_start, p_w_end = (
+ pad_size * self.vae_scale_factor,
+ image_patch.size(2) - pad_size * self.vae_scale_factor,
+ pad_size * self.vae_scale_factor,
+ image_patch.size(3) - pad_size * self.vae_scale_factor,
+ )
+ image[:, :, h_start:h_end, w_start:w_end] += image_patch[:, :, p_h_start:p_h_end, p_w_start:p_w_end]
+ count[:, :, h_start:h_end, w_start:w_end] += 1
+ progress_bar.update()
+ image = image / count
+
+ return image
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
+ def upcast_vae(self):
+ dtype = self.vae.dtype
+ self.vae.to(dtype=torch.float32)
+ use_torch_2_0_or_xformers = isinstance(
+ self.vae.decoder.mid_block.attentions[0].processor,
+ (AttnProcessor2_0, XFormersAttnProcessor),
+ )
+ # if xformers or torch_2_0 is used attention block does not need
+ # to be in float32 which can save lots of memory
+ if use_torch_2_0_or_xformers:
+ self.vae.post_quant_conv.to(dtype)
+ self.vae.decoder.conv_in.to(dtype)
+ self.vae.decoder.mid_block.to(dtype)
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ prompt_2: Optional[Union[str, List[str]]] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ denoising_end: Optional[float] = None,
+ guidance_scale: float = 5.0,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ pooled_prompt_embeds: Optional[torch.Tensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = False,
+ callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
+ callback_steps: int = 1,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ guidance_rescale: float = 0.0,
+ original_size: Optional[Tuple[int, int]] = None,
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
+ target_size: Optional[Tuple[int, int]] = None,
+ negative_original_size: Optional[Tuple[int, int]] = None,
+ negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
+ negative_target_size: Optional[Tuple[int, int]] = None,
+ ################### DemoFusion specific parameters ####################
+ view_batch_size: int = 16,
+ multi_decoder: bool = True,
+ stride: Optional[int] = 64,
+ cosine_scale_1: Optional[float] = 3.0,
+ cosine_scale_2: Optional[float] = 1.0,
+ cosine_scale_3: Optional[float] = 1.0,
+ sigma: Optional[float] = 0.8,
+ show_image: bool = False,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ used in both text-encoders
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
+ Anything below 512 pixels won't work well for
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
+ and checkpoints that are not specifically fine-tuned on low resolutions.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
+ Anything below 512 pixels won't work well for
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
+ and checkpoints that are not specifically fine-tuned on low resolutions.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ denoising_end (`float`, *optional*):
+ When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
+ completed before it is intentionally prematurely terminated. As a result, the returned sample will
+ still retain a substantial amount of noise as determined by the discrete timesteps selected by the
+ scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a
+ "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
+ Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
+ guidance_scale (`float`, *optional*, defaults to 5.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ pooled_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
+ input argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
+ of a plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ guidance_rescale (`float`, *optional*, defaults to 0.7):
+ Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
+ [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
+ Guidance rescale factor should fix overexposure when using zero terminal SNR.
+ original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
+ `original_size` defaults to `(width, height)` if not specified. Part of SDXL's micro-conditioning as
+ explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
+ `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
+ `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
+ `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ For most cases, `target_size` should be set to the desired height and width of the generated image. If
+ not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in
+ section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ To negatively condition the generation process based on a specific image resolution. Part of SDXL's
+ micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
+ negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
+ To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
+ micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
+ negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ To negatively condition the generation process based on a target image resolution. It should be as same
+ as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
+ ################### DemoFusion specific parameters ####################
+ view_batch_size (`int`, defaults to 16):
+ The batch size for multiple denoising paths. Typically, a larger batch size can result in higher
+ efficiency but comes with increased GPU memory requirements.
+ multi_decoder (`bool`, defaults to True):
+ Determine whether to use a tiled decoder. Generally, when the resolution exceeds 3072x3072,
+ a tiled decoder becomes necessary.
+ stride (`int`, defaults to 64):
+ The stride of moving local patches. A smaller stride is better for alleviating seam issues,
+ but it also introduces additional computational overhead and inference time.
+ cosine_scale_1 (`float`, defaults to 3):
+ Control the strength of skip-residual. For specific impacts, please refer to Appendix C
+ in the DemoFusion paper.
+ cosine_scale_2 (`float`, defaults to 1):
+ Control the strength of dilated sampling. For specific impacts, please refer to Appendix C
+ in the DemoFusion paper.
+ cosine_scale_3 (`float`, defaults to 1):
+ Control the strength of the gaussian filter. For specific impacts, please refer to Appendix C
+ in the DemoFusion paper.
+ sigma (`float`, defaults to 1):
+ The standard value of the gaussian filter.
+ show_image (`bool`, defaults to False):
+ Determine whether to show intermediate results during generation.
+
+ Examples:
+
+ Returns:
+ a `list` with the generated images at each phase.
+ """
+
+ # 0. Default height and width to unet
+ height = height or self.default_sample_size * self.vae_scale_factor
+ width = width or self.default_sample_size * self.vae_scale_factor
+
+ x1_size = self.default_sample_size * self.vae_scale_factor
+
+ height_scale = height / x1_size
+ width_scale = width / x1_size
+ scale_num = int(max(height_scale, width_scale))
+ aspect_ratio = min(height_scale, width_scale) / max(height_scale, width_scale)
+
+ original_size = original_size or (height, width)
+ target_size = target_size or (height, width)
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ prompt_2,
+ height,
+ width,
+ callback_steps,
+ negative_prompt,
+ negative_prompt_2,
+ prompt_embeds,
+ negative_prompt_embeds,
+ pooled_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ num_images_per_prompt,
+ )
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ text_encoder_lora_scale = (
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
+ )
+ (
+ prompt_embeds,
+ negative_prompt_embeds,
+ pooled_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ ) = self.encode_prompt(
+ prompt=prompt,
+ prompt_2=prompt_2,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ do_classifier_free_guidance=do_classifier_free_guidance,
+ negative_prompt=negative_prompt,
+ negative_prompt_2=negative_prompt_2,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ lora_scale=text_encoder_lora_scale,
+ )
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+
+ timesteps = self.scheduler.timesteps
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.unet.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height // scale_num,
+ width // scale_num,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 7. Prepare added time ids & embeddings
+ add_text_embeds = pooled_prompt_embeds
+ add_time_ids = self._get_add_time_ids(
+ original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype
+ )
+ if negative_original_size is not None and negative_target_size is not None:
+ negative_add_time_ids = self._get_add_time_ids(
+ negative_original_size,
+ negative_crops_coords_top_left,
+ negative_target_size,
+ dtype=prompt_embeds.dtype,
+ )
+ else:
+ negative_add_time_ids = add_time_ids
+
+ if do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
+ add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
+
+ prompt_embeds = prompt_embeds.to(device)
+ add_text_embeds = add_text_embeds.to(device)
+ add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
+
+ # 8. Denoising loop
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+
+ # 7.1 Apply denoising_end
+ if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1:
+ discrete_timestep_cutoff = int(
+ round(
+ self.scheduler.config.num_train_timesteps
+ - (denoising_end * self.scheduler.config.num_train_timesteps)
+ )
+ )
+ num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
+ timesteps = timesteps[:num_inference_steps]
+
+ output_images = []
+
+ ############################################################### Phase 1 #################################################################
+
+ print("### Phase 1 Denoising ###")
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ latents_for_view = latents
+
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = latents.repeat_interleave(2, dim=0) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs=cross_attention_kwargs,
+ added_cond_kwargs=added_cond_kwargs,
+ return_dict=False,
+ )[0]
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred[::2], noise_pred[1::2]
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
+
+ anchor_mean = latents.mean()
+ anchor_std = latents.std()
+ if not output_type == "latent":
+ # make sure the VAE is in float32 mode, as it overflows in float16
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
+
+ if needs_upcasting:
+ self.upcast_vae()
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
+ print("### Phase 1 Decoding ###")
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+ # cast back to fp16 if needed
+ if needs_upcasting:
+ self.vae.to(dtype=torch.float16)
+
+ image = self.image_processor.postprocess(image, output_type=output_type)
+ if show_image:
+ plt.figure(figsize=(10, 10))
+ plt.imshow(image[0])
+ plt.axis("off") # Turn off axis numbers and ticks
+ plt.show()
+ output_images.append(image[0])
+
+ ####################################################### Phase 2+ #####################################################
+
+ for current_scale_num in range(2, scale_num + 1):
+ print("### Phase {} Denoising ###".format(current_scale_num))
+ current_height = self.unet.config.sample_size * self.vae_scale_factor * current_scale_num
+ current_width = self.unet.config.sample_size * self.vae_scale_factor * current_scale_num
+ if height > width:
+ current_width = int(current_width * aspect_ratio)
+ else:
+ current_height = int(current_height * aspect_ratio)
+
+ latents = F.interpolate(
+ latents,
+ size=(int(current_height / self.vae_scale_factor), int(current_width / self.vae_scale_factor)),
+ mode="bicubic",
+ )
+
+ noise_latents = []
+ noise = torch.randn_like(latents)
+ for timestep in timesteps:
+ noise_latent = self.scheduler.add_noise(latents, noise, timestep.unsqueeze(0))
+ noise_latents.append(noise_latent)
+ latents = noise_latents[0]
+
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ count = torch.zeros_like(latents)
+ value = torch.zeros_like(latents)
+ cosine_factor = (
+ 0.5
+ * (
+ 1
+ + torch.cos(
+ torch.pi
+ * (self.scheduler.config.num_train_timesteps - t)
+ / self.scheduler.config.num_train_timesteps
+ )
+ ).cpu()
+ )
+
+ c1 = cosine_factor**cosine_scale_1
+ latents = latents * (1 - c1) + noise_latents[i] * c1
+
+ ############################################# MultiDiffusion #############################################
+
+ views = self.get_views(
+ current_height,
+ current_width,
+ stride=stride,
+ window_size=self.unet.config.sample_size,
+ random_jitter=True,
+ )
+ views_batch = [views[i : i + view_batch_size] for i in range(0, len(views), view_batch_size)]
+
+ jitter_range = (self.unet.config.sample_size - stride) // 4
+ latents_ = F.pad(latents, (jitter_range, jitter_range, jitter_range, jitter_range), "constant", 0)
+
+ count_local = torch.zeros_like(latents_)
+ value_local = torch.zeros_like(latents_)
+
+ for j, batch_view in enumerate(views_batch):
+ vb_size = len(batch_view)
+
+ # get the latents corresponding to the current view coordinates
+ latents_for_view = torch.cat(
+ [
+ latents_[:, :, h_start:h_end, w_start:w_end]
+ for h_start, h_end, w_start, w_end in batch_view
+ ]
+ )
+
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = latents_for_view
+ latent_model_input = (
+ latent_model_input.repeat_interleave(2, dim=0)
+ if do_classifier_free_guidance
+ else latent_model_input
+ )
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ prompt_embeds_input = torch.cat([prompt_embeds] * vb_size)
+ add_text_embeds_input = torch.cat([add_text_embeds] * vb_size)
+ add_time_ids_input = []
+ for h_start, h_end, w_start, w_end in batch_view:
+ add_time_ids_ = add_time_ids.clone()
+ add_time_ids_[:, 2] = h_start * self.vae_scale_factor
+ add_time_ids_[:, 3] = w_start * self.vae_scale_factor
+ add_time_ids_input.append(add_time_ids_)
+ add_time_ids_input = torch.cat(add_time_ids_input)
+
+ # predict the noise residual
+ added_cond_kwargs = {"text_embeds": add_text_embeds_input, "time_ids": add_time_ids_input}
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds_input,
+ cross_attention_kwargs=cross_attention_kwargs,
+ added_cond_kwargs=added_cond_kwargs,
+ return_dict=False,
+ )[0]
+
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred[::2], noise_pred[1::2]
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+ noise_pred = rescale_noise_cfg(
+ noise_pred, noise_pred_text, guidance_rescale=guidance_rescale
+ )
+
+ # compute the previous noisy sample x_t -> x_t-1
+ self.scheduler._init_step_index(t)
+ latents_denoised_batch = self.scheduler.step(
+ noise_pred, t, latents_for_view, **extra_step_kwargs, return_dict=False
+ )[0]
+
+ # extract value from batch
+ for latents_view_denoised, (h_start, h_end, w_start, w_end) in zip(
+ latents_denoised_batch.chunk(vb_size), batch_view
+ ):
+ value_local[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised
+ count_local[:, :, h_start:h_end, w_start:w_end] += 1
+
+ value_local = value_local[
+ :,
+ :,
+ jitter_range : jitter_range + current_height // self.vae_scale_factor,
+ jitter_range : jitter_range + current_width // self.vae_scale_factor,
+ ]
+ count_local = count_local[
+ :,
+ :,
+ jitter_range : jitter_range + current_height // self.vae_scale_factor,
+ jitter_range : jitter_range + current_width // self.vae_scale_factor,
+ ]
+
+ c2 = cosine_factor**cosine_scale_2
+
+ value += value_local / count_local * (1 - c2)
+ count += torch.ones_like(value_local) * (1 - c2)
+
+ ############################################# Dilated Sampling #############################################
+
+ views = [[h, w] for h in range(current_scale_num) for w in range(current_scale_num)]
+ views_batch = [views[i : i + view_batch_size] for i in range(0, len(views), view_batch_size)]
+
+ h_pad = (current_scale_num - (latents.size(2) % current_scale_num)) % current_scale_num
+ w_pad = (current_scale_num - (latents.size(3) % current_scale_num)) % current_scale_num
+ latents_ = F.pad(latents, (w_pad, 0, h_pad, 0), "constant", 0)
+
+ count_global = torch.zeros_like(latents_)
+ value_global = torch.zeros_like(latents_)
+
+ c3 = 0.99 * cosine_factor**cosine_scale_3 + 1e-2
+ std_, mean_ = latents_.std(), latents_.mean()
+ latents_gaussian = gaussian_filter(
+ latents_, kernel_size=(2 * current_scale_num - 1), sigma=sigma * c3
+ )
+ latents_gaussian = (
+ latents_gaussian - latents_gaussian.mean()
+ ) / latents_gaussian.std() * std_ + mean_
+
+ for j, batch_view in enumerate(views_batch):
+ latents_for_view = torch.cat(
+ [latents_[:, :, h::current_scale_num, w::current_scale_num] for h, w in batch_view]
+ )
+ latents_for_view_gaussian = torch.cat(
+ [latents_gaussian[:, :, h::current_scale_num, w::current_scale_num] for h, w in batch_view]
+ )
+
+ vb_size = latents_for_view.size(0)
+
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = latents_for_view_gaussian
+ latent_model_input = (
+ latent_model_input.repeat_interleave(2, dim=0)
+ if do_classifier_free_guidance
+ else latent_model_input
+ )
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ prompt_embeds_input = torch.cat([prompt_embeds] * vb_size)
+ add_text_embeds_input = torch.cat([add_text_embeds] * vb_size)
+ add_time_ids_input = torch.cat([add_time_ids] * vb_size)
+
+ # predict the noise residual
+ added_cond_kwargs = {"text_embeds": add_text_embeds_input, "time_ids": add_time_ids_input}
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds_input,
+ cross_attention_kwargs=cross_attention_kwargs,
+ added_cond_kwargs=added_cond_kwargs,
+ return_dict=False,
+ )[0]
+
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred[::2], noise_pred[1::2]
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+ noise_pred = rescale_noise_cfg(
+ noise_pred, noise_pred_text, guidance_rescale=guidance_rescale
+ )
+
+ # compute the previous noisy sample x_t -> x_t-1
+ self.scheduler._init_step_index(t)
+ latents_denoised_batch = self.scheduler.step(
+ noise_pred, t, latents_for_view, **extra_step_kwargs, return_dict=False
+ )[0]
+
+ # extract value from batch
+ for latents_view_denoised, (h, w) in zip(latents_denoised_batch.chunk(vb_size), batch_view):
+ value_global[:, :, h::current_scale_num, w::current_scale_num] += latents_view_denoised
+ count_global[:, :, h::current_scale_num, w::current_scale_num] += 1
+
+ c2 = cosine_factor**cosine_scale_2
+
+ value_global = value_global[:, :, h_pad:, w_pad:]
+
+ value += value_global * c2
+ count += torch.ones_like(value_global) * c2
+
+ ###########################################################
+
+ latents = torch.where(count > 0, value / count, value)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
+
+ #########################################################################################################################################
+
+ latents = (latents - latents.mean()) / latents.std() * anchor_std + anchor_mean
+ if not output_type == "latent":
+ # make sure the VAE is in float32 mode, as it overflows in float16
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
+
+ if needs_upcasting:
+ self.upcast_vae()
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
+
+ print("### Phase {} Decoding ###".format(current_scale_num))
+ if multi_decoder:
+ image = self.tiled_decode(latents, current_height, current_width)
+ else:
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+
+ # cast back to fp16 if needed
+ if needs_upcasting:
+ self.vae.to(dtype=torch.float16)
+ else:
+ image = latents
+
+ if not output_type == "latent":
+ image = self.image_processor.postprocess(image, output_type=output_type)
+ if show_image:
+ plt.figure(figsize=(10, 10))
+ plt.imshow(image[0])
+ plt.axis("off") # Turn off axis numbers and ticks
+ plt.show()
+ output_images.append(image[0])
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ return output_images
+
+ # Override to properly handle the loading and unloading of the additional text encoder.
+ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
+ # We could have accessed the unet config from `lora_state_dict()` too. We pass
+ # it here explicitly to be able to tell that it's coming from an SDXL
+ # pipeline.
+
+ # Remove any existing hooks.
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
+ from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
+ else:
+ raise ImportError("Offloading requires `accelerate v0.17.0` or higher.")
+
+ is_model_cpu_offload = False
+ is_sequential_cpu_offload = False
+ recursive = False
+ for _, component in self.components.items():
+ if isinstance(component, torch.nn.Module):
+ if hasattr(component, "_hf_hook"):
+ is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
+ is_sequential_cpu_offload = (
+ isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
+ or hasattr(component._hf_hook, "hooks")
+ and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
+ )
+ logger.info(
+ "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
+ )
+ recursive = is_sequential_cpu_offload
+ remove_hook_from_module(component, recurse=recursive)
+ state_dict, network_alphas = self.lora_state_dict(
+ pretrained_model_name_or_path_or_dict,
+ unet_config=self.unet.config,
+ **kwargs,
+ )
+ self.load_lora_into_unet(state_dict, network_alphas=network_alphas, unet=self.unet)
+
+ text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
+ if len(text_encoder_state_dict) > 0:
+ self.load_lora_into_text_encoder(
+ text_encoder_state_dict,
+ network_alphas=network_alphas,
+ text_encoder=self.text_encoder,
+ prefix="text_encoder",
+ lora_scale=self.lora_scale,
+ )
+
+ text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k}
+ if len(text_encoder_2_state_dict) > 0:
+ self.load_lora_into_text_encoder(
+ text_encoder_2_state_dict,
+ network_alphas=network_alphas,
+ text_encoder=self.text_encoder_2,
+ prefix="text_encoder_2",
+ lora_scale=self.lora_scale,
+ )
+
+ # Offload back.
+ if is_model_cpu_offload:
+ self.enable_model_cpu_offload()
+ elif is_sequential_cpu_offload:
+ self.enable_sequential_cpu_offload()
+
+ @classmethod
+ def save_lora_weights(
+ cls,
+ save_directory: Union[str, os.PathLike],
+ unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
+ text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
+ text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
+ is_main_process: bool = True,
+ weight_name: str = None,
+ save_function: Callable = None,
+ safe_serialization: bool = True,
+ ):
+ state_dict = {}
+
+ def pack_weights(layers, prefix):
+ layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
+ layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
+ return layers_state_dict
+
+ if not (unet_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers):
+ raise ValueError(
+ "You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers` or `text_encoder_2_lora_layers`."
+ )
+
+ if unet_lora_layers:
+ state_dict.update(pack_weights(unet_lora_layers, "unet"))
+
+ if text_encoder_lora_layers and text_encoder_2_lora_layers:
+ state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder"))
+ state_dict.update(pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
+
+ cls.write_lora_layers(
+ state_dict=state_dict,
+ save_directory=save_directory,
+ is_main_process=is_main_process,
+ weight_name=weight_name,
+ save_function=save_function,
+ safe_serialization=safe_serialization,
+ )
+
+ def _remove_text_encoder_monkey_patch(self):
+ self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder)
+ self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder_2)
diff --git a/diffusers/examples/community/pipeline_fabric.py b/diffusers/examples/community/pipeline_fabric.py
new file mode 100644
index 0000000000000000000000000000000000000000..02fdcd04c1038f8f5f24ea2dfc8fe97f9eb8b8d8
--- /dev/null
+++ b/diffusers/examples/community/pipeline_fabric.py
@@ -0,0 +1,751 @@
+# Copyright 2024 FABRIC authors and the HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import List, Optional, Union
+
+import torch
+from packaging import version
+from PIL import Image
+from transformers import CLIPTextModel, CLIPTokenizer
+
+from diffusers import AutoencoderKL, UNet2DConditionModel
+from diffusers.configuration_utils import FrozenDict
+from diffusers.image_processor import VaeImageProcessor
+from diffusers.loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
+from diffusers.models.attention import BasicTransformerBlock
+from diffusers.models.attention_processor import LoRAAttnProcessor
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline
+from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
+from diffusers.schedulers import EulerAncestralDiscreteScheduler, KarrasDiffusionSchedulers
+from diffusers.utils import (
+ deprecate,
+ logging,
+ replace_example_docstring,
+)
+from diffusers.utils.torch_utils import randn_tensor
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> from diffusers import DiffusionPipeline
+ >>> import torch
+
+ >>> model_id = "dreamlike-art/dreamlike-photoreal-2.0"
+ >>> pipe = DiffusionPipeline(model_id, torch_dtype=torch.float16, custom_pipeline="pipeline_fabric")
+ >>> pipe = pipe.to("cuda")
+ >>> prompt = "a giant standing in a fantasy landscape best quality"
+ >>> liked = [] # list of images for positive feedback
+ >>> disliked = [] # list of images for negative feedback
+ >>> image = pipe(prompt, num_images=4, liked=liked, disliked=disliked).images[0]
+ ```
+"""
+
+
+class FabricCrossAttnProcessor:
+ def __init__(self):
+ self.attntion_probs = None
+
+ def __call__(
+ self,
+ attn,
+ hidden_states,
+ encoder_hidden_states=None,
+ attention_mask=None,
+ weights=None,
+ lora_scale=1.0,
+ ):
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+
+ if isinstance(attn.processor, LoRAAttnProcessor):
+ query = attn.to_q(hidden_states) + lora_scale * attn.processor.to_q_lora(hidden_states)
+ else:
+ query = attn.to_q(hidden_states)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ elif attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ if isinstance(attn.processor, LoRAAttnProcessor):
+ key = attn.to_k(encoder_hidden_states) + lora_scale * attn.processor.to_k_lora(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states) + lora_scale * attn.processor.to_v_lora(encoder_hidden_states)
+ else:
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ query = attn.head_to_batch_dim(query)
+ key = attn.head_to_batch_dim(key)
+ value = attn.head_to_batch_dim(value)
+
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
+
+ if weights is not None:
+ if weights.shape[0] != 1:
+ weights = weights.repeat_interleave(attn.heads, dim=0)
+ attention_probs = attention_probs * weights[:, None]
+ attention_probs = attention_probs / attention_probs.sum(dim=-1, keepdim=True)
+
+ hidden_states = torch.bmm(attention_probs, value)
+ hidden_states = attn.batch_to_head_dim(hidden_states)
+
+ # linear proj
+ if isinstance(attn.processor, LoRAAttnProcessor):
+ hidden_states = attn.to_out[0](hidden_states) + lora_scale * attn.processor.to_out_lora(hidden_states)
+ else:
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ return hidden_states
+
+
+class FabricPipeline(DiffusionPipeline):
+ r"""
+ Pipeline for text-to-image generation using Stable Diffusion and conditioning the results using feedback images.
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`~transformers.CLIPTextModel`]):
+ Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
+ tokenizer ([`~transformers.CLIPTokenizer`]):
+ A `CLIPTokenizer` to tokenize text.
+ unet ([`UNet2DConditionModel`]):
+ A `UNet2DConditionModel` to denoise the encoded image latents.
+ scheduler ([`EulerAncestralDiscreteScheduler`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
+ about a model's potential harms.
+ """
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: KarrasDiffusionSchedulers,
+ requires_safety_checker: bool = True,
+ ):
+ super().__init__()
+
+ is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
+ version.parse(unet.config._diffusers_version).base_version
+ ) < version.parse("0.9.0.dev0")
+ is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
+ deprecation_message = (
+ "The configuration file of the unet has set the default `sample_size` to smaller than"
+ " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
+ " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
+ " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
+ " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
+ " in the config might lead to incorrect results in future versions. If you have downloaded this"
+ " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
+ " the `unet/config.json` file"
+ )
+
+ deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(unet.config)
+ new_config["sample_size"] = 64
+ unet._internal_dict = FrozenDict(new_config)
+
+ self.register_modules(
+ unet=unet,
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ scheduler=scheduler,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
+ def _encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt=None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ lora_scale: Optional[float] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ lora_scale (`float`, *optional*):
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ """
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ # textual inversion: process multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = self.tokenizer.batch_decode(
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
+ )
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = text_inputs.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ prompt_embeds = self.text_encoder(
+ text_input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ prompt_embeds = prompt_embeds[0]
+
+ if self.text_encoder is not None:
+ prompt_embeds_dtype = self.text_encoder.dtype
+ elif self.unet is not None:
+ prompt_embeds_dtype = self.unet.dtype
+ else:
+ prompt_embeds_dtype = prompt_embeds.dtype
+
+ prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ # textual inversion: process multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = uncond_input.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ negative_prompt_embeds = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ negative_prompt_embeds = negative_prompt_embeds[0]
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+
+ return prompt_embeds
+
+ def get_unet_hidden_states(self, z_all, t, prompt_embd):
+ cached_hidden_states = []
+ for module in self.unet.modules():
+ if isinstance(module, BasicTransformerBlock):
+
+ def new_forward(self, hidden_states, *args, **kwargs):
+ cached_hidden_states.append(hidden_states.clone().detach().cpu())
+ return self.old_forward(hidden_states, *args, **kwargs)
+
+ module.attn1.old_forward = module.attn1.forward
+ module.attn1.forward = new_forward.__get__(module.attn1)
+
+ # run forward pass to cache hidden states, output can be discarded
+ _ = self.unet(z_all, t, encoder_hidden_states=prompt_embd)
+
+ # restore original forward pass
+ for module in self.unet.modules():
+ if isinstance(module, BasicTransformerBlock):
+ module.attn1.forward = module.attn1.old_forward
+ del module.attn1.old_forward
+
+ return cached_hidden_states
+
+ def unet_forward_with_cached_hidden_states(
+ self,
+ z_all,
+ t,
+ prompt_embd,
+ cached_pos_hiddens: Optional[List[torch.Tensor]] = None,
+ cached_neg_hiddens: Optional[List[torch.Tensor]] = None,
+ pos_weights=(0.8, 0.8),
+ neg_weights=(0.5, 0.5),
+ ):
+ if cached_pos_hiddens is None and cached_neg_hiddens is None:
+ return self.unet(z_all, t, encoder_hidden_states=prompt_embd)
+
+ local_pos_weights = torch.linspace(*pos_weights, steps=len(self.unet.down_blocks) + 1)[:-1].tolist()
+ local_neg_weights = torch.linspace(*neg_weights, steps=len(self.unet.down_blocks) + 1)[:-1].tolist()
+ for block, pos_weight, neg_weight in zip(
+ self.unet.down_blocks + [self.unet.mid_block] + self.unet.up_blocks,
+ local_pos_weights + [pos_weights[1]] + local_pos_weights[::-1],
+ local_neg_weights + [neg_weights[1]] + local_neg_weights[::-1],
+ ):
+ for module in block.modules():
+ if isinstance(module, BasicTransformerBlock):
+
+ def new_forward(
+ self,
+ hidden_states,
+ pos_weight=pos_weight,
+ neg_weight=neg_weight,
+ **kwargs,
+ ):
+ cond_hiddens, uncond_hiddens = hidden_states.chunk(2, dim=0)
+ batch_size, d_model = cond_hiddens.shape[:2]
+ device, dtype = hidden_states.device, hidden_states.dtype
+
+ weights = torch.ones(batch_size, d_model, device=device, dtype=dtype)
+ out_pos = self.old_forward(hidden_states)
+ out_neg = self.old_forward(hidden_states)
+
+ if cached_pos_hiddens is not None:
+ cached_pos_hs = cached_pos_hiddens.pop(0).to(hidden_states.device)
+ cond_pos_hs = torch.cat([cond_hiddens, cached_pos_hs], dim=1)
+ pos_weights = weights.clone().repeat(1, 1 + cached_pos_hs.shape[1] // d_model)
+ pos_weights[:, d_model:] = pos_weight
+ attn_with_weights = FabricCrossAttnProcessor()
+ out_pos = attn_with_weights(
+ self,
+ cond_hiddens,
+ encoder_hidden_states=cond_pos_hs,
+ weights=pos_weights,
+ )
+ else:
+ out_pos = self.old_forward(cond_hiddens)
+
+ if cached_neg_hiddens is not None:
+ cached_neg_hs = cached_neg_hiddens.pop(0).to(hidden_states.device)
+ uncond_neg_hs = torch.cat([uncond_hiddens, cached_neg_hs], dim=1)
+ neg_weights = weights.clone().repeat(1, 1 + cached_neg_hs.shape[1] // d_model)
+ neg_weights[:, d_model:] = neg_weight
+ attn_with_weights = FabricCrossAttnProcessor()
+ out_neg = attn_with_weights(
+ self,
+ uncond_hiddens,
+ encoder_hidden_states=uncond_neg_hs,
+ weights=neg_weights,
+ )
+ else:
+ out_neg = self.old_forward(uncond_hiddens)
+
+ out = torch.cat([out_pos, out_neg], dim=0)
+ return out
+
+ module.attn1.old_forward = module.attn1.forward
+ module.attn1.forward = new_forward.__get__(module.attn1)
+
+ out = self.unet(z_all, t, encoder_hidden_states=prompt_embd)
+
+ # restore original forward pass
+ for module in self.unet.modules():
+ if isinstance(module, BasicTransformerBlock):
+ module.attn1.forward = module.attn1.old_forward
+ del module.attn1.old_forward
+
+ return out
+
+ def preprocess_feedback_images(self, images, vae, dim, device, dtype, generator) -> torch.tensor:
+ images_t = [self.image_to_tensor(img, dim, dtype) for img in images]
+ images_t = torch.stack(images_t).to(device)
+ latents = vae.config.scaling_factor * vae.encode(images_t).latent_dist.sample(generator)
+
+ return torch.cat([latents], dim=0)
+
+ def check_inputs(
+ self,
+ prompt,
+ negative_prompt=None,
+ liked=None,
+ disliked=None,
+ height=None,
+ width=None,
+ ):
+ if prompt is None:
+ raise ValueError("Provide `prompt`. Cannot leave both `prompt` undefined.")
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and (
+ not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list)
+ ):
+ raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
+
+ if liked is not None and not isinstance(liked, list):
+ raise ValueError(f"`liked` has to be of type `list` but is {type(liked)}")
+
+ if disliked is not None and not isinstance(disliked, list):
+ raise ValueError(f"`disliked` has to be of type `list` but is {type(disliked)}")
+
+ if height is not None and not isinstance(height, int):
+ raise ValueError(f"`height` has to be of type `int` but is {type(height)}")
+
+ if width is not None and not isinstance(width, int):
+ raise ValueError(f"`width` has to be of type `int` but is {type(width)}")
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Optional[Union[str, List[str]]] = "",
+ negative_prompt: Optional[Union[str, List[str]]] = "lowres, bad anatomy, bad hands, cropped, worst quality",
+ liked: Optional[Union[List[str], List[Image.Image]]] = [],
+ disliked: Optional[Union[List[str], List[Image.Image]]] = [],
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ height: int = 512,
+ width: int = 512,
+ return_dict: bool = True,
+ num_images: int = 4,
+ guidance_scale: float = 7.0,
+ num_inference_steps: int = 20,
+ output_type: Optional[str] = "pil",
+ feedback_start_ratio: float = 0.33,
+ feedback_end_ratio: float = 0.66,
+ min_weight: float = 0.05,
+ max_weight: float = 0.8,
+ neg_scale: float = 0.5,
+ pos_bottleneck_scale: float = 1.0,
+ neg_bottleneck_scale: float = 1.0,
+ latents: Optional[torch.Tensor] = None,
+ ):
+ r"""
+ The call function to the pipeline for generation. Generate a trajectory of images with binary feedback. The
+ feedback can be given as a list of liked and disliked images.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`
+ instead.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
+ liked (`List[Image.Image]` or `List[str]`, *optional*):
+ Encourages images with liked features.
+ disliked (`List[Image.Image]` or `List[str]`, *optional*):
+ Discourages images with disliked features.
+ generator (`torch.Generator` or `List[torch.Generator]` or `int`, *optional*):
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) or an `int` to
+ make generation deterministic.
+ height (`int`, *optional*, defaults to 512):
+ Height of the generated image.
+ width (`int`, *optional*, defaults to 512):
+ Width of the generated image.
+ num_images (`int`, *optional*, defaults to 4):
+ The number of images to generate per prompt.
+ guidance_scale (`float`, *optional*, defaults to 7.0):
+ A higher guidance scale value encourages the model to generate images closely linked to the text
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
+ num_inference_steps (`int`, *optional*, defaults to 20):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ feedback_start_ratio (`float`, *optional*, defaults to `.33`):
+ Start point for providing feedback (between 0 and 1).
+ feedback_end_ratio (`float`, *optional*, defaults to `.66`):
+ End point for providing feedback (between 0 and 1).
+ min_weight (`float`, *optional*, defaults to `.05`):
+ Minimum weight for feedback.
+ max_weight (`float`, *optional*, defults tp `1.0`):
+ Maximum weight for feedback.
+ neg_scale (`float`, *optional*, defaults to `.5`):
+ Scale factor for negative feedback.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.fabric.FabricPipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
+ otherwise a `tuple` is returned where the first element is a list with the generated images and the
+ second element is a list of `bool`s indicating whether the corresponding generated image contains
+ "not-safe-for-work" (nsfw) content.
+
+ """
+
+ self.check_inputs(prompt, negative_prompt, liked, disliked)
+
+ device = self._execution_device
+ dtype = self.unet.dtype
+
+ if isinstance(prompt, str) and prompt is not None:
+ batch_size = 1
+ elif isinstance(prompt, list) and prompt is not None:
+ batch_size = len(prompt)
+ else:
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if isinstance(negative_prompt, str):
+ negative_prompt = negative_prompt
+ elif isinstance(negative_prompt, list):
+ negative_prompt = negative_prompt
+ else:
+ assert len(negative_prompt) == batch_size
+
+ shape = (
+ batch_size * num_images,
+ self.unet.config.in_channels,
+ height // self.vae_scale_factor,
+ width // self.vae_scale_factor,
+ )
+ latent_noise = randn_tensor(
+ shape,
+ device=device,
+ dtype=dtype,
+ generator=generator,
+ )
+
+ positive_latents = (
+ self.preprocess_feedback_images(liked, self.vae, (height, width), device, dtype, generator)
+ if liked and len(liked) > 0
+ else torch.tensor(
+ [],
+ device=device,
+ dtype=dtype,
+ )
+ )
+ negative_latents = (
+ self.preprocess_feedback_images(disliked, self.vae, (height, width), device, dtype, generator)
+ if disliked and len(disliked) > 0
+ else torch.tensor(
+ [],
+ device=device,
+ dtype=dtype,
+ )
+ )
+
+ do_classifier_free_guidance = guidance_scale > 0.1
+
+ (prompt_neg_embs, prompt_pos_embs) = self._encode_prompt(
+ prompt,
+ device,
+ num_images,
+ do_classifier_free_guidance,
+ negative_prompt,
+ ).split([num_images * batch_size, num_images * batch_size])
+
+ batched_prompt_embd = torch.cat([prompt_pos_embs, prompt_neg_embs], dim=0)
+
+ null_tokens = self.tokenizer(
+ [""],
+ return_tensors="pt",
+ max_length=self.tokenizer.model_max_length,
+ padding="max_length",
+ truncation=True,
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = null_tokens.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ null_prompt_emb = self.text_encoder(
+ input_ids=null_tokens.input_ids.to(device),
+ attention_mask=attention_mask,
+ ).last_hidden_state
+
+ null_prompt_emb = null_prompt_emb.to(device=device, dtype=dtype)
+
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+ latent_noise = latent_noise * self.scheduler.init_noise_sigma
+
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+
+ ref_start_idx = round(len(timesteps) * feedback_start_ratio)
+ ref_end_idx = round(len(timesteps) * feedback_end_ratio)
+
+ with self.progress_bar(total=num_inference_steps) as pbar:
+ for i, t in enumerate(timesteps):
+ sigma = self.scheduler.sigma_t[t] if hasattr(self.scheduler, "sigma_t") else 0
+ if hasattr(self.scheduler, "sigmas"):
+ sigma = self.scheduler.sigmas[i]
+
+ alpha_hat = 1 / (sigma**2 + 1)
+
+ z_single = self.scheduler.scale_model_input(latent_noise, t)
+ z_all = torch.cat([z_single] * 2, dim=0)
+ z_ref = torch.cat([positive_latents, negative_latents], dim=0)
+
+ if i >= ref_start_idx and i <= ref_end_idx:
+ weight_factor = max_weight
+ else:
+ weight_factor = min_weight
+
+ pos_ws = (weight_factor, weight_factor * pos_bottleneck_scale)
+ neg_ws = (weight_factor * neg_scale, weight_factor * neg_scale * neg_bottleneck_scale)
+
+ if z_ref.size(0) > 0 and weight_factor > 0:
+ noise = torch.randn_like(z_ref)
+ if isinstance(self.scheduler, EulerAncestralDiscreteScheduler):
+ z_ref_noised = (alpha_hat**0.5 * z_ref + (1 - alpha_hat) ** 0.5 * noise).type(dtype)
+ else:
+ z_ref_noised = self.scheduler.add_noise(z_ref, noise, t)
+
+ ref_prompt_embd = torch.cat(
+ [null_prompt_emb] * (len(positive_latents) + len(negative_latents)), dim=0
+ )
+ cached_hidden_states = self.get_unet_hidden_states(z_ref_noised, t, ref_prompt_embd)
+
+ n_pos, n_neg = positive_latents.shape[0], negative_latents.shape[0]
+ cached_pos_hs, cached_neg_hs = [], []
+ for hs in cached_hidden_states:
+ cached_pos, cached_neg = hs.split([n_pos, n_neg], dim=0)
+ cached_pos = cached_pos.view(1, -1, *cached_pos.shape[2:]).expand(num_images, -1, -1)
+ cached_neg = cached_neg.view(1, -1, *cached_neg.shape[2:]).expand(num_images, -1, -1)
+ cached_pos_hs.append(cached_pos)
+ cached_neg_hs.append(cached_neg)
+
+ if n_pos == 0:
+ cached_pos_hs = None
+ if n_neg == 0:
+ cached_neg_hs = None
+ else:
+ cached_pos_hs, cached_neg_hs = None, None
+ unet_out = self.unet_forward_with_cached_hidden_states(
+ z_all,
+ t,
+ prompt_embd=batched_prompt_embd,
+ cached_pos_hiddens=cached_pos_hs,
+ cached_neg_hiddens=cached_neg_hs,
+ pos_weights=pos_ws,
+ neg_weights=neg_ws,
+ )[0]
+
+ noise_cond, noise_uncond = unet_out.chunk(2)
+ guidance = noise_cond - noise_uncond
+ noise_pred = noise_uncond + guidance_scale * guidance
+ latent_noise = self.scheduler.step(noise_pred, t, latent_noise)[0]
+
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ pbar.update()
+
+ y = self.vae.decode(latent_noise / self.vae.config.scaling_factor, return_dict=False)[0]
+ imgs = self.image_processor.postprocess(
+ y,
+ output_type=output_type,
+ )
+
+ if not return_dict:
+ return imgs
+
+ return StableDiffusionPipelineOutput(imgs, False)
+
+ def image_to_tensor(self, image: Union[str, Image.Image], dim: tuple, dtype):
+ """
+ Convert latent PIL image to a torch tensor for further processing.
+ """
+ if isinstance(image, str):
+ image = Image.open(image)
+ if not image.mode == "RGB":
+ image = image.convert("RGB")
+ image = self.image_processor.preprocess(image, height=dim[0], width=dim[1])[0]
+ return image.type(dtype)
diff --git a/diffusers/examples/community/pipeline_flux_with_cfg.py b/diffusers/examples/community/pipeline_flux_with_cfg.py
new file mode 100644
index 0000000000000000000000000000000000000000..06da6da899cd57925ebcf04baf356f0c20ecb95b
--- /dev/null
+++ b/diffusers/examples/community/pipeline_flux_with_cfg.py
@@ -0,0 +1,850 @@
+# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import numpy as np
+import torch
+from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
+
+from diffusers.image_processor import VaeImageProcessor
+from diffusers.loaders import FluxLoraLoaderMixin, FromSingleFileMixin
+from diffusers.models.autoencoders import AutoencoderKL
+from diffusers.models.transformers import FluxTransformer2DModel
+from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline
+from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
+from diffusers.utils import (
+ USE_PEFT_BACKEND,
+ is_torch_xla_available,
+ logging,
+ replace_example_docstring,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
+from diffusers.utils.torch_utils import randn_tensor
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import FluxPipeline
+
+ >>> pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
+ >>> pipe.to("cuda")
+ >>> prompt = "A cat holding a sign that says hello world"
+ >>> # Depending on the variant being used, the pipeline call will slightly vary.
+ >>> # Refer to the pipeline documentation for more details.
+ >>> image = pipe(prompt, num_inference_steps=4, guidance_scale=0.0).images[0]
+ >>> image.save("flux.png")
+ ```
+"""
+
+
+def calculate_shift(
+ image_seq_len,
+ base_seq_len: int = 256,
+ max_seq_len: int = 4096,
+ base_shift: float = 0.5,
+ max_shift: float = 1.16,
+):
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
+ b = base_shift - m * base_seq_len
+ mu = image_seq_len * m + b
+ return mu
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ """
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class FluxCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin):
+ r"""
+ The Flux pipeline for text-to-image generation.
+
+ Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
+
+ Args:
+ transformer ([`FluxTransformer2DModel`]):
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ text_encoder_2 ([`T5EncoderModel`]):
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
+ the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
+ tokenizer_2 (`T5TokenizerFast`):
+ Second Tokenizer of class
+ [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
+ """
+
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
+ _optional_components = []
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
+
+ def __init__(
+ self,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ text_encoder_2: T5EncoderModel,
+ tokenizer_2: T5TokenizerFast,
+ transformer: FluxTransformer2DModel,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ text_encoder_2=text_encoder_2,
+ tokenizer=tokenizer,
+ tokenizer_2=tokenizer_2,
+ transformer=transformer,
+ scheduler=scheduler,
+ )
+ self.vae_scale_factor = (
+ 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
+ )
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+ self.tokenizer_max_length = (
+ self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
+ )
+ self.default_sample_size = 64
+
+ def _get_t5_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ num_images_per_prompt: int = 1,
+ max_sequence_length: int = 512,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ text_inputs = self.tokenizer_2(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ return_length=False,
+ return_overflowing_tokens=False,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because `max_sequence_length` is set to "
+ f" {max_sequence_length} tokens: {removed_text}"
+ )
+
+ prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
+
+ dtype = self.text_encoder_2.dtype
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ _, seq_len, _ = prompt_embeds.shape
+
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ return prompt_embeds
+
+ def _get_clip_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]],
+ num_images_per_prompt: int = 1,
+ device: Optional[torch.device] = None,
+ ):
+ device = device or self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer_max_length,
+ truncation=True,
+ return_overflowing_tokens=False,
+ return_length=False,
+ return_tensors="pt",
+ )
+
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer_max_length} tokens: {removed_text}"
+ )
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
+
+ # Use pooled output of CLIPTextModel
+ prompt_embeds = prompt_embeds.pooler_output
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
+
+ return prompt_embeds
+
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ prompt_2: Union[str, List[str]],
+ negative_prompt: Union[str, List[str]] = None,
+ negative_prompt_2: Union[str, List[str]] = None,
+ device: Optional[torch.device] = None,
+ num_images_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ max_sequence_length: int = 512,
+ lora_scale: Optional[float] = None,
+ do_true_cfg: bool = False,
+ ):
+ device = device or self._execution_device
+
+ # Set LoRA scale if applicable
+ if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
+ scale_lora_layers(self.text_encoder, lora_scale)
+ if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
+ scale_lora_layers(self.text_encoder_2, lora_scale)
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ if do_true_cfg and negative_prompt is not None:
+ negative_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+ negative_batch_size = len(negative_prompt)
+
+ if negative_batch_size != batch_size:
+ raise ValueError(
+ f"Negative prompt batch size ({negative_batch_size}) does not match prompt batch size ({batch_size})"
+ )
+
+ # Concatenate prompts
+ prompts = prompt + negative_prompt
+ prompts_2 = (
+ prompt_2 + negative_prompt_2 if prompt_2 is not None and negative_prompt_2 is not None else None
+ )
+ else:
+ prompts = prompt
+ prompts_2 = prompt_2
+
+ if prompt_embeds is None:
+ if prompts_2 is None:
+ prompts_2 = prompts
+
+ # Get pooled prompt embeddings from CLIPTextModel
+ pooled_prompt_embeds = self._get_clip_prompt_embeds(
+ prompt=prompts,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ )
+ prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=prompts_2,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+
+ if do_true_cfg and negative_prompt is not None:
+ # Split embeddings back into positive and negative parts
+ total_batch_size = batch_size * num_images_per_prompt
+ positive_indices = slice(0, total_batch_size)
+ negative_indices = slice(total_batch_size, 2 * total_batch_size)
+
+ positive_pooled_prompt_embeds = pooled_prompt_embeds[positive_indices]
+ negative_pooled_prompt_embeds = pooled_prompt_embeds[negative_indices]
+
+ positive_prompt_embeds = prompt_embeds[positive_indices]
+ negative_prompt_embeds = prompt_embeds[negative_indices]
+
+ pooled_prompt_embeds = positive_pooled_prompt_embeds
+ prompt_embeds = positive_prompt_embeds
+
+ # Unscale LoRA layers
+ if self.text_encoder is not None:
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
+ unscale_lora_layers(self.text_encoder, lora_scale)
+
+ if self.text_encoder_2 is not None:
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
+
+ dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
+ text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
+
+ if do_true_cfg and negative_prompt is not None:
+ return (
+ prompt_embeds,
+ pooled_prompt_embeds,
+ text_ids,
+ negative_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ )
+ else:
+ return prompt_embeds, pooled_prompt_embeds, text_ids, None, None
+
+ def check_inputs(
+ self,
+ prompt,
+ prompt_2,
+ height,
+ width,
+ negative_prompt=None,
+ negative_prompt_2=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ pooled_prompt_embeds=None,
+ negative_pooled_prompt_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ max_sequence_length=None,
+ ):
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt_2 is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
+ raise ValueError(
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
+ )
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
+ raise ValueError(
+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
+ )
+
+ if max_sequence_length is not None and max_sequence_length > 512:
+ raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
+
+ @staticmethod
+ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
+ latent_image_ids = torch.zeros(height // 2, width // 2, 3)
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
+
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
+
+ latent_image_ids = latent_image_ids.reshape(
+ latent_image_id_height * latent_image_id_width, latent_image_id_channels
+ )
+
+ return latent_image_ids.to(device=device, dtype=dtype)
+
+ @staticmethod
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
+
+ return latents
+
+ @staticmethod
+ def _unpack_latents(latents, height, width, vae_scale_factor):
+ batch_size, num_patches, channels = latents.shape
+
+ height = height // vae_scale_factor
+ width = width // vae_scale_factor
+
+ latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
+
+ latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
+
+ return latents
+
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.vae.enable_slicing()
+
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_slicing()
+
+ def enable_vae_tiling(self):
+ r"""
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
+ processing larger images.
+ """
+ self.vae.enable_tiling()
+
+ def disable_vae_tiling(self):
+ r"""
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_tiling()
+
+ def prepare_latents(
+ self,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ latents=None,
+ ):
+ height = 2 * (int(height) // self.vae_scale_factor)
+ width = 2 * (int(width) // self.vae_scale_factor)
+
+ shape = (batch_size, num_channels_latents, height, width)
+
+ if latents is not None:
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
+ return latents.to(device=device, dtype=dtype), latent_image_ids
+
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
+
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
+
+ return latents, latent_image_ids
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def joint_attention_kwargs(self):
+ return self._joint_attention_kwargs
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ prompt_2: Optional[Union[str, List[str]]] = None,
+ negative_prompt: Union[str, List[str]] = None, #
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
+ true_cfg: float = 1.0, #
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 28,
+ timesteps: List[int] = None,
+ guidance_scale: float = 3.5,
+ num_images_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 512,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ will be used instead
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
+ passed will be used. Must be in descending order.
+ guidance_scale (`float`, *optional*, defaults to 7.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
+ joint_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
+ is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
+ images.
+ """
+
+ height = height or self.default_sample_size * self.vae_scale_factor
+ width = width or self.default_sample_size * self.vae_scale_factor
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ prompt_2,
+ height,
+ width,
+ negative_prompt=negative_prompt,
+ negative_prompt_2=negative_prompt_2,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ max_sequence_length=max_sequence_length,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._joint_attention_kwargs = joint_attention_kwargs
+ self._interrupt = False
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ lora_scale = (
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
+ )
+ do_true_cfg = true_cfg > 1 and negative_prompt is not None
+ (
+ prompt_embeds,
+ pooled_prompt_embeds,
+ text_ids,
+ negative_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ ) = self.encode_prompt(
+ prompt=prompt,
+ prompt_2=prompt_2,
+ negative_prompt=negative_prompt,
+ negative_prompt_2=negative_prompt_2,
+ prompt_embeds=prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ lora_scale=lora_scale,
+ do_true_cfg=do_true_cfg,
+ )
+
+ if do_true_cfg:
+ # Concatenate embeddings
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+ pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
+
+ # 4. Prepare latent variables
+ num_channels_latents = self.transformer.config.in_channels // 4
+ latents, latent_image_ids = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 5. Prepare timesteps
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
+ image_seq_len = latents.shape[1]
+ mu = calculate_shift(
+ image_seq_len,
+ self.scheduler.config.base_image_seq_len,
+ self.scheduler.config.max_image_seq_len,
+ self.scheduler.config.base_shift,
+ self.scheduler.config.max_shift,
+ )
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler,
+ num_inference_steps,
+ device,
+ timesteps,
+ sigmas,
+ mu=mu,
+ )
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ self._num_timesteps = len(timesteps)
+
+ # 6. Denoising loop
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ latent_model_input = torch.cat([latents] * 2) if do_true_cfg else latents
+
+ # handle guidance
+ if self.transformer.config.guidance_embeds:
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
+ guidance = guidance.expand(latent_model_input.shape[0])
+ else:
+ guidance = None
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latent_model_input.shape[0]).to(latent_model_input.dtype)
+
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep / 1000,
+ guidance=guidance,
+ pooled_projections=pooled_prompt_embeds,
+ encoder_hidden_states=prompt_embeds,
+ txt_ids=text_ids,
+ img_ids=latent_image_ids,
+ joint_attention_kwargs=self.joint_attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ if do_true_cfg:
+ neg_noise_pred, noise_pred = noise_pred.chunk(2)
+ noise_pred = neg_noise_pred + true_cfg * (noise_pred - neg_noise_pred)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents_dtype = latents.dtype
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ if latents.dtype != latents_dtype:
+ if torch.backends.mps.is_available():
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
+ latents = latents.to(latents_dtype)
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ if output_type == "latent":
+ image = latents
+
+ else:
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
+ image = self.vae.decode(latents, return_dict=False)[0]
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image,)
+
+ return FluxPipelineOutput(images=image)
diff --git a/diffusers/examples/community/pipeline_hunyuandit_differential_img2img.py b/diffusers/examples/community/pipeline_hunyuandit_differential_img2img.py
new file mode 100644
index 0000000000000000000000000000000000000000..3ece670e5bde8e822f0411a8e41bc168eb94153b
--- /dev/null
+++ b/diffusers/examples/community/pipeline_hunyuandit_differential_img2img.py
@@ -0,0 +1,1144 @@
+# Copyright 2024 HunyuanDiT Authors and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Callable, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+from transformers import (
+ BertModel,
+ BertTokenizer,
+ CLIPImageProcessor,
+ MT5Tokenizer,
+ T5EncoderModel,
+)
+
+from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
+from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
+from diffusers.models import AutoencoderKL, HunyuanDiT2DModel
+from diffusers.models.embeddings import get_2d_rotary_pos_embed
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline
+from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
+from diffusers.pipelines.stable_diffusion.safety_checker import (
+ StableDiffusionSafetyChecker,
+)
+from diffusers.schedulers import DDPMScheduler
+from diffusers.utils import (
+ deprecate,
+ is_torch_xla_available,
+ logging,
+ replace_example_docstring,
+)
+from diffusers.utils.torch_utils import randn_tensor
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import FlowMatchEulerDiscreteScheduler
+ >>> from diffusers.utils import load_image
+ >>> from PIL import Image
+ >>> from torchvision import transforms
+ >>> from pipeline_hunyuandit_differential_img2img import HunyuanDiTDifferentialImg2ImgPipeline
+ >>> pipe = HunyuanDiTDifferentialImg2ImgPipeline.from_pretrained(
+ >>> "Tencent-Hunyuan/HunyuanDiT-Diffusers", torch_dtype=torch.float16
+ >>> ).to("cuda")
+ >>> source_image = load_image(
+ >>> "https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/20240329211129_4024911930.png"
+ >>> )
+ >>> map = load_image(
+ >>> "https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/gradient_mask_2.png"
+ >>> )
+ >>> prompt = "a green pear"
+ >>> negative_prompt = "blurry"
+ >>> image = pipe(
+ >>> prompt=prompt,
+ >>> negative_prompt=negative_prompt,
+ >>> image=source_image,
+ >>> num_inference_steps=28,
+ >>> guidance_scale=4.5,
+ >>> strength=1.0,
+ >>> map=map,
+ >>> ).images[0]
+
+ ```
+"""
+
+STANDARD_RATIO = np.array(
+ [
+ 1.0, # 1:1
+ 4.0 / 3.0, # 4:3
+ 3.0 / 4.0, # 3:4
+ 16.0 / 9.0, # 16:9
+ 9.0 / 16.0, # 9:16
+ ]
+)
+STANDARD_SHAPE = [
+ [(1024, 1024), (1280, 1280)], # 1:1
+ [(1024, 768), (1152, 864), (1280, 960)], # 4:3
+ [(768, 1024), (864, 1152), (960, 1280)], # 3:4
+ [(1280, 768)], # 16:9
+ [(768, 1280)], # 9:16
+]
+STANDARD_AREA = [np.array([w * h for w, h in shapes]) for shapes in STANDARD_SHAPE]
+SUPPORTED_SHAPE = [
+ (1024, 1024),
+ (1280, 1280), # 1:1
+ (1024, 768),
+ (1152, 864),
+ (1280, 960), # 4:3
+ (768, 1024),
+ (864, 1152),
+ (960, 1280), # 3:4
+ (1280, 768), # 16:9
+ (768, 1280), # 9:16
+]
+
+
+def map_to_standard_shapes(target_width, target_height):
+ target_ratio = target_width / target_height
+ closest_ratio_idx = np.argmin(np.abs(STANDARD_RATIO - target_ratio))
+ closest_area_idx = np.argmin(np.abs(STANDARD_AREA[closest_ratio_idx] - target_width * target_height))
+ width, height = STANDARD_SHAPE[closest_ratio_idx][closest_area_idx]
+ return width, height
+
+
+def get_resize_crop_region_for_grid(src, tgt_size):
+ th = tw = tgt_size
+ h, w = src
+
+ r = h / w
+
+ # resize
+ if r > 1:
+ resize_height = th
+ resize_width = int(round(th / h * w))
+ else:
+ resize_width = tw
+ resize_height = int(round(tw / w * h))
+
+ crop_top = int(round((th - resize_height) / 2.0))
+ crop_left = int(round((tw - resize_width) / 2.0))
+
+ return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
+def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
+ """
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
+ """
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
+ # rescale the results from guidance (fixes overexposure)
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
+ return noise_cfg
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor,
+ generator: Optional[torch.Generator] = None,
+ sample_mode: str = "sample",
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ """
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class HunyuanDiTDifferentialImg2ImgPipeline(DiffusionPipeline):
+ r"""
+ Differential Pipeline for English/Chinese-to-image generation using HunyuanDiT.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ HunyuanDiT uses two text encoders: [mT5](https://huggingface.co/google/mt5-base) and [bilingual CLIP](fine-tuned by
+ ourselves)
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. We use
+ `sdxl-vae-fp16-fix`.
+ text_encoder (Optional[`~transformers.BertModel`, `~transformers.CLIPTextModel`]):
+ Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
+ HunyuanDiT uses a fine-tuned [bilingual CLIP].
+ tokenizer (Optional[`~transformers.BertTokenizer`, `~transformers.CLIPTokenizer`]):
+ A `BertTokenizer` or `CLIPTokenizer` to tokenize text.
+ transformer ([`HunyuanDiT2DModel`]):
+ The HunyuanDiT model designed by Tencent Hunyuan.
+ text_encoder_2 (`T5EncoderModel`):
+ The mT5 embedder. Specifically, it is 't5-v1_1-xxl'.
+ tokenizer_2 (`MT5Tokenizer`):
+ The tokenizer for the mT5 embedder.
+ scheduler ([`DDPMScheduler`]):
+ A scheduler to be used in combination with HunyuanDiT to denoise the encoded image latents.
+ """
+
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
+ _optional_components = [
+ "safety_checker",
+ "feature_extractor",
+ "text_encoder_2",
+ "tokenizer_2",
+ "text_encoder",
+ "tokenizer",
+ ]
+ _exclude_from_cpu_offload = ["safety_checker"]
+ _callback_tensor_inputs = [
+ "latents",
+ "prompt_embeds",
+ "negative_prompt_embeds",
+ "prompt_embeds_2",
+ "negative_prompt_embeds_2",
+ ]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: BertModel,
+ tokenizer: BertTokenizer,
+ transformer: HunyuanDiT2DModel,
+ scheduler: DDPMScheduler,
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPImageProcessor,
+ requires_safety_checker: bool = True,
+ text_encoder_2=T5EncoderModel,
+ tokenizer_2=MT5Tokenizer,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ tokenizer_2=tokenizer_2,
+ transformer=transformer,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ text_encoder_2=text_encoder_2,
+ )
+
+ if safety_checker is None and requires_safety_checker:
+ logger.warning(
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
+ )
+
+ if safety_checker is not None and feature_extractor is None:
+ raise ValueError(
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
+ )
+
+ self.vae_scale_factor = (
+ 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
+ )
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+ self.mask_processor = VaeImageProcessor(
+ vae_scale_factor=self.vae_scale_factor,
+ do_normalize=False,
+ do_convert_grayscale=True,
+ )
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
+ self.default_sample_size = (
+ self.transformer.config.sample_size
+ if hasattr(self, "transformer") and self.transformer is not None
+ else 128
+ )
+
+ # copied from diffusers.pipelines.huanyuandit.pipeline_huanyuandit.HunyuanDiTPipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt: str,
+ device: torch.device = None,
+ dtype: torch.dtype = None,
+ num_images_per_prompt: int = 1,
+ do_classifier_free_guidance: bool = True,
+ negative_prompt: Optional[str] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_attention_mask: Optional[torch.Tensor] = None,
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
+ max_sequence_length: Optional[int] = None,
+ text_encoder_index: int = 0,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ dtype (`torch.dtype`):
+ torch dtype
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ prompt_attention_mask (`torch.Tensor`, *optional*):
+ Attention mask for the prompt. Required when `prompt_embeds` is passed directly.
+ negative_prompt_attention_mask (`torch.Tensor`, *optional*):
+ Attention mask for the negative prompt. Required when `negative_prompt_embeds` is passed directly.
+ max_sequence_length (`int`, *optional*): maximum sequence length to use for the prompt.
+ text_encoder_index (`int`, *optional*):
+ Index of the text encoder to use. `0` for clip and `1` for T5.
+ """
+ if dtype is None:
+ if self.text_encoder_2 is not None:
+ dtype = self.text_encoder_2.dtype
+ elif self.transformer is not None:
+ dtype = self.transformer.dtype
+ else:
+ dtype = None
+
+ if device is None:
+ device = self._execution_device
+
+ tokenizers = [self.tokenizer, self.tokenizer_2]
+ text_encoders = [self.text_encoder, self.text_encoder_2]
+
+ tokenizer = tokenizers[text_encoder_index]
+ text_encoder = text_encoders[text_encoder_index]
+
+ if max_sequence_length is None:
+ if text_encoder_index == 0:
+ max_length = 77
+ if text_encoder_index == 1:
+ max_length = 256
+ else:
+ max_length = max_sequence_length
+
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ text_inputs = tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_attention_mask=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ prompt_attention_mask = text_inputs.attention_mask.to(device)
+ prompt_embeds = text_encoder(
+ text_input_ids.to(device),
+ attention_mask=prompt_attention_mask,
+ )
+ prompt_embeds = prompt_embeds[0]
+ prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
+
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ negative_prompt_attention_mask = uncond_input.attention_mask.to(device)
+ negative_prompt_embeds = text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=negative_prompt_attention_mask,
+ )
+ negative_prompt_embeds = negative_prompt_embeds[0]
+ negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ return (
+ prompt_embeds,
+ negative_prompt_embeds,
+ prompt_attention_mask,
+ negative_prompt_attention_mask,
+ )
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
+ def run_safety_checker(self, image, device, dtype):
+ if self.safety_checker is None:
+ has_nsfw_concept = None
+ else:
+ if torch.is_tensor(image):
+ feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
+ else:
+ feature_extractor_input = self.image_processor.numpy_to_pil(image)
+ safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
+ image, has_nsfw_concept = self.safety_checker(
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
+ )
+ return image, has_nsfw_concept
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ prompt_attention_mask=None,
+ negative_prompt_attention_mask=None,
+ prompt_embeds_2=None,
+ negative_prompt_embeds_2=None,
+ prompt_attention_mask_2=None,
+ negative_prompt_attention_mask_2=None,
+ callback_on_step_end_tensor_inputs=None,
+ ):
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is None and prompt_embeds_2 is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds_2`. Cannot leave both `prompt` and `prompt_embeds_2` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if prompt_embeds is not None and prompt_attention_mask is None:
+ raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
+
+ if prompt_embeds_2 is not None and prompt_attention_mask_2 is None:
+ raise ValueError("Must provide `prompt_attention_mask_2` when specifying `prompt_embeds_2`.")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
+ raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
+
+ if negative_prompt_embeds_2 is not None and negative_prompt_attention_mask_2 is None:
+ raise ValueError(
+ "Must provide `negative_prompt_attention_mask_2` when specifying `negative_prompt_embeds_2`."
+ )
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+ if prompt_embeds_2 is not None and negative_prompt_embeds_2 is not None:
+ if prompt_embeds_2.shape != negative_prompt_embeds_2.shape:
+ raise ValueError(
+ "`prompt_embeds_2` and `negative_prompt_embeds_2` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds_2` {prompt_embeds_2.shape} != `negative_prompt_embeds_2`"
+ f" {negative_prompt_embeds_2.shape}."
+ )
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
+ def get_timesteps(self, num_inference_steps, strength, device):
+ # get the original timestep using init_timestep
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
+
+ t_start = max(num_inference_steps - init_timestep, 0)
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
+ if hasattr(self.scheduler, "set_begin_index"):
+ self.scheduler.set_begin_index(t_start * self.scheduler.order)
+
+ return timesteps, num_inference_steps - t_start
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.prepare_latents
+ def prepare_latents(
+ self,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ image,
+ timestep,
+ dtype,
+ device,
+ generator=None,
+ ):
+ shape = (
+ batch_size,
+ num_channels_latents,
+ int(height) // self.vae_scale_factor,
+ int(width) // self.vae_scale_factor,
+ )
+
+ image = image.to(device=device, dtype=dtype)
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+ elif isinstance(generator, list):
+ init_latents = [
+ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) for i in range(batch_size)
+ ]
+ init_latents = torch.cat(init_latents, dim=0)
+
+ else:
+ init_latents = retrieve_latents(self.vae.encode(image), generator=generator)
+
+ init_latents = init_latents * self.vae.config.scaling_factor
+ if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
+ # expand init_latents for batch_size
+ deprecation_message = (
+ f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial"
+ " images (`image`). Initial images are now duplicating to match the number of text prompts. Note"
+ " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
+ " your script to pass as many initial images as text prompts to suppress this warning."
+ )
+ deprecate(
+ "len(prompt) != len(image)",
+ "1.0.0",
+ deprecation_message,
+ standard_warn=False,
+ )
+ additional_image_per_prompt = batch_size // init_latents.shape[0]
+ init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0)
+ elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
+ raise ValueError(
+ f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
+ )
+ else:
+ init_latents = torch.cat([init_latents], dim=0)
+
+ shape = init_latents.shape
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+
+ # get latents
+ init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
+ latents = init_latents
+
+ return latents
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def guidance_rescale(self):
+ return self._guidance_rescale
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ image: PipelineImageInput = None,
+ strength: float = 0.8,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: Optional[int] = 50,
+ timesteps: List[int] = None,
+ sigmas: List[float] = None,
+ guidance_scale: Optional[float] = 5.0,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: Optional[float] = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_embeds_2: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds_2: Optional[torch.Tensor] = None,
+ prompt_attention_mask: Optional[torch.Tensor] = None,
+ prompt_attention_mask_2: Optional[torch.Tensor] = None,
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
+ negative_prompt_attention_mask_2: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback_on_step_end: Optional[
+ Union[
+ Callable[[int, int, Dict], None],
+ PipelineCallback,
+ MultiPipelineCallbacks,
+ ]
+ ] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ guidance_rescale: float = 0.0,
+ original_size: Optional[Tuple[int, int]] = (1024, 1024),
+ target_size: Optional[Tuple[int, int]] = None,
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
+ use_resolution_binning: bool = True,
+ map: PipelineImageInput = None,
+ denoising_start: Optional[float] = None,
+ ):
+ r"""
+ The call function to the pipeline for generation with HunyuanDiT.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
+ image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
+ `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
+ numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
+ or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
+ list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
+ latents as `image`, but if passing latents directly it is not encoded again.
+ strength (`float`, *optional*, defaults to 0.8):
+ Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
+ starting point and more noise is added the higher the `strength`. The number of denoising steps depends
+ on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
+ process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
+ essentially ignores `image`.
+ height (`int`):
+ The height in pixels of the generated image.
+ width (`int`):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference. This parameter is modulated by `strength`.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
+ passed will be used. Must be in descending order.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
+ will be used.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ A higher guidance scale value encourages the model to generate images closely linked to the text
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
+ generation deterministic.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
+ provided, text embeddings are generated from the `prompt` input argument.
+ prompt_embeds_2 (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
+ provided, text embeddings are generated from the `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
+ negative_prompt_embeds_2 (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
+ prompt_attention_mask (`torch.Tensor`, *optional*):
+ Attention mask for the prompt. Required when `prompt_embeds` is passed directly.
+ prompt_attention_mask_2 (`torch.Tensor`, *optional*):
+ Attention mask for the prompt. Required when `prompt_embeds_2` is passed directly.
+ negative_prompt_attention_mask (`torch.Tensor`, *optional*):
+ Attention mask for the negative prompt. Required when `negative_prompt_embeds` is passed directly.
+ negative_prompt_attention_mask_2 (`torch.Tensor`, *optional*):
+ Attention mask for the negative prompt. Required when `negative_prompt_embeds_2` is passed directly.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback_on_step_end (`Callable[[int, int, Dict], None]`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
+ A callback function or a list of callback functions to be called at the end of each denoising step.
+ callback_on_step_end_tensor_inputs (`List[str]`, *optional*):
+ A list of tensor inputs that should be passed to the callback function. If not defined, all tensor
+ inputs will be passed.
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
+ Rescale the noise_cfg according to `guidance_rescale`. Based on findings of [Common Diffusion Noise
+ Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
+ original_size (`Tuple[int, int]`, *optional*, defaults to `(1024, 1024)`):
+ The original size of the image. Used to calculate the time ids.
+ target_size (`Tuple[int, int]`, *optional*):
+ The target size of the image. Used to calculate the time ids.
+ crops_coords_top_left (`Tuple[int, int]`, *optional*, defaults to `(0, 0)`):
+ The top left coordinates of the crop. Used to calculate the time ids.
+ use_resolution_binning (`bool`, *optional*, defaults to `True`):
+ Whether to use resolution binning or not. If `True`, the input resolution will be mapped to the closest
+ standard resolution. Supported resolutions are 1024x1024, 1280x1280, 1024x768, 1152x864, 1280x960,
+ 768x1024, 864x1152, 960x1280, 1280x768, and 768x1280. It is recommended to set this to `True`.
+ denoising_start (`float`, *optional*):
+ When specified, indicates the fraction (between 0.0 and 1.0) of the total denoising process to be
+ bypassed before it is initiated. Consequently, the initial part of the denoising process is skipped and
+ it is assumed that the passed `image` is a partly denoised image. Note that when this is specified,
+ strength will be ignored. The `denoising_start` parameter is particularly beneficial when this pipeline
+ is integrated into a "Mixture of Denoisers" multi-pipeline setup, as detailed in [**Refining the Image
+ Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output).
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
+ otherwise a `tuple` is returned where the first element is a list with the generated images and the
+ second element is a list of `bool`s indicating whether the corresponding generated image contains
+ "not-safe-for-work" (nsfw) content.
+ """
+
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+
+ # 0. default height and width
+ height = height or self.default_sample_size * self.vae_scale_factor
+ width = width or self.default_sample_size * self.vae_scale_factor
+ height = int((height // 16) * 16)
+ width = int((width // 16) * 16)
+
+ if use_resolution_binning and (height, width) not in SUPPORTED_SHAPE:
+ width, height = map_to_standard_shapes(width, height)
+ height = int(height)
+ width = int(width)
+ logger.warning(f"Reshaped to (height, width)=({height}, {width}), Supported shapes are {SUPPORTED_SHAPE}")
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ height,
+ width,
+ negative_prompt,
+ prompt_embeds,
+ negative_prompt_embeds,
+ prompt_attention_mask,
+ negative_prompt_attention_mask,
+ prompt_embeds_2,
+ negative_prompt_embeds_2,
+ prompt_attention_mask_2,
+ negative_prompt_attention_mask_2,
+ callback_on_step_end_tensor_inputs,
+ )
+ self._guidance_scale = guidance_scale
+ self._guidance_rescale = guidance_rescale
+ self._interrupt = False
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ # 3. Encode input prompt
+
+ (
+ prompt_embeds,
+ negative_prompt_embeds,
+ prompt_attention_mask,
+ negative_prompt_attention_mask,
+ ) = self.encode_prompt(
+ prompt=prompt,
+ device=device,
+ dtype=self.transformer.dtype,
+ num_images_per_prompt=num_images_per_prompt,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ negative_prompt=negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ prompt_attention_mask=prompt_attention_mask,
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
+ max_sequence_length=77,
+ text_encoder_index=0,
+ )
+ (
+ prompt_embeds_2,
+ negative_prompt_embeds_2,
+ prompt_attention_mask_2,
+ negative_prompt_attention_mask_2,
+ ) = self.encode_prompt(
+ prompt=prompt,
+ device=device,
+ dtype=self.transformer.dtype,
+ num_images_per_prompt=num_images_per_prompt,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ negative_prompt=negative_prompt,
+ prompt_embeds=prompt_embeds_2,
+ negative_prompt_embeds=negative_prompt_embeds_2,
+ prompt_attention_mask=prompt_attention_mask_2,
+ negative_prompt_attention_mask=negative_prompt_attention_mask_2,
+ max_sequence_length=256,
+ text_encoder_index=1,
+ )
+
+ # 4. Preprocess image
+ init_image = self.image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
+ map = self.mask_processor.preprocess(
+ map,
+ height=height // self.vae_scale_factor,
+ width=width // self.vae_scale_factor,
+ ).to(device)
+
+ # 5. Prepare timesteps
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler, num_inference_steps, device, timesteps, sigmas
+ )
+
+ # begin diff diff change
+ total_time_steps = num_inference_steps
+ # end diff diff change
+
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
+
+ # 6. Prepare latent variables
+ num_channels_latents = self.transformer.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ init_image,
+ latent_timestep,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ )
+
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 8. create image_rotary_emb, style embedding & time ids
+ grid_height = height // 8 // self.transformer.config.patch_size
+ grid_width = width // 8 // self.transformer.config.patch_size
+ base_size = 512 // 8 // self.transformer.config.patch_size
+ grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size)
+ image_rotary_emb = get_2d_rotary_pos_embed(
+ self.transformer.inner_dim // self.transformer.num_heads,
+ grid_crops_coords,
+ (grid_height, grid_width),
+ )
+
+ style = torch.tensor([0], device=device)
+
+ target_size = target_size or (height, width)
+ add_time_ids = list(original_size + target_size + crops_coords_top_left)
+ add_time_ids = torch.tensor([add_time_ids], dtype=prompt_embeds.dtype)
+
+ if self.do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+ prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask])
+ prompt_embeds_2 = torch.cat([negative_prompt_embeds_2, prompt_embeds_2])
+ prompt_attention_mask_2 = torch.cat([negative_prompt_attention_mask_2, prompt_attention_mask_2])
+ add_time_ids = torch.cat([add_time_ids] * 2, dim=0)
+ style = torch.cat([style] * 2, dim=0)
+
+ prompt_embeds = prompt_embeds.to(device=device)
+ prompt_attention_mask = prompt_attention_mask.to(device=device)
+ prompt_embeds_2 = prompt_embeds_2.to(device=device)
+ prompt_attention_mask_2 = prompt_attention_mask_2.to(device=device)
+ add_time_ids = add_time_ids.to(dtype=prompt_embeds.dtype, device=device).repeat(
+ batch_size * num_images_per_prompt, 1
+ )
+ style = style.to(device=device).repeat(batch_size * num_images_per_prompt)
+ # 9. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ # preparations for diff diff
+ original_with_noise = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ init_image,
+ timesteps,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ )
+ thresholds = torch.arange(total_time_steps, dtype=map.dtype) / total_time_steps
+ thresholds = thresholds.unsqueeze(1).unsqueeze(1).to(device)
+ masks = map.squeeze() > (thresholds + (denoising_start or 0))
+ # end diff diff preparations
+ self._num_timesteps = len(timesteps)
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+ # diff diff
+ if i == 0 and denoising_start is None:
+ latents = original_with_noise[:1]
+ else:
+ mask = masks[i].unsqueeze(0).to(latents.dtype)
+ mask = mask.unsqueeze(1) # fit shape
+ latents = original_with_noise[i] * mask + latents * (1 - mask)
+ # end diff diff
+
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # expand scalar t to 1-D tensor to match the 1st dim of latent_model_input
+ t_expand = torch.tensor([t] * latent_model_input.shape[0], device=device).to(
+ dtype=latent_model_input.dtype
+ )
+
+ # predict the noise residual
+ noise_pred = self.transformer(
+ latent_model_input,
+ t_expand,
+ encoder_hidden_states=prompt_embeds,
+ text_embedding_mask=prompt_attention_mask,
+ encoder_hidden_states_t5=prompt_embeds_2,
+ text_embedding_mask_t5=prompt_attention_mask_2,
+ image_meta_size=add_time_ids,
+ style=style,
+ image_rotary_emb=image_rotary_emb,
+ return_dict=False,
+ )[0]
+
+ noise_pred, _ = noise_pred.chunk(2, dim=1)
+
+ # perform guidance
+ if self.do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ if self.do_classifier_free_guidance and guidance_rescale > 0.0:
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+ prompt_embeds_2 = callback_outputs.pop("prompt_embeds_2", prompt_embeds_2)
+ negative_prompt_embeds_2 = callback_outputs.pop(
+ "negative_prompt_embeds_2", negative_prompt_embeds_2
+ )
+
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ if not output_type == "latent":
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
+ else:
+ image = latents
+ has_nsfw_concept = None
+
+ if has_nsfw_concept is None:
+ do_denormalize = [True] * image.shape[0]
+ else:
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
+
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
diff --git a/diffusers/examples/community/pipeline_null_text_inversion.py b/diffusers/examples/community/pipeline_null_text_inversion.py
new file mode 100644
index 0000000000000000000000000000000000000000..7e27b4647bc9fbc7e7db03ff3c32f2a1441e9719
--- /dev/null
+++ b/diffusers/examples/community/pipeline_null_text_inversion.py
@@ -0,0 +1,260 @@
+import inspect
+import os
+
+import numpy as np
+import torch
+import torch.nn.functional as nnf
+from PIL import Image
+from torch.optim.adam import Adam
+from tqdm import tqdm
+
+from diffusers import StableDiffusionPipeline
+from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
+
+
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps=None,
+ device=None,
+ timesteps=None,
+ **kwargs,
+):
+ """
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used,
+ `timesteps` must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
+ timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
+ must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class NullTextPipeline(StableDiffusionPipeline):
+ def get_noise_pred(self, latents, t, context):
+ latents_input = torch.cat([latents] * 2)
+ guidance_scale = 7.5
+ noise_pred = self.unet(latents_input, t, encoder_hidden_states=context)["sample"]
+ noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)
+ latents = self.prev_step(noise_pred, t, latents)
+ return latents
+
+ def get_noise_pred_single(self, latents, t, context):
+ noise_pred = self.unet(latents, t, encoder_hidden_states=context)["sample"]
+ return noise_pred
+
+ @torch.no_grad()
+ def image2latent(self, image_path):
+ image = Image.open(image_path).convert("RGB")
+ image = np.array(image)
+ image = torch.from_numpy(image).float() / 127.5 - 1
+ image = image.permute(2, 0, 1).unsqueeze(0).to(self.device)
+ latents = self.vae.encode(image)["latent_dist"].mean
+ latents = latents * 0.18215
+ return latents
+
+ @torch.no_grad()
+ def latent2image(self, latents):
+ latents = 1 / 0.18215 * latents.detach()
+ image = self.vae.decode(latents)["sample"].detach()
+ image = self.processor.postprocess(image, output_type="pil")[0]
+ return image
+
+ def prev_step(self, model_output, timestep, sample):
+ prev_timestep = timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps
+ alpha_prod_t = self.scheduler.alphas_cumprod[timestep]
+ alpha_prod_t_prev = (
+ self.scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.scheduler.final_alpha_cumprod
+ )
+ beta_prod_t = 1 - alpha_prod_t
+ pred_original_sample = (sample - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5
+ pred_sample_direction = (1 - alpha_prod_t_prev) ** 0.5 * model_output
+ prev_sample = alpha_prod_t_prev**0.5 * pred_original_sample + pred_sample_direction
+ return prev_sample
+
+ def next_step(self, model_output, timestep, sample):
+ timestep, next_timestep = (
+ min(timestep - self.scheduler.config.num_train_timesteps // self.num_inference_steps, 999),
+ timestep,
+ )
+ alpha_prod_t = self.scheduler.alphas_cumprod[timestep] if timestep >= 0 else self.scheduler.final_alpha_cumprod
+ alpha_prod_t_next = self.scheduler.alphas_cumprod[next_timestep]
+ beta_prod_t = 1 - alpha_prod_t
+ next_original_sample = (sample - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5
+ next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output
+ next_sample = alpha_prod_t_next**0.5 * next_original_sample + next_sample_direction
+ return next_sample
+
+ def null_optimization(self, latents, context, num_inner_steps, epsilon):
+ uncond_embeddings, cond_embeddings = context.chunk(2)
+ uncond_embeddings_list = []
+ latent_cur = latents[-1]
+ bar = tqdm(total=num_inner_steps * self.num_inference_steps)
+ for i in range(self.num_inference_steps):
+ uncond_embeddings = uncond_embeddings.clone().detach()
+ uncond_embeddings.requires_grad = True
+ optimizer = Adam([uncond_embeddings], lr=1e-2 * (1.0 - i / 100.0))
+ latent_prev = latents[len(latents) - i - 2]
+ t = self.scheduler.timesteps[i]
+ with torch.no_grad():
+ noise_pred_cond = self.get_noise_pred_single(latent_cur, t, cond_embeddings)
+ for j in range(num_inner_steps):
+ noise_pred_uncond = self.get_noise_pred_single(latent_cur, t, uncond_embeddings)
+ noise_pred = noise_pred_uncond + 7.5 * (noise_pred_cond - noise_pred_uncond)
+ latents_prev_rec = self.prev_step(noise_pred, t, latent_cur)
+ loss = nnf.mse_loss(latents_prev_rec, latent_prev)
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+ loss_item = loss.item()
+ bar.update()
+ if loss_item < epsilon + i * 2e-5:
+ break
+ for j in range(j + 1, num_inner_steps):
+ bar.update()
+ uncond_embeddings_list.append(uncond_embeddings[:1].detach())
+ with torch.no_grad():
+ context = torch.cat([uncond_embeddings, cond_embeddings])
+ latent_cur = self.get_noise_pred(latent_cur, t, context)
+ bar.close()
+ return uncond_embeddings_list
+
+ @torch.no_grad()
+ def ddim_inversion_loop(self, latent, context):
+ self.scheduler.set_timesteps(self.num_inference_steps)
+ _, cond_embeddings = context.chunk(2)
+ all_latent = [latent]
+ latent = latent.clone().detach()
+ with torch.no_grad():
+ for i in range(0, self.num_inference_steps):
+ t = self.scheduler.timesteps[len(self.scheduler.timesteps) - i - 1]
+ noise_pred = self.unet(latent, t, encoder_hidden_states=cond_embeddings)["sample"]
+ latent = self.next_step(noise_pred, t, latent)
+ all_latent.append(latent)
+ return all_latent
+
+ def get_context(self, prompt):
+ uncond_input = self.tokenizer(
+ [""], padding="max_length", max_length=self.tokenizer.model_max_length, return_tensors="pt"
+ )
+ uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
+ text_input = self.tokenizer(
+ [prompt],
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
+ context = torch.cat([uncond_embeddings, text_embeddings])
+ return context
+
+ def invert(
+ self, image_path: str, prompt: str, num_inner_steps=10, early_stop_epsilon=1e-6, num_inference_steps=50
+ ):
+ self.num_inference_steps = num_inference_steps
+ context = self.get_context(prompt)
+ latent = self.image2latent(image_path)
+ ddim_latents = self.ddim_inversion_loop(latent, context)
+ if os.path.exists(image_path + ".pt"):
+ uncond_embeddings = torch.load(image_path + ".pt")
+ else:
+ uncond_embeddings = self.null_optimization(ddim_latents, context, num_inner_steps, early_stop_epsilon)
+ uncond_embeddings = torch.stack(uncond_embeddings, 0)
+ torch.save(uncond_embeddings, image_path + ".pt")
+ return ddim_latents[-1], uncond_embeddings
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt,
+ uncond_embeddings,
+ inverted_latent,
+ num_inference_steps: int = 50,
+ timesteps=None,
+ guidance_scale=7.5,
+ negative_prompt=None,
+ num_images_per_prompt=1,
+ generator=None,
+ latents=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ output_type="pil",
+ ):
+ self._guidance_scale = guidance_scale
+ # 0. Default height and width to unet
+ height = self.unet.config.sample_size * self.vae_scale_factor
+ width = self.unet.config.sample_size * self.vae_scale_factor
+ # to deal with lora scaling and other possible forward hook
+ callback_steps = None
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ height,
+ width,
+ callback_steps,
+ negative_prompt,
+ prompt_embeds,
+ negative_prompt_embeds,
+ )
+ # 2. Define call parameter
+ device = self._execution_device
+ # 3. Encode input prompt
+ prompt_embeds, _ = self.encode_prompt(
+ prompt,
+ device,
+ num_images_per_prompt,
+ self.do_classifier_free_guidance,
+ negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ )
+ # 4. Prepare timesteps
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
+ latents = inverted_latent
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ noise_pred_uncond = self.unet(latents, t, encoder_hidden_states=uncond_embeddings[i])["sample"]
+ noise_pred = self.unet(latents, t, encoder_hidden_states=prompt_embeds)["sample"]
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred - noise_pred_uncond)
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+ progress_bar.update()
+ if not output_type == "latent":
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
+ 0
+ ]
+ else:
+ image = latents
+ image = self.image_processor.postprocess(
+ image, output_type=output_type, do_denormalize=[True] * image.shape[0]
+ )
+ # Offload all models
+ self.maybe_free_model_hooks()
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=False)
diff --git a/diffusers/examples/community/pipeline_prompt2prompt.py b/diffusers/examples/community/pipeline_prompt2prompt.py
new file mode 100644
index 0000000000000000000000000000000000000000..508e84177928a86d537da98579b93407c698da21
--- /dev/null
+++ b/diffusers/examples/community/pipeline_prompt2prompt.py
@@ -0,0 +1,1422 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+import abc
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from packaging import version
+from transformers import (
+ CLIPImageProcessor,
+ CLIPTextModel,
+ CLIPTokenizer,
+ CLIPVisionModelWithProjection,
+)
+
+from diffusers import AutoencoderKL, DiffusionPipeline, UNet2DConditionModel
+from diffusers.configuration_utils import FrozenDict, deprecate
+from diffusers.image_processor import VaeImageProcessor
+from diffusers.loaders import (
+ FromSingleFileMixin,
+ IPAdapterMixin,
+ StableDiffusionLoraLoaderMixin,
+ TextualInversionLoaderMixin,
+)
+from diffusers.models.attention import Attention
+from diffusers.models.lora import adjust_lora_scale_text_encoder
+from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
+from diffusers.pipelines.stable_diffusion.safety_checker import (
+ StableDiffusionSafetyChecker,
+)
+from diffusers.schedulers import KarrasDiffusionSchedulers
+from diffusers.utils import (
+ USE_PEFT_BACKEND,
+ logging,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
+from diffusers.utils.torch_utils import randn_tensor
+
+
+logger = logging.get_logger(__name__)
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
+def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
+ """
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
+ """
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
+ # rescale the results from guidance (fixes overexposure)
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
+ return noise_cfg
+
+
+class Prompt2PromptPipeline(
+ DiffusionPipeline,
+ TextualInversionLoaderMixin,
+ StableDiffusionLoraLoaderMixin,
+ IPAdapterMixin,
+ FromSingleFileMixin,
+):
+ r"""
+ Pipeline for text-to-image generation using Stable Diffusion.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
+
+ The pipeline also inherits the following loading methods:
+ - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
+ - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
+ - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
+ - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
+ - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
+ text_encoder ([`~transformers.CLIPTextModel`]):
+ Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
+ tokenizer ([`~transformers.CLIPTokenizer`]):
+ A `CLIPTokenizer` to tokenize text.
+ unet ([`UNet2DConditionModel`]):
+ A `UNet2DConditionModel` to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
+ about a model's potential harms.
+ feature_extractor ([`~transformers.CLIPImageProcessor`]):
+ A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
+ """
+
+ model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
+ _exclude_from_cpu_offload = ["safety_checker"]
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
+ _optional_components = ["safety_checker", "feature_extractor"]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: KarrasDiffusionSchedulers,
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPImageProcessor,
+ image_encoder: CLIPVisionModelWithProjection = None,
+ requires_safety_checker: bool = True,
+ ):
+ super().__init__()
+
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
+ " file"
+ )
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["steps_offset"] = 1
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
+ )
+ deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["clip_sample"] = False
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if safety_checker is None and requires_safety_checker:
+ logger.warning(
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
+ )
+
+ if safety_checker is not None and feature_extractor is None:
+ raise ValueError(
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
+ )
+
+ is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
+ version.parse(unet.config._diffusers_version).base_version
+ ) < version.parse("0.9.0.dev0")
+ is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
+ deprecation_message = (
+ "The configuration file of the unet has set the default `sample_size` to smaller than"
+ " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
+ " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
+ " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
+ " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
+ " in the config might lead to incorrect results in future versions. If you have downloaded this"
+ " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
+ " the `unet/config.json` file"
+ )
+ deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(unet.config)
+ new_config["sample_size"] = 64
+ unet._internal_dict = FrozenDict(new_config)
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ image_encoder=image_encoder,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
+ def _encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt=None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ lora_scale: Optional[float] = None,
+ **kwargs,
+ ):
+ deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
+ deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
+
+ prompt_embeds_tuple = self.encode_prompt(
+ prompt=prompt,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ do_classifier_free_guidance=do_classifier_free_guidance,
+ negative_prompt=negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ lora_scale=lora_scale,
+ **kwargs,
+ )
+
+ # concatenate for backwards comp
+ prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
+
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt=None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ lora_scale: Optional[float] = None,
+ clip_skip: Optional[int] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ lora_scale (`float`, *optional*):
+ A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+ """
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ # dynamically adjust the LoRA scale
+ if not USE_PEFT_BACKEND:
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
+ else:
+ scale_lora_layers(self.text_encoder, lora_scale)
+
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ # textual inversion: process multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = self.tokenizer.batch_decode(
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
+ )
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = text_inputs.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ if clip_skip is None:
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
+ prompt_embeds = prompt_embeds[0]
+ else:
+ prompt_embeds = self.text_encoder(
+ text_input_ids.to(device),
+ attention_mask=attention_mask,
+ output_hidden_states=True,
+ )
+ # Access the `hidden_states` first, that contains a tuple of
+ # all the hidden states from the encoder layers. Then index into
+ # the tuple to access the hidden states from the desired layer.
+ prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
+ # We also need to apply the final LayerNorm here to not mess with the
+ # representations. The `last_hidden_states` that we typically use for
+ # obtaining the final prompt representations passes through the LayerNorm
+ # layer.
+ prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
+
+ if self.text_encoder is not None:
+ prompt_embeds_dtype = self.text_encoder.dtype
+ elif self.unet is not None:
+ prompt_embeds_dtype = self.unet.dtype
+ else:
+ prompt_embeds_dtype = prompt_embeds.dtype
+
+ prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ # textual inversion: process multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = uncond_input.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ negative_prompt_embeds = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ negative_prompt_embeds = negative_prompt_embeds[0]
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder, lora_scale)
+
+ return prompt_embeds, negative_prompt_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
+ def run_safety_checker(self, image, device, dtype):
+ if self.safety_checker is None:
+ has_nsfw_concept = None
+ else:
+ if torch.is_tensor(image):
+ feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
+ else:
+ feature_extractor_input = self.image_processor.numpy_to_pil(image)
+ safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
+ image, has_nsfw_concept = self.safety_checker(
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
+ )
+ return image, has_nsfw_concept
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ callback_steps,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ ip_adapter_image=None,
+ ip_adapter_image_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ ):
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
+ raise ValueError(
+ "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
+ )
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
+ def prepare_latents(
+ self,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ latents=None,
+ ):
+ shape = (
+ batch_size,
+ num_channels_latents,
+ height // self.vae_scale_factor,
+ width // self.vae_scale_factor,
+ )
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]],
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
+ callback_steps: Optional[int] = 1,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ guidance_rescale: float = 0.0,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image generation.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+
+ The keyword arguments to configure the edit are:
+ - edit_type (`str`). The edit type to apply. Can be either of `replace`, `refine`, `reweight`.
+ - n_cross_replace (`int`): Number of diffusion steps in which cross attention should be replaced
+ - n_self_replace (`int`): Number of diffusion steps in which self attention should be replaced
+ - local_blend_words(`List[str]`, *optional*, default to `None`): Determines which area should be
+ changed. If None, then the whole image can be changed.
+ - equalizer_words(`List[str]`, *optional*, default to `None`): Required for edit type `reweight`.
+ Determines which words should be enhanced.
+ - equalizer_strengths (`List[float]`, *optional*, default to `None`) Required for edit type `reweight`.
+ Determines which how much the words in `equalizer_words` should be enhanced.
+
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
+ Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when
+ using zero terminal SNR.
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+
+ self.controller = create_controller(
+ prompt,
+ cross_attention_kwargs,
+ num_inference_steps,
+ tokenizer=self.tokenizer,
+ device=self.device,
+ )
+ self.register_attention_control(self.controller) # add attention controller
+
+ # 0. Default height and width to unet
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(prompt, height, width, callback_steps)
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ text_encoder_lora_scale = (
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
+ )
+ prompt_embeds = self._encode_prompt(
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ lora_scale=text_encoder_lora_scale,
+ )
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.unet.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 7. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds).sample
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+ # step callback
+ latents = self.controller.step_callback(latents)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
+
+ # 8. Post-processing
+ if not output_type == "latent":
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
+ else:
+ image = latents
+ has_nsfw_concept = None
+
+ # 9. Run safety checker
+ if has_nsfw_concept is None:
+ do_denormalize = [True] * image.shape[0]
+ else:
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
+
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
+
+ # Offload last model to CPU
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.final_offload_hook.offload()
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
+
+ def register_attention_control(self, controller):
+ attn_procs = {}
+ cross_att_count = 0
+ for name in self.unet.attn_processors.keys():
+ (None if name.endswith("attn1.processor") else self.unet.config.cross_attention_dim)
+ if name.startswith("mid_block"):
+ self.unet.config.block_out_channels[-1]
+ place_in_unet = "mid"
+ elif name.startswith("up_blocks"):
+ block_id = int(name[len("up_blocks.")])
+ list(reversed(self.unet.config.block_out_channels))[block_id]
+ place_in_unet = "up"
+ elif name.startswith("down_blocks"):
+ block_id = int(name[len("down_blocks.")])
+ self.unet.config.block_out_channels[block_id]
+ place_in_unet = "down"
+ else:
+ continue
+ cross_att_count += 1
+ attn_procs[name] = P2PCrossAttnProcessor(controller=controller, place_in_unet=place_in_unet)
+
+ self.unet.set_attn_processor(attn_procs)
+ controller.num_att_layers = cross_att_count
+
+
+class P2PCrossAttnProcessor:
+ def __init__(self, controller, place_in_unet):
+ super().__init__()
+ self.controller = controller
+ self.place_in_unet = place_in_unet
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states,
+ encoder_hidden_states=None,
+ attention_mask=None,
+ ):
+ batch_size, sequence_length, _ = hidden_states.shape
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+
+ query = attn.to_q(hidden_states)
+
+ is_cross = encoder_hidden_states is not None
+ encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ query = attn.head_to_batch_dim(query)
+ key = attn.head_to_batch_dim(key)
+ value = attn.head_to_batch_dim(value)
+
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
+
+ # one line change
+ self.controller(attention_probs, is_cross, self.place_in_unet)
+
+ hidden_states = torch.bmm(attention_probs, value)
+ hidden_states = attn.batch_to_head_dim(hidden_states)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ return hidden_states
+
+
+def create_controller(
+ prompts: List[str],
+ cross_attention_kwargs: Dict,
+ num_inference_steps: int,
+ tokenizer,
+ device,
+) -> AttentionControl:
+ edit_type = cross_attention_kwargs.get("edit_type", None)
+ local_blend_words = cross_attention_kwargs.get("local_blend_words", None)
+ equalizer_words = cross_attention_kwargs.get("equalizer_words", None)
+ equalizer_strengths = cross_attention_kwargs.get("equalizer_strengths", None)
+ n_cross_replace = cross_attention_kwargs.get("n_cross_replace", 0.4)
+ n_self_replace = cross_attention_kwargs.get("n_self_replace", 0.4)
+
+ # only replace
+ if edit_type == "replace" and local_blend_words is None:
+ return AttentionReplace(
+ prompts,
+ num_inference_steps,
+ n_cross_replace,
+ n_self_replace,
+ tokenizer=tokenizer,
+ device=device,
+ )
+
+ # replace + localblend
+ if edit_type == "replace" and local_blend_words is not None:
+ lb = LocalBlend(prompts, local_blend_words, tokenizer=tokenizer, device=device)
+ return AttentionReplace(
+ prompts,
+ num_inference_steps,
+ n_cross_replace,
+ n_self_replace,
+ lb,
+ tokenizer=tokenizer,
+ device=device,
+ )
+
+ # only refine
+ if edit_type == "refine" and local_blend_words is None:
+ return AttentionRefine(
+ prompts,
+ num_inference_steps,
+ n_cross_replace,
+ n_self_replace,
+ tokenizer=tokenizer,
+ device=device,
+ )
+
+ # refine + localblend
+ if edit_type == "refine" and local_blend_words is not None:
+ lb = LocalBlend(prompts, local_blend_words, tokenizer=tokenizer, device=device)
+ return AttentionRefine(
+ prompts,
+ num_inference_steps,
+ n_cross_replace,
+ n_self_replace,
+ lb,
+ tokenizer=tokenizer,
+ device=device,
+ )
+
+ # reweight
+ if edit_type == "reweight":
+ assert (
+ equalizer_words is not None and equalizer_strengths is not None
+ ), "To use reweight edit, please specify equalizer_words and equalizer_strengths."
+ assert len(equalizer_words) == len(
+ equalizer_strengths
+ ), "equalizer_words and equalizer_strengths must be of same length."
+ equalizer = get_equalizer(prompts[1], equalizer_words, equalizer_strengths, tokenizer=tokenizer)
+ return AttentionReweight(
+ prompts,
+ num_inference_steps,
+ n_cross_replace,
+ n_self_replace,
+ tokenizer=tokenizer,
+ device=device,
+ equalizer=equalizer,
+ )
+
+ raise ValueError(f"Edit type {edit_type} not recognized. Use one of: replace, refine, reweight.")
+
+
+class AttentionControl(abc.ABC):
+ def step_callback(self, x_t):
+ return x_t
+
+ def between_steps(self):
+ return
+
+ @property
+ def num_uncond_att_layers(self):
+ return 0
+
+ @abc.abstractmethod
+ def forward(self, attn, is_cross: bool, place_in_unet: str):
+ raise NotImplementedError
+
+ def __call__(self, attn, is_cross: bool, place_in_unet: str):
+ if self.cur_att_layer >= self.num_uncond_att_layers:
+ h = attn.shape[0]
+ attn[h // 2 :] = self.forward(attn[h // 2 :], is_cross, place_in_unet)
+ self.cur_att_layer += 1
+ if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers:
+ self.cur_att_layer = 0
+ self.cur_step += 1
+ self.between_steps()
+ return attn
+
+ def reset(self):
+ self.cur_step = 0
+ self.cur_att_layer = 0
+
+ def __init__(self):
+ self.cur_step = 0
+ self.num_att_layers = -1
+ self.cur_att_layer = 0
+
+
+class EmptyControl(AttentionControl):
+ def forward(self, attn, is_cross: bool, place_in_unet: str):
+ return attn
+
+
+class AttentionStore(AttentionControl):
+ @staticmethod
+ def get_empty_store():
+ return {
+ "down_cross": [],
+ "mid_cross": [],
+ "up_cross": [],
+ "down_self": [],
+ "mid_self": [],
+ "up_self": [],
+ }
+
+ def forward(self, attn, is_cross: bool, place_in_unet: str):
+ key = f"{place_in_unet}_{'cross' if is_cross else 'self'}"
+ if attn.shape[1] <= 32**2: # avoid memory overhead
+ self.step_store[key].append(attn)
+ return attn
+
+ def between_steps(self):
+ if len(self.attention_store) == 0:
+ self.attention_store = self.step_store
+ else:
+ for key in self.attention_store:
+ for i in range(len(self.attention_store[key])):
+ self.attention_store[key][i] += self.step_store[key][i]
+ self.step_store = self.get_empty_store()
+
+ def get_average_attention(self):
+ average_attention = {
+ key: [item / self.cur_step for item in self.attention_store[key]] for key in self.attention_store
+ }
+ return average_attention
+
+ def reset(self):
+ super(AttentionStore, self).reset()
+ self.step_store = self.get_empty_store()
+ self.attention_store = {}
+
+ def __init__(self):
+ super(AttentionStore, self).__init__()
+ self.step_store = self.get_empty_store()
+ self.attention_store = {}
+
+
+class LocalBlend:
+ def __call__(self, x_t, attention_store):
+ k = 1
+ maps = attention_store["down_cross"][2:4] + attention_store["up_cross"][:3]
+ maps = [item.reshape(self.alpha_layers.shape[0], -1, 1, 16, 16, self.max_num_words) for item in maps]
+ maps = torch.cat(maps, dim=1)
+ maps = (maps * self.alpha_layers).sum(-1).mean(1)
+ mask = F.max_pool2d(maps, (k * 2 + 1, k * 2 + 1), (1, 1), padding=(k, k))
+ mask = F.interpolate(mask, size=(x_t.shape[2:]))
+ mask = mask / mask.max(2, keepdims=True)[0].max(3, keepdims=True)[0]
+ mask = mask.gt(self.threshold)
+ mask = (mask[:1] + mask[1:]).float()
+ x_t = x_t[:1] + mask * (x_t - x_t[:1])
+ return x_t
+
+ def __init__(
+ self,
+ prompts: List[str],
+ words: [List[List[str]]],
+ tokenizer,
+ device,
+ threshold=0.3,
+ max_num_words=77,
+ ):
+ self.max_num_words = 77
+
+ alpha_layers = torch.zeros(len(prompts), 1, 1, 1, 1, self.max_num_words)
+ for i, (prompt, words_) in enumerate(zip(prompts, words)):
+ if isinstance(words_, str):
+ words_ = [words_]
+ for word in words_:
+ ind = get_word_inds(prompt, word, tokenizer)
+ alpha_layers[i, :, :, :, :, ind] = 1
+ self.alpha_layers = alpha_layers.to(device)
+ self.threshold = threshold
+
+
+class AttentionControlEdit(AttentionStore, abc.ABC):
+ def step_callback(self, x_t):
+ if self.local_blend is not None:
+ x_t = self.local_blend(x_t, self.attention_store)
+ return x_t
+
+ def replace_self_attention(self, attn_base, att_replace):
+ if att_replace.shape[2] <= 16**2:
+ return attn_base.unsqueeze(0).expand(att_replace.shape[0], *attn_base.shape)
+ else:
+ return att_replace
+
+ @abc.abstractmethod
+ def replace_cross_attention(self, attn_base, att_replace):
+ raise NotImplementedError
+
+ def forward(self, attn, is_cross: bool, place_in_unet: str):
+ super(AttentionControlEdit, self).forward(attn, is_cross, place_in_unet)
+ # FIXME not replace correctly
+ if is_cross or (self.num_self_replace[0] <= self.cur_step < self.num_self_replace[1]):
+ h = attn.shape[0] // (self.batch_size)
+ attn = attn.reshape(self.batch_size, h, *attn.shape[1:])
+ attn_base, attn_repalce = attn[0], attn[1:]
+ if is_cross:
+ alpha_words = self.cross_replace_alpha[self.cur_step]
+ attn_repalce_new = (
+ self.replace_cross_attention(attn_base, attn_repalce) * alpha_words
+ + (1 - alpha_words) * attn_repalce
+ )
+ attn[1:] = attn_repalce_new
+ else:
+ attn[1:] = self.replace_self_attention(attn_base, attn_repalce)
+ attn = attn.reshape(self.batch_size * h, *attn.shape[2:])
+ return attn
+
+ def __init__(
+ self,
+ prompts,
+ num_steps: int,
+ cross_replace_steps: Union[float, Tuple[float, float], Dict[str, Tuple[float, float]]],
+ self_replace_steps: Union[float, Tuple[float, float]],
+ local_blend: Optional[LocalBlend],
+ tokenizer,
+ device,
+ ):
+ super(AttentionControlEdit, self).__init__()
+ # add tokenizer and device here
+
+ self.tokenizer = tokenizer
+ self.device = device
+
+ self.batch_size = len(prompts)
+ self.cross_replace_alpha = get_time_words_attention_alpha(
+ prompts, num_steps, cross_replace_steps, self.tokenizer
+ ).to(self.device)
+ if isinstance(self_replace_steps, float):
+ self_replace_steps = 0, self_replace_steps
+ self.num_self_replace = int(num_steps * self_replace_steps[0]), int(num_steps * self_replace_steps[1])
+ self.local_blend = local_blend # 在外面定义后传进来
+
+
+class AttentionReplace(AttentionControlEdit):
+ def replace_cross_attention(self, attn_base, att_replace):
+ return torch.einsum("hpw,bwn->bhpn", attn_base, self.mapper)
+
+ def __init__(
+ self,
+ prompts,
+ num_steps: int,
+ cross_replace_steps: float,
+ self_replace_steps: float,
+ local_blend: Optional[LocalBlend] = None,
+ tokenizer=None,
+ device=None,
+ ):
+ super(AttentionReplace, self).__init__(
+ prompts,
+ num_steps,
+ cross_replace_steps,
+ self_replace_steps,
+ local_blend,
+ tokenizer,
+ device,
+ )
+ self.mapper = get_replacement_mapper(prompts, self.tokenizer).to(self.device)
+
+
+class AttentionRefine(AttentionControlEdit):
+ def replace_cross_attention(self, attn_base, att_replace):
+ attn_base_replace = attn_base[:, :, self.mapper].permute(2, 0, 1, 3)
+ attn_replace = attn_base_replace * self.alphas + att_replace * (1 - self.alphas)
+ return attn_replace
+
+ def __init__(
+ self,
+ prompts,
+ num_steps: int,
+ cross_replace_steps: float,
+ self_replace_steps: float,
+ local_blend: Optional[LocalBlend] = None,
+ tokenizer=None,
+ device=None,
+ ):
+ super(AttentionRefine, self).__init__(
+ prompts,
+ num_steps,
+ cross_replace_steps,
+ self_replace_steps,
+ local_blend,
+ tokenizer,
+ device,
+ )
+ self.mapper, alphas = get_refinement_mapper(prompts, self.tokenizer)
+ self.mapper, alphas = self.mapper.to(self.device), alphas.to(self.device)
+ self.alphas = alphas.reshape(alphas.shape[0], 1, 1, alphas.shape[1])
+
+
+class AttentionReweight(AttentionControlEdit):
+ def replace_cross_attention(self, attn_base, att_replace):
+ if self.prev_controller is not None:
+ attn_base = self.prev_controller.replace_cross_attention(attn_base, att_replace)
+ attn_replace = attn_base[None, :, :, :] * self.equalizer[:, None, None, :]
+ return attn_replace
+
+ def __init__(
+ self,
+ prompts,
+ num_steps: int,
+ cross_replace_steps: float,
+ self_replace_steps: float,
+ equalizer,
+ local_blend: Optional[LocalBlend] = None,
+ controller: Optional[AttentionControlEdit] = None,
+ tokenizer=None,
+ device=None,
+ ):
+ super(AttentionReweight, self).__init__(
+ prompts,
+ num_steps,
+ cross_replace_steps,
+ self_replace_steps,
+ local_blend,
+ tokenizer,
+ device,
+ )
+ self.equalizer = equalizer.to(self.device)
+ self.prev_controller = controller
+
+
+### util functions for all Edits
+def update_alpha_time_word(
+ alpha,
+ bounds: Union[float, Tuple[float, float]],
+ prompt_ind: int,
+ word_inds: Optional[torch.Tensor] = None,
+):
+ if isinstance(bounds, float):
+ bounds = 0, bounds
+ start, end = int(bounds[0] * alpha.shape[0]), int(bounds[1] * alpha.shape[0])
+ if word_inds is None:
+ word_inds = torch.arange(alpha.shape[2])
+ alpha[:start, prompt_ind, word_inds] = 0
+ alpha[start:end, prompt_ind, word_inds] = 1
+ alpha[end:, prompt_ind, word_inds] = 0
+ return alpha
+
+
+def get_time_words_attention_alpha(
+ prompts,
+ num_steps,
+ cross_replace_steps: Union[float, Dict[str, Tuple[float, float]]],
+ tokenizer,
+ max_num_words=77,
+):
+ if not isinstance(cross_replace_steps, dict):
+ cross_replace_steps = {"default_": cross_replace_steps}
+ if "default_" not in cross_replace_steps:
+ cross_replace_steps["default_"] = (0.0, 1.0)
+ alpha_time_words = torch.zeros(num_steps + 1, len(prompts) - 1, max_num_words)
+ for i in range(len(prompts) - 1):
+ alpha_time_words = update_alpha_time_word(alpha_time_words, cross_replace_steps["default_"], i)
+ for key, item in cross_replace_steps.items():
+ if key != "default_":
+ inds = [get_word_inds(prompts[i], key, tokenizer) for i in range(1, len(prompts))]
+ for i, ind in enumerate(inds):
+ if len(ind) > 0:
+ alpha_time_words = update_alpha_time_word(alpha_time_words, item, i, ind)
+ alpha_time_words = alpha_time_words.reshape(num_steps + 1, len(prompts) - 1, 1, 1, max_num_words)
+ return alpha_time_words
+
+
+### util functions for LocalBlend and ReplacementEdit
+def get_word_inds(text: str, word_place: int, tokenizer):
+ split_text = text.split(" ")
+ if isinstance(word_place, str):
+ word_place = [i for i, word in enumerate(split_text) if word_place == word]
+ elif isinstance(word_place, int):
+ word_place = [word_place]
+ out = []
+ if len(word_place) > 0:
+ words_encode = [tokenizer.decode([item]).strip("#") for item in tokenizer.encode(text)][1:-1]
+ cur_len, ptr = 0, 0
+
+ for i in range(len(words_encode)):
+ cur_len += len(words_encode[i])
+ if ptr in word_place:
+ out.append(i + 1)
+ if cur_len >= len(split_text[ptr]):
+ ptr += 1
+ cur_len = 0
+ return np.array(out)
+
+
+### util functions for ReplacementEdit
+def get_replacement_mapper_(x: str, y: str, tokenizer, max_len=77):
+ words_x = x.split(" ")
+ words_y = y.split(" ")
+ if len(words_x) != len(words_y):
+ raise ValueError(
+ f"attention replacement edit can only be applied on prompts with the same length"
+ f" but prompt A has {len(words_x)} words and prompt B has {len(words_y)} words."
+ )
+ inds_replace = [i for i in range(len(words_y)) if words_y[i] != words_x[i]]
+ inds_source = [get_word_inds(x, i, tokenizer) for i in inds_replace]
+ inds_target = [get_word_inds(y, i, tokenizer) for i in inds_replace]
+ mapper = np.zeros((max_len, max_len))
+ i = j = 0
+ cur_inds = 0
+ while i < max_len and j < max_len:
+ if cur_inds < len(inds_source) and inds_source[cur_inds][0] == i:
+ inds_source_, inds_target_ = inds_source[cur_inds], inds_target[cur_inds]
+ if len(inds_source_) == len(inds_target_):
+ mapper[inds_source_, inds_target_] = 1
+ else:
+ ratio = 1 / len(inds_target_)
+ for i_t in inds_target_:
+ mapper[inds_source_, i_t] = ratio
+ cur_inds += 1
+ i += len(inds_source_)
+ j += len(inds_target_)
+ elif cur_inds < len(inds_source):
+ mapper[i, j] = 1
+ i += 1
+ j += 1
+ else:
+ mapper[j, j] = 1
+ i += 1
+ j += 1
+
+ return torch.from_numpy(mapper).float()
+
+
+def get_replacement_mapper(prompts, tokenizer, max_len=77):
+ x_seq = prompts[0]
+ mappers = []
+ for i in range(1, len(prompts)):
+ mapper = get_replacement_mapper_(x_seq, prompts[i], tokenizer, max_len)
+ mappers.append(mapper)
+ return torch.stack(mappers)
+
+
+### util functions for ReweightEdit
+def get_equalizer(
+ text: str,
+ word_select: Union[int, Tuple[int, ...]],
+ values: Union[List[float], Tuple[float, ...]],
+ tokenizer,
+):
+ if isinstance(word_select, (int, str)):
+ word_select = (word_select,)
+ equalizer = torch.ones(len(values), 77)
+ values = torch.tensor(values, dtype=torch.float32)
+ for word in word_select:
+ inds = get_word_inds(text, word, tokenizer)
+ equalizer[:, inds] = values
+ return equalizer
+
+
+### util functions for RefinementEdit
+class ScoreParams:
+ def __init__(self, gap, match, mismatch):
+ self.gap = gap
+ self.match = match
+ self.mismatch = mismatch
+
+ def mis_match_char(self, x, y):
+ if x != y:
+ return self.mismatch
+ else:
+ return self.match
+
+
+def get_matrix(size_x, size_y, gap):
+ matrix = np.zeros((size_x + 1, size_y + 1), dtype=np.int32)
+ matrix[0, 1:] = (np.arange(size_y) + 1) * gap
+ matrix[1:, 0] = (np.arange(size_x) + 1) * gap
+ return matrix
+
+
+def get_traceback_matrix(size_x, size_y):
+ matrix = np.zeros((size_x + 1, size_y + 1), dtype=np.int32)
+ matrix[0, 1:] = 1
+ matrix[1:, 0] = 2
+ matrix[0, 0] = 4
+ return matrix
+
+
+def global_align(x, y, score):
+ matrix = get_matrix(len(x), len(y), score.gap)
+ trace_back = get_traceback_matrix(len(x), len(y))
+ for i in range(1, len(x) + 1):
+ for j in range(1, len(y) + 1):
+ left = matrix[i, j - 1] + score.gap
+ up = matrix[i - 1, j] + score.gap
+ diag = matrix[i - 1, j - 1] + score.mis_match_char(x[i - 1], y[j - 1])
+ matrix[i, j] = max(left, up, diag)
+ if matrix[i, j] == left:
+ trace_back[i, j] = 1
+ elif matrix[i, j] == up:
+ trace_back[i, j] = 2
+ else:
+ trace_back[i, j] = 3
+ return matrix, trace_back
+
+
+def get_aligned_sequences(x, y, trace_back):
+ x_seq = []
+ y_seq = []
+ i = len(x)
+ j = len(y)
+ mapper_y_to_x = []
+ while i > 0 or j > 0:
+ if trace_back[i, j] == 3:
+ x_seq.append(x[i - 1])
+ y_seq.append(y[j - 1])
+ i = i - 1
+ j = j - 1
+ mapper_y_to_x.append((j, i))
+ elif trace_back[i][j] == 1:
+ x_seq.append("-")
+ y_seq.append(y[j - 1])
+ j = j - 1
+ mapper_y_to_x.append((j, -1))
+ elif trace_back[i][j] == 2:
+ x_seq.append(x[i - 1])
+ y_seq.append("-")
+ i = i - 1
+ elif trace_back[i][j] == 4:
+ break
+ mapper_y_to_x.reverse()
+ return x_seq, y_seq, torch.tensor(mapper_y_to_x, dtype=torch.int64)
+
+
+def get_mapper(x: str, y: str, tokenizer, max_len=77):
+ x_seq = tokenizer.encode(x)
+ y_seq = tokenizer.encode(y)
+ score = ScoreParams(0, 1, -1)
+ matrix, trace_back = global_align(x_seq, y_seq, score)
+ mapper_base = get_aligned_sequences(x_seq, y_seq, trace_back)[-1]
+ alphas = torch.ones(max_len)
+ alphas[: mapper_base.shape[0]] = mapper_base[:, 1].ne(-1).float()
+ mapper = torch.zeros(max_len, dtype=torch.int64)
+ mapper[: mapper_base.shape[0]] = mapper_base[:, 1]
+ mapper[mapper_base.shape[0] :] = len(y_seq) + torch.arange(max_len - len(y_seq))
+ return mapper, alphas
+
+
+def get_refinement_mapper(prompts, tokenizer, max_len=77):
+ x_seq = prompts[0]
+ mappers, alphas = [], []
+ for i in range(1, len(prompts)):
+ mapper, alpha = get_mapper(x_seq, prompts[i], tokenizer, max_len)
+ mappers.append(mapper)
+ alphas.append(alpha)
+ return torch.stack(mappers), torch.stack(alphas)
diff --git a/diffusers/examples/community/pipeline_sdxl_style_aligned.py b/diffusers/examples/community/pipeline_sdxl_style_aligned.py
new file mode 100644
index 0000000000000000000000000000000000000000..8328bc2caed92f02d21aea7ce386da3dfc6d94a9
--- /dev/null
+++ b/diffusers/examples/community/pipeline_sdxl_style_aligned.py
@@ -0,0 +1,1912 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# Based on [Style Aligned Image Generation via Shared Attention](https://arxiv.org/abs/2312.02133).
+# Authors: Amir Hertz, Andrey Voynov, Shlomi Fruchter, Daniel Cohen-Or
+# Project Page: https://style-aligned-gen.github.io/
+# Code: https://github.com/google/style-aligned
+#
+# Adapted to Diffusers by [Aryan V S](https://github.com/a-r-r-o-w/).
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from PIL import Image
+from transformers import (
+ CLIPImageProcessor,
+ CLIPTextModel,
+ CLIPTextModelWithProjection,
+ CLIPTokenizer,
+ CLIPVisionModelWithProjection,
+)
+
+from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
+from diffusers.loaders import (
+ FromSingleFileMixin,
+ IPAdapterMixin,
+ StableDiffusionXLLoraLoaderMixin,
+ TextualInversionLoaderMixin,
+)
+from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel
+from diffusers.models.attention_processor import (
+ Attention,
+ AttnProcessor2_0,
+ FusedAttnProcessor2_0,
+ XFormersAttnProcessor,
+)
+from diffusers.models.lora import adjust_lora_scale_text_encoder
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
+from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
+from diffusers.schedulers import KarrasDiffusionSchedulers
+from diffusers.utils import (
+ USE_PEFT_BACKEND,
+ deprecate,
+ is_invisible_watermark_available,
+ is_torch_xla_available,
+ logging,
+ replace_example_docstring,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
+from diffusers.utils.torch_utils import randn_tensor
+
+
+if is_invisible_watermark_available():
+ from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> from typing import List
+
+ >>> import torch
+ >>> from diffusers.pipelines.pipeline_utils import DiffusionPipeline
+ >>> from PIL import Image
+
+ >>> model_id = "a-r-r-o-w/dreamshaper-xl-turbo"
+ >>> pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16, variant="fp16", custom_pipeline="pipeline_sdxl_style_aligned")
+ >>> pipe = pipe.to("cuda")
+
+ # Enable memory saving techniques
+ >>> pipe.enable_vae_slicing()
+ >>> pipe.enable_vae_tiling()
+
+ >>> prompt = [
+ ... "a toy train. macro photo. 3d game asset",
+ ... "a toy airplane. macro photo. 3d game asset",
+ ... "a toy bicycle. macro photo. 3d game asset",
+ ... "a toy car. macro photo. 3d game asset",
+ ... ]
+ >>> negative_prompt = "low quality, worst quality, "
+
+ >>> # Enable StyleAligned
+ >>> pipe.enable_style_aligned(
+ ... share_group_norm=False,
+ ... share_layer_norm=False,
+ ... share_attention=True,
+ ... adain_queries=True,
+ ... adain_keys=True,
+ ... adain_values=False,
+ ... full_attention_share=False,
+ ... shared_score_scale=1.0,
+ ... shared_score_shift=0.0,
+ ... only_self_level=0.0,
+ >>> )
+
+ >>> # Run inference
+ >>> images = pipe(
+ ... prompt=prompt,
+ ... negative_prompt=negative_prompt,
+ ... guidance_scale=2,
+ ... height=1024,
+ ... width=1024,
+ ... num_inference_steps=10,
+ ... generator=torch.Generator().manual_seed(42),
+ >>> ).images
+
+ >>> # Disable StyleAligned if you do not wish to use it anymore
+ >>> pipe.disable_style_aligned()
+ ```
+"""
+
+
+def expand_first(feat: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
+ b = feat.shape[0]
+ feat_style = torch.stack((feat[0], feat[b // 2])).unsqueeze(1)
+ if scale == 1:
+ feat_style = feat_style.expand(2, b // 2, *feat.shape[1:])
+ else:
+ feat_style = feat_style.repeat(1, b // 2, 1, 1, 1)
+ feat_style = torch.cat([feat_style[:, :1], scale * feat_style[:, 1:]], dim=1)
+ return feat_style.reshape(*feat.shape)
+
+
+def concat_first(feat: torch.Tensor, dim: int = 2, scale: float = 1.0) -> torch.Tensor:
+ feat_style = expand_first(feat, scale=scale)
+ return torch.cat((feat, feat_style), dim=dim)
+
+
+def calc_mean_std(feat: torch.Tensor, eps: float = 1e-5) -> Tuple[torch.Tensor, torch.Tensor]:
+ feat_std = (feat.var(dim=-2, keepdims=True) + eps).sqrt()
+ feat_mean = feat.mean(dim=-2, keepdims=True)
+ return feat_mean, feat_std
+
+
+def adain(feat: torch.Tensor) -> torch.Tensor:
+ feat_mean, feat_std = calc_mean_std(feat)
+ feat_style_mean = expand_first(feat_mean)
+ feat_style_std = expand_first(feat_std)
+ feat = (feat - feat_mean) / feat_std
+ feat = feat * feat_style_std + feat_style_mean
+ return feat
+
+
+def get_switch_vec(total_num_layers, level):
+ if level == 0:
+ return torch.zeros(total_num_layers, dtype=torch.bool)
+ if level == 1:
+ return torch.ones(total_num_layers, dtype=torch.bool)
+ to_flip = level > 0.5
+ if to_flip:
+ level = 1 - level
+ num_switch = int(level * total_num_layers)
+ vec = torch.arange(total_num_layers)
+ vec = vec % (total_num_layers // num_switch)
+ vec = vec == 0
+ if to_flip:
+ vec = ~vec
+ return vec
+
+
+class SharedAttentionProcessor(AttnProcessor2_0):
+ def __init__(
+ self,
+ share_attention: bool = True,
+ adain_queries: bool = True,
+ adain_keys: bool = True,
+ adain_values: bool = False,
+ full_attention_share: bool = False,
+ shared_score_scale: float = 1.0,
+ shared_score_shift: float = 0.0,
+ ):
+ r"""Shared Attention Processor as proposed in the StyleAligned paper."""
+ super().__init__()
+ self.share_attention = share_attention
+ self.adain_queries = adain_queries
+ self.adain_keys = adain_keys
+ self.adain_values = adain_values
+ self.full_attention_share = full_attention_share
+ self.shared_score_scale = shared_score_scale
+ self.shared_score_shift = shared_score_shift
+
+ def shifted_scaled_dot_product_attention(
+ self, attn: Attention, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
+ ) -> torch.Tensor:
+ logits = torch.einsum("bhqd,bhkd->bhqk", query, key) * attn.scale
+ logits[:, :, :, query.shape[2] :] += self.shared_score_shift
+ probs = logits.softmax(-1)
+ return torch.einsum("bhqk,bhkd->bhqd", probs, value)
+
+ def shared_call(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ **kwargs,
+ ):
+ residual = hidden_states
+ input_ndim = hidden_states.ndim
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+
+ if attention_mask is not None:
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+ # scaled_dot_product_attention expects attention_mask shape to be
+ # (batch, heads, source_length, target_length)
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
+
+ if attn.group_norm is not None:
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = attn.to_q(hidden_states)
+ key = attn.to_k(hidden_states)
+ value = attn.to_v(hidden_states)
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ if self.adain_queries:
+ query = adain(query)
+ if self.adain_keys:
+ key = adain(key)
+ if self.adain_values:
+ value = adain(value)
+ if self.share_attention:
+ key = concat_first(key, -2, scale=self.shared_score_scale)
+ value = concat_first(value, -2)
+ if self.shared_score_shift != 0:
+ hidden_states = self.shifted_scaled_dot_product_attention(attn, query, key, value)
+ else:
+ hidden_states = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ )
+ else:
+ hidden_states = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ )
+
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states = hidden_states.to(query.dtype)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+ return hidden_states
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ **kwargs,
+ ):
+ if self.full_attention_share:
+ b, n, d = hidden_states.shape
+ k = 2
+ hidden_states = hidden_states.view(k, b, n, d).permute(0, 1, 3, 2).contiguous().view(-1, n, d)
+ # hidden_states = einops.rearrange(hidden_states, "(k b) n d -> k (b n) d", k=2)
+ hidden_states = super().__call__(
+ attn,
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ **kwargs,
+ )
+ hidden_states = hidden_states.view(k, b, n, d).permute(0, 1, 3, 2).contiguous().view(-1, n, d)
+ # hidden_states = einops.rearrange(hidden_states, "k (b n) d -> (k b) n d", n=n)
+ else:
+ hidden_states = self.shared_call(attn, hidden_states, hidden_states, attention_mask, **kwargs)
+
+ return hidden_states
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
+def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
+ """
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
+ """
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
+ # rescale the results from guidance (fixes overexposure)
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
+ return noise_cfg
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ **kwargs,
+):
+ """
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used,
+ `timesteps` must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
+ timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
+ must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+class StyleAlignedSDXLPipeline(
+ DiffusionPipeline,
+ StableDiffusionMixin,
+ FromSingleFileMixin,
+ StableDiffusionXLLoraLoaderMixin,
+ TextualInversionLoaderMixin,
+ IPAdapterMixin,
+):
+ r"""
+ Pipeline for text-to-image generation using Stable Diffusion XL.
+
+ This pipeline also adds experimental support for [StyleAligned](https://arxiv.org/abs/2312.02133). It can
+ be enabled/disabled using `.enable_style_aligned()` or `.disable_style_aligned()` respectively.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ The pipeline also inherits the following loading methods:
+ - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
+ - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
+ - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
+ - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
+ - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion XL uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ text_encoder_2 ([` CLIPTextModelWithProjection`]):
+ Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
+ specifically the
+ [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
+ variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ tokenizer_2 (`CLIPTokenizer`):
+ Second Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`):
+ Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of
+ `stabilityai/stable-diffusion-xl-base-1-0`.
+ add_watermarker (`bool`, *optional*):
+ Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to
+ watermark output images. If not defined, it will default to True if the package is installed, otherwise no
+ watermarker will be used.
+ """
+
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae"
+ _optional_components = [
+ "tokenizer",
+ "tokenizer_2",
+ "text_encoder",
+ "text_encoder_2",
+ "image_encoder",
+ "feature_extractor",
+ ]
+ _callback_tensor_inputs = [
+ "latents",
+ "prompt_embeds",
+ "negative_prompt_embeds",
+ "add_text_embeds",
+ "add_time_ids",
+ "negative_pooled_prompt_embeds",
+ "negative_add_time_ids",
+ ]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ text_encoder_2: CLIPTextModelWithProjection,
+ tokenizer: CLIPTokenizer,
+ tokenizer_2: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: KarrasDiffusionSchedulers,
+ image_encoder: CLIPVisionModelWithProjection = None,
+ feature_extractor: CLIPImageProcessor = None,
+ force_zeros_for_empty_prompt: bool = True,
+ add_watermarker: Optional[bool] = None,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ text_encoder_2=text_encoder_2,
+ tokenizer=tokenizer,
+ tokenizer_2=tokenizer_2,
+ unet=unet,
+ scheduler=scheduler,
+ image_encoder=image_encoder,
+ feature_extractor=feature_extractor,
+ )
+ self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+ self.mask_processor = VaeImageProcessor(
+ vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
+ )
+
+ self.default_sample_size = self.unet.config.sample_size
+
+ add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
+
+ if add_watermarker:
+ self.watermark = StableDiffusionXLWatermarker()
+ else:
+ self.watermark = None
+
+ def encode_prompt(
+ self,
+ prompt: str,
+ prompt_2: Optional[str] = None,
+ device: Optional[torch.device] = None,
+ num_images_per_prompt: int = 1,
+ do_classifier_free_guidance: bool = True,
+ negative_prompt: Optional[str] = None,
+ negative_prompt_2: Optional[str] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ pooled_prompt_embeds: Optional[torch.Tensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
+ lora_scale: Optional[float] = None,
+ clip_skip: Optional[int] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ used in both text-encoders
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ pooled_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
+ input argument.
+ lora_scale (`float`, *optional*):
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+ """
+ device = device or self._execution_device
+
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ # dynamically adjust the LoRA scale
+ if self.text_encoder is not None:
+ if not USE_PEFT_BACKEND:
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
+ else:
+ scale_lora_layers(self.text_encoder, lora_scale)
+
+ if self.text_encoder_2 is not None:
+ if not USE_PEFT_BACKEND:
+ adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
+ else:
+ scale_lora_layers(self.text_encoder_2, lora_scale)
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # Define tokenizers and text encoders
+ tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
+ text_encoders = (
+ [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
+ )
+
+ if prompt_embeds is None:
+ prompt_2 = prompt_2 or prompt
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
+
+ # textual inversion: process multi-vector tokens if necessary
+ prompt_embeds_list = []
+ prompts = [prompt, prompt_2]
+ for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, tokenizer)
+
+ text_inputs = tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
+
+ # We are only ALWAYS interested in the pooled output of the final text encoder
+ pooled_prompt_embeds = prompt_embeds[0]
+ if clip_skip is None:
+ prompt_embeds = prompt_embeds.hidden_states[-2]
+ else:
+ # "2" because SDXL always indexes from the penultimate layer.
+ prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
+
+ prompt_embeds_list.append(prompt_embeds)
+
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
+
+ # get unconditional embeddings for classifier free guidance
+ zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt
+ if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:
+ negative_prompt_embeds = torch.zeros_like(prompt_embeds)
+ negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
+ elif do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = negative_prompt or ""
+ negative_prompt_2 = negative_prompt_2 or negative_prompt
+
+ # normalize str to list
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+ negative_prompt_2 = (
+ batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
+ )
+
+ uncond_tokens: List[str]
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = [negative_prompt, negative_prompt_2]
+
+ negative_prompt_embeds_list = []
+ for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):
+ if isinstance(self, TextualInversionLoaderMixin):
+ negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer)
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = tokenizer(
+ negative_prompt,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ negative_prompt_embeds = text_encoder(
+ uncond_input.input_ids.to(device),
+ output_hidden_states=True,
+ )
+ # We are only ALWAYS interested in the pooled output of the final text encoder
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
+ negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
+
+ negative_prompt_embeds_list.append(negative_prompt_embeds)
+
+ negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
+
+ if self.text_encoder_2 is not None:
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
+ else:
+ prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ if self.text_encoder_2 is not None:
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
+ else:
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
+ bs_embed * num_images_per_prompt, -1
+ )
+ if do_classifier_free_guidance:
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
+ bs_embed * num_images_per_prompt, -1
+ )
+
+ if self.text_encoder is not None:
+ if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder, lora_scale)
+
+ if self.text_encoder_2 is not None:
+ if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
+
+ return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
+ def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
+ dtype = next(self.image_encoder.parameters()).dtype
+
+ if not isinstance(image, torch.Tensor):
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
+
+ image = image.to(device=device, dtype=dtype)
+ if output_hidden_states:
+ image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
+ image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
+ uncond_image_enc_hidden_states = self.image_encoder(
+ torch.zeros_like(image), output_hidden_states=True
+ ).hidden_states[-2]
+ uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
+ num_images_per_prompt, dim=0
+ )
+ return image_enc_hidden_states, uncond_image_enc_hidden_states
+ else:
+ image_embeds = self.image_encoder(image).image_embeds
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+ uncond_image_embeds = torch.zeros_like(image_embeds)
+
+ return image_embeds, uncond_image_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(
+ self,
+ prompt,
+ prompt_2,
+ height,
+ width,
+ callback_steps,
+ negative_prompt=None,
+ negative_prompt_2=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ pooled_prompt_embeds=None,
+ negative_pooled_prompt_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ ):
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt_2 is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
+ raise ValueError(
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
+ )
+
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
+ raise ValueError(
+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
+ )
+
+ def get_timesteps(self, num_inference_steps, strength, device, denoising_start=None):
+ # get the original timestep using init_timestep
+ if denoising_start is None:
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
+ t_start = max(num_inference_steps - init_timestep, 0)
+ else:
+ t_start = 0
+
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
+
+ # Strength is irrelevant if we directly request a timestep to start at;
+ # that is, strength is determined by the denoising_start instead.
+ if denoising_start is not None:
+ discrete_timestep_cutoff = int(
+ round(
+ self.scheduler.config.num_train_timesteps
+ - (denoising_start * self.scheduler.config.num_train_timesteps)
+ )
+ )
+
+ num_inference_steps = (timesteps < discrete_timestep_cutoff).sum().item()
+ if self.scheduler.order == 2 and num_inference_steps % 2 == 0:
+ # if the scheduler is a 2nd order scheduler we might have to do +1
+ # because `num_inference_steps` might be even given that every timestep
+ # (except the highest one) is duplicated. If `num_inference_steps` is even it would
+ # mean that we cut the timesteps in the middle of the denoising step
+ # (between 1st and 2nd derivative) which leads to incorrect results. By adding 1
+ # we ensure that the denoising process always ends after the 2nd derivate step of the scheduler
+ num_inference_steps = num_inference_steps + 1
+
+ # because t_n+1 >= t_n, we slice the timesteps starting from the end
+ timesteps = timesteps[-num_inference_steps:]
+ return timesteps, num_inference_steps
+
+ return timesteps, num_inference_steps - t_start
+
+ def prepare_latents(
+ self,
+ image,
+ mask,
+ width,
+ height,
+ num_channels_latents,
+ timestep,
+ batch_size,
+ num_images_per_prompt,
+ dtype,
+ device,
+ generator=None,
+ add_noise=True,
+ latents=None,
+ is_strength_max=True,
+ return_noise=False,
+ return_image_latents=False,
+ ):
+ batch_size *= num_images_per_prompt
+
+ if image is None:
+ shape = (
+ batch_size,
+ num_channels_latents,
+ int(height) // self.vae_scale_factor,
+ int(width) // self.vae_scale_factor,
+ )
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ elif mask is None:
+ if not isinstance(image, (torch.Tensor, Image.Image, list)):
+ raise ValueError(
+ f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
+ )
+
+ # Offload text encoder if `enable_model_cpu_offload` was enabled
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.text_encoder_2.to("cpu")
+ torch.cuda.empty_cache()
+
+ image = image.to(device=device, dtype=dtype)
+
+ if image.shape[1] == 4:
+ init_latents = image
+
+ else:
+ # make sure the VAE is in float32 mode, as it overflows in float16
+ if self.vae.config.force_upcast:
+ image = image.float()
+ self.vae.to(dtype=torch.float32)
+
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ elif isinstance(generator, list):
+ init_latents = [
+ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
+ for i in range(batch_size)
+ ]
+ init_latents = torch.cat(init_latents, dim=0)
+ else:
+ init_latents = retrieve_latents(self.vae.encode(image), generator=generator)
+
+ if self.vae.config.force_upcast:
+ self.vae.to(dtype)
+
+ init_latents = init_latents.to(dtype)
+ init_latents = self.vae.config.scaling_factor * init_latents
+
+ if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
+ # expand init_latents for batch_size
+ additional_image_per_prompt = batch_size // init_latents.shape[0]
+ init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0)
+ elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
+ raise ValueError(
+ f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
+ )
+ else:
+ init_latents = torch.cat([init_latents], dim=0)
+
+ if add_noise:
+ shape = init_latents.shape
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ # get latents
+ init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
+
+ latents = init_latents
+ return latents
+
+ else:
+ shape = (
+ batch_size,
+ num_channels_latents,
+ int(height) // self.vae_scale_factor,
+ int(width) // self.vae_scale_factor,
+ )
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if (image is None or timestep is None) and not is_strength_max:
+ raise ValueError(
+ "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise."
+ "However, either the image or the noise timestep has not been provided."
+ )
+
+ if image.shape[1] == 4:
+ image_latents = image.to(device=device, dtype=dtype)
+ image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1)
+ elif return_image_latents or (latents is None and not is_strength_max):
+ image = image.to(device=device, dtype=dtype)
+ image_latents = self._encode_vae_image(image=image, generator=generator)
+ image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1)
+
+ if latents is None and add_noise:
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ # if strength is 1. then initialise the latents to noise, else initial to image + noise
+ latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep)
+ # if pure noise then scale the initial latents by the Scheduler's init sigma
+ latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents
+ elif add_noise:
+ noise = latents.to(device)
+ latents = noise * self.scheduler.init_noise_sigma
+ else:
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ latents = image_latents.to(device)
+
+ outputs = (latents,)
+
+ if return_noise:
+ outputs += (noise,)
+
+ if return_image_latents:
+ outputs += (image_latents,)
+
+ return outputs
+
+ def prepare_mask_latents(
+ self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
+ ):
+ # resize the mask to latents shape as we concatenate the mask to the latents
+ # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
+ # and half precision
+ mask = torch.nn.functional.interpolate(
+ mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor)
+ )
+ mask = mask.to(device=device, dtype=dtype)
+
+ # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
+ if mask.shape[0] < batch_size:
+ if not batch_size % mask.shape[0] == 0:
+ raise ValueError(
+ "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
+ f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
+ " of masks that you pass is divisible by the total requested batch size."
+ )
+ mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
+
+ mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
+
+ if masked_image is not None and masked_image.shape[1] == 4:
+ masked_image_latents = masked_image
+ else:
+ masked_image_latents = None
+
+ if masked_image is not None:
+ if masked_image_latents is None:
+ masked_image = masked_image.to(device=device, dtype=dtype)
+ masked_image_latents = self._encode_vae_image(masked_image, generator=generator)
+
+ if masked_image_latents.shape[0] < batch_size:
+ if not batch_size % masked_image_latents.shape[0] == 0:
+ raise ValueError(
+ "The passed images and the required batch size don't match. Images are supposed to be duplicated"
+ f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
+ " Make sure the number of images that you pass is divisible by the total requested batch size."
+ )
+ masked_image_latents = masked_image_latents.repeat(
+ batch_size // masked_image_latents.shape[0], 1, 1, 1
+ )
+
+ masked_image_latents = (
+ torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents
+ )
+
+ # aligning device to prevent device errors when concating it with the latent model input
+ masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
+
+ return mask, masked_image_latents
+
+ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
+ dtype = image.dtype
+ if self.vae.config.force_upcast:
+ image = image.float()
+ self.vae.to(dtype=torch.float32)
+
+ if isinstance(generator, list):
+ image_latents = [
+ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
+ for i in range(image.shape[0])
+ ]
+ image_latents = torch.cat(image_latents, dim=0)
+ else:
+ image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
+
+ if self.vae.config.force_upcast:
+ self.vae.to(dtype)
+
+ image_latents = image_latents.to(dtype)
+ image_latents = self.vae.config.scaling_factor * image_latents
+
+ return image_latents
+
+ def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype):
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
+
+ passed_add_embed_dim = (
+ self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim
+ )
+ expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
+
+ if expected_add_embed_dim != passed_add_embed_dim:
+ raise ValueError(
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
+ )
+
+ add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
+ return add_time_ids
+
+ def upcast_vae(self):
+ dtype = self.vae.dtype
+ self.vae.to(dtype=torch.float32)
+ use_torch_2_0_or_xformers = isinstance(
+ self.vae.decoder.mid_block.attentions[0].processor,
+ (
+ AttnProcessor2_0,
+ XFormersAttnProcessor,
+ FusedAttnProcessor2_0,
+ ),
+ )
+ # if xformers or torch_2_0 is used attention block does not need
+ # to be in float32 which can save lots of memory
+ if use_torch_2_0_or_xformers:
+ self.vae.post_quant_conv.to(dtype)
+ self.vae.decoder.conv_in.to(dtype)
+ self.vae.decoder.mid_block.to(dtype)
+
+ def _enable_shared_attention_processors(
+ self,
+ share_attention: bool,
+ adain_queries: bool,
+ adain_keys: bool,
+ adain_values: bool,
+ full_attention_share: bool,
+ shared_score_scale: float,
+ shared_score_shift: float,
+ only_self_level: float,
+ ):
+ r"""Helper method to enable usage of Shared Attention Processor."""
+ attn_procs = {}
+ num_self_layers = len([name for name in self.unet.attn_processors.keys() if "attn1" in name])
+
+ only_self_vec = get_switch_vec(num_self_layers, only_self_level)
+
+ for i, name in enumerate(self.unet.attn_processors.keys()):
+ is_self_attention = "attn1" in name
+ if is_self_attention:
+ if only_self_vec[i // 2]:
+ attn_procs[name] = AttnProcessor2_0()
+ else:
+ attn_procs[name] = SharedAttentionProcessor(
+ share_attention=share_attention,
+ adain_queries=adain_queries,
+ adain_keys=adain_keys,
+ adain_values=adain_values,
+ full_attention_share=full_attention_share,
+ shared_score_scale=shared_score_scale,
+ shared_score_shift=shared_score_shift,
+ )
+ else:
+ attn_procs[name] = AttnProcessor2_0()
+
+ self.unet.set_attn_processor(attn_procs)
+
+ def _disable_shared_attention_processors(self):
+ r"""
+ Helper method to disable usage of the Shared Attention Processor. All processors
+ are reset to the default Attention Processor for pytorch versions above 2.0.
+ """
+ attn_procs = {}
+
+ for i, name in enumerate(self.unet.attn_processors.keys()):
+ attn_procs[name] = AttnProcessor2_0()
+
+ self.unet.set_attn_processor(attn_procs)
+
+ def _register_shared_norm(self, share_group_norm: bool = True, share_layer_norm: bool = True):
+ r"""Helper method to register shared group/layer normalization layers."""
+
+ def register_norm_forward(norm_layer: Union[nn.GroupNorm, nn.LayerNorm]) -> Union[nn.GroupNorm, nn.LayerNorm]:
+ if not hasattr(norm_layer, "orig_forward"):
+ setattr(norm_layer, "orig_forward", norm_layer.forward)
+ orig_forward = norm_layer.orig_forward
+
+ def forward_(hidden_states: torch.Tensor) -> torch.Tensor:
+ n = hidden_states.shape[-2]
+ hidden_states = concat_first(hidden_states, dim=-2)
+ hidden_states = orig_forward(hidden_states)
+ return hidden_states[..., :n, :]
+
+ norm_layer.forward = forward_
+ return norm_layer
+
+ def get_norm_layers(pipeline_, norm_layers_: Dict[str, List[Union[nn.GroupNorm, nn.LayerNorm]]]):
+ if isinstance(pipeline_, nn.LayerNorm) and share_layer_norm:
+ norm_layers_["layer"].append(pipeline_)
+ if isinstance(pipeline_, nn.GroupNorm) and share_group_norm:
+ norm_layers_["group"].append(pipeline_)
+ else:
+ for layer in pipeline_.children():
+ get_norm_layers(layer, norm_layers_)
+
+ norm_layers = {"group": [], "layer": []}
+ get_norm_layers(self.unet, norm_layers)
+
+ norm_layers_list = []
+ for key in ["group", "layer"]:
+ for layer in norm_layers[key]:
+ norm_layers_list.append(register_norm_forward(layer))
+
+ return norm_layers_list
+
+ @property
+ def style_aligned_enabled(self):
+ r"""Returns whether StyleAligned has been enabled in the pipeline or not."""
+ return hasattr(self, "_style_aligned_norm_layers") and self._style_aligned_norm_layers is not None
+
+ def enable_style_aligned(
+ self,
+ share_group_norm: bool = True,
+ share_layer_norm: bool = True,
+ share_attention: bool = True,
+ adain_queries: bool = True,
+ adain_keys: bool = True,
+ adain_values: bool = False,
+ full_attention_share: bool = False,
+ shared_score_scale: float = 1.0,
+ shared_score_shift: float = 0.0,
+ only_self_level: float = 0.0,
+ ):
+ r"""
+ Enables the StyleAligned mechanism as in https://arxiv.org/abs/2312.02133.
+
+ Args:
+ share_group_norm (`bool`, defaults to `True`):
+ Whether or not to use shared group normalization layers.
+ share_layer_norm (`bool`, defaults to `True`):
+ Whether or not to use shared layer normalization layers.
+ share_attention (`bool`, defaults to `True`):
+ Whether or not to use attention sharing between batch images.
+ adain_queries (`bool`, defaults to `True`):
+ Whether or not to apply the AdaIn operation on attention queries.
+ adain_keys (`bool`, defaults to `True`):
+ Whether or not to apply the AdaIn operation on attention keys.
+ adain_values (`bool`, defaults to `False`):
+ Whether or not to apply the AdaIn operation on attention values.
+ full_attention_share (`bool`, defaults to `False`):
+ Whether or not to use full attention sharing between all images in a batch. Can
+ lead to content leakage within each batch and some loss in diversity.
+ shared_score_scale (`float`, defaults to `1.0`):
+ Scale for shared attention.
+ """
+ self._style_aligned_norm_layers = self._register_shared_norm(share_group_norm, share_layer_norm)
+ self._enable_shared_attention_processors(
+ share_attention=share_attention,
+ adain_queries=adain_queries,
+ adain_keys=adain_keys,
+ adain_values=adain_values,
+ full_attention_share=full_attention_share,
+ shared_score_scale=shared_score_scale,
+ shared_score_shift=shared_score_shift,
+ only_self_level=only_self_level,
+ )
+
+ def disable_style_aligned(self):
+ r"""Disables the StyleAligned mechanism if it had been previously enabled."""
+ if self.style_aligned_enabled:
+ for layer in self._style_aligned_norm_layers:
+ layer.forward = layer.orig_forward
+
+ self._style_aligned_norm_layers = None
+ self._disable_shared_attention_processors()
+
+ # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
+ def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
+ """
+ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
+
+ Args:
+ timesteps (`torch.Tensor`):
+ generate embedding vectors at these timesteps
+ embedding_dim (`int`, *optional*, defaults to 512):
+ dimension of the embeddings to generate
+ dtype:
+ data type of the generated embeddings
+
+ Returns:
+ `torch.Tensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
+ """
+ assert len(w.shape) == 1
+ w = w * 1000.0
+
+ half_dim = embedding_dim // 2
+ emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
+ emb = w.to(dtype)[:, None] * emb[None, :]
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+ if embedding_dim % 2 == 1: # zero pad
+ emb = torch.nn.functional.pad(emb, (0, 1))
+ assert emb.shape == (w.shape[0], embedding_dim)
+ return emb
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def guidance_rescale(self):
+ return self._guidance_rescale
+
+ @property
+ def clip_skip(self):
+ return self._clip_skip
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
+
+ @property
+ def cross_attention_kwargs(self):
+ return self._cross_attention_kwargs
+
+ @property
+ def denoising_end(self):
+ return self._denoising_end
+
+ @property
+ def denoising_start(self):
+ return self._denoising_start
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ prompt_2: Optional[Union[str, List[str]]] = None,
+ image: Optional[PipelineImageInput] = None,
+ mask_image: Optional[PipelineImageInput] = None,
+ masked_image_latents: Optional[torch.Tensor] = None,
+ strength: float = 0.3,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ timesteps: List[int] = None,
+ denoising_start: Optional[float] = None,
+ denoising_end: Optional[float] = None,
+ guidance_scale: float = 5.0,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ pooled_prompt_embeds: Optional[torch.Tensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
+ ip_adapter_image: Optional[PipelineImageInput] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ guidance_rescale: float = 0.0,
+ original_size: Optional[Tuple[int, int]] = None,
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
+ target_size: Optional[Tuple[int, int]] = None,
+ clip_skip: Optional[int] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ **kwargs,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ used in both text-encoders
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
+ Anything below 512 pixels won't work well for
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
+ and checkpoints that are not specifically fine-tuned on low resolutions.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
+ Anything below 512 pixels won't work well for
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
+ and checkpoints that are not specifically fine-tuned on low resolutions.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
+ passed will be used. Must be in descending order.
+ denoising_end (`float`, *optional*):
+ When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
+ completed before it is intentionally prematurely terminated. As a result, the returned sample will
+ still retain a substantial amount of noise as determined by the discrete timesteps selected by the
+ scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a
+ "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
+ Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
+ guidance_scale (`float`, *optional*, defaults to 5.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ pooled_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
+ input argument.
+ ip_adapter_image: (`PipelineImageInput`, *optional*):
+ Optional image input to work with IP Adapters.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
+ of a plain tuple.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
+ Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
+ [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
+ Guidance rescale factor should fix overexposure when using zero terminal SNR.
+ original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
+ `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
+ explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
+ `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
+ `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
+ `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ For most cases, `target_size` should be set to the desired height and width of the generated image. If
+ not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in
+ section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ To negatively condition the generation process based on a specific image resolution. Part of SDXL's
+ micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
+ negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
+ To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
+ micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
+ negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ To negatively condition the generation process based on a target image resolution. It should be as same
+ as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
+ """
+
+ callback = kwargs.pop("callback", None)
+ callback_steps = kwargs.pop("callback_steps", None)
+
+ if callback is not None:
+ deprecate(
+ "callback",
+ "1.0.0",
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
+ )
+ if callback_steps is not None:
+ deprecate(
+ "callback_steps",
+ "1.0.0",
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
+ )
+
+ # 0. Default height and width to unet
+ height = height or self.default_sample_size * self.vae_scale_factor
+ width = width or self.default_sample_size * self.vae_scale_factor
+
+ original_size = original_size or (height, width)
+ target_size = target_size or (height, width)
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt=prompt,
+ prompt_2=prompt_2,
+ height=height,
+ width=width,
+ callback_steps=callback_steps,
+ negative_prompt=negative_prompt,
+ negative_prompt_2=negative_prompt_2,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._guidance_rescale = guidance_rescale
+ self._clip_skip = clip_skip
+ self._cross_attention_kwargs = cross_attention_kwargs
+ self._denoising_end = denoising_end
+ self._denoising_start = denoising_start
+ self._interrupt = False
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ # 3. Encode input prompt
+ lora_scale = (
+ self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
+ )
+
+ (
+ prompt_embeds,
+ negative_prompt_embeds,
+ pooled_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ ) = self.encode_prompt(
+ prompt=prompt,
+ prompt_2=prompt_2,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ negative_prompt=negative_prompt,
+ negative_prompt_2=negative_prompt_2,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ lora_scale=lora_scale,
+ clip_skip=self.clip_skip,
+ )
+
+ # 4. Preprocess image and mask_image
+ if image is not None:
+ image = self.image_processor.preprocess(image, height=height, width=width)
+ image = image.to(device=self.device, dtype=prompt_embeds.dtype)
+
+ if mask_image is not None:
+ mask = self.mask_processor.preprocess(mask_image, height=height, width=width)
+ mask = mask.to(device=self.device, dtype=prompt_embeds.dtype)
+
+ if masked_image_latents is not None:
+ masked_image = masked_image_latents
+ elif image.shape[1] == 4:
+ # if image is in latent space, we can't mask it
+ masked_image = None
+ else:
+ masked_image = image * (mask < 0.5)
+ else:
+ mask = None
+
+ # 4. Prepare timesteps
+ def denoising_value_valid(dnv):
+ return isinstance(dnv, float) and 0 < dnv < 1
+
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
+
+ if image is not None:
+ timesteps, num_inference_steps = self.get_timesteps(
+ num_inference_steps,
+ strength,
+ device,
+ denoising_start=self.denoising_start if denoising_value_valid(self.denoising_start) else None,
+ )
+
+ # check that number of inference steps is not < 1 - as this doesn't make sense
+ if num_inference_steps < 1:
+ raise ValueError(
+ f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline"
+ f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
+ )
+
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
+ is_strength_max = strength == 1.0
+ add_noise = True if self.denoising_start is None else False
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.unet.config.in_channels
+ num_channels_unet = self.unet.config.in_channels
+ return_image_latents = num_channels_unet == 4
+
+ latents = self.prepare_latents(
+ image=image,
+ mask=mask,
+ width=width,
+ height=height,
+ num_channels_latents=num_channels_latents,
+ timestep=latent_timestep,
+ batch_size=batch_size * num_images_per_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ dtype=prompt_embeds.dtype,
+ device=device,
+ generator=generator,
+ add_noise=add_noise,
+ latents=latents,
+ is_strength_max=is_strength_max,
+ return_noise=True,
+ return_image_latents=return_image_latents,
+ )
+
+ if mask is not None:
+ if return_image_latents:
+ latents, noise, image_latents = latents
+ else:
+ latents, noise = latents
+
+ mask, masked_image_latents = self.prepare_mask_latents(
+ mask=mask,
+ masked_image=masked_image,
+ batch_size=batch_size * num_images_per_prompt,
+ height=height,
+ width=width,
+ dtype=prompt_embeds.dtype,
+ device=device,
+ generator=generator,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ )
+
+ # Check that sizes of mask, masked image and latents match
+ if num_channels_unet == 9:
+ # default case for runwayml/stable-diffusion-inpainting
+ num_channels_mask = mask.shape[1]
+ num_channels_masked_image = masked_image_latents.shape[1]
+ if num_channels_latents + num_channels_mask + num_channels_masked_image != num_channels_unet:
+ raise ValueError(
+ f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
+ f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
+ f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
+ f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
+ " `pipeline.unet` or your `mask_image` or `image` input."
+ )
+ elif num_channels_unet != 4:
+ raise ValueError(
+ f"The unet {self.unet.__class__} should have either 4 or 9 input channels, not {self.unet.config.in_channels}."
+ )
+
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ height, width = latents.shape[-2:]
+ height = height * self.vae_scale_factor
+ width = width * self.vae_scale_factor
+
+ original_size = original_size or (height, width)
+ target_size = target_size or (height, width)
+
+ # 7. Prepare added time ids & embeddings
+ add_text_embeds = pooled_prompt_embeds
+ add_time_ids = self._get_add_time_ids(
+ original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype
+ )
+
+ if self.do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
+ add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)
+
+ prompt_embeds = prompt_embeds.to(device)
+ add_text_embeds = add_text_embeds.to(device)
+ add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
+
+ if ip_adapter_image is not None:
+ output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
+ image_embeds, negative_image_embeds = self.encode_image(
+ ip_adapter_image, device, num_images_per_prompt, output_hidden_state
+ )
+ if self.do_classifier_free_guidance:
+ image_embeds = torch.cat([negative_image_embeds, image_embeds])
+ image_embeds = image_embeds.to(device)
+
+ # 8. Denoising loop
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+
+ # 8.1 Apply denoising_end
+ if (
+ self.denoising_end is not None
+ and isinstance(self.denoising_end, float)
+ and self.denoising_end > 0
+ and self.denoising_end < 1
+ ):
+ discrete_timestep_cutoff = int(
+ round(
+ self.scheduler.config.num_train_timesteps
+ - (self.denoising_end * self.scheduler.config.num_train_timesteps)
+ )
+ )
+ num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
+ timesteps = timesteps[:num_inference_steps]
+
+ # 9. Optionally get Guidance Scale Embedding
+ timestep_cond = None
+ if self.unet.config.time_cond_proj_dim is not None:
+ guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
+ timestep_cond = self.get_guidance_scale_embedding(
+ guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
+ ).to(device=device, dtype=latents.dtype)
+
+ self._num_timesteps = len(timesteps)
+
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
+ if ip_adapter_image is not None:
+ added_cond_kwargs["image_embeds"] = image_embeds
+
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ timestep_cond=timestep_cond,
+ cross_attention_kwargs=self.cross_attention_kwargs,
+ added_cond_kwargs=added_cond_kwargs,
+ return_dict=False,
+ )[0]
+
+ # perform guidance
+ if self.do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+ if mask is not None and num_channels_unet == 4:
+ init_latents_proper = image_latents
+
+ if self.do_classifier_free_guidance:
+ init_mask, _ = mask.chunk(2)
+ else:
+ init_mask = mask
+
+ if i < len(timesteps) - 1:
+ noise_timestep = timesteps[i + 1]
+ init_latents_proper = self.scheduler.add_noise(
+ init_latents_proper, noise, torch.tensor([noise_timestep])
+ )
+
+ latents = (1 - init_mask) * init_latents_proper + init_mask * latents
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+ add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
+ negative_pooled_prompt_embeds = callback_outputs.pop(
+ "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
+ )
+ add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ if not output_type == "latent":
+ # make sure the VAE is in float32 mode, as it overflows in float16
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
+
+ if needs_upcasting:
+ self.upcast_vae()
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
+
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+
+ # cast back to fp16 if needed
+ if needs_upcasting:
+ self.vae.to(dtype=torch.float16)
+ else:
+ image = latents
+
+ if not output_type == "latent":
+ # apply watermark if available
+ if self.watermark is not None:
+ image = self.watermark.apply_watermark(image)
+
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image,)
+
+ return StableDiffusionXLPipelineOutput(images=image)
diff --git a/diffusers/examples/community/pipeline_stable_diffusion_3_differential_img2img.py b/diffusers/examples/community/pipeline_stable_diffusion_3_differential_img2img.py
new file mode 100644
index 0000000000000000000000000000000000000000..8cee5ecbc1411bd763c5fe778c2f508c3074088a
--- /dev/null
+++ b/diffusers/examples/community/pipeline_stable_diffusion_3_differential_img2img.py
@@ -0,0 +1,981 @@
+# Copyright 2024 Stability AI and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Callable, Dict, List, Optional, Union
+
+import torch
+from transformers import (
+ CLIPTextModelWithProjection,
+ CLIPTokenizer,
+ T5EncoderModel,
+ T5TokenizerFast,
+)
+
+from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
+from diffusers.models.autoencoders import AutoencoderKL
+from diffusers.models.transformers import SD3Transformer2DModel
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline
+from diffusers.pipelines.stable_diffusion_3.pipeline_output import StableDiffusion3PipelineOutput
+from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
+from diffusers.utils import (
+ is_torch_xla_available,
+ logging,
+ replace_example_docstring,
+)
+from diffusers.utils.torch_utils import randn_tensor
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+
+ >>> from diffusers import AutoPipelineForImage2Image
+ >>> from diffusers.utils import load_image
+
+ >>> device = "cuda"
+ >>> model_id_or_path = "stabilityai/stable-diffusion-3-medium-diffusers"
+ >>> pipe = AutoPipelineForImage2Image.from_pretrained(model_id_or_path, torch_dtype=torch.float16)
+ >>> pipe = pipe.to(device)
+
+ >>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
+ >>> init_image = load_image(url).resize((512, 512))
+
+ >>> prompt = "cat wizard, gandalf, lord of the rings, detailed, fantasy, cute, adorable, Pixar, Disney, 8k"
+
+ >>> images = pipe(prompt=prompt, image=init_image, strength=0.95, guidance_scale=7.5).images[0]
+ ```
+"""
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ """
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class StableDiffusion3DifferentialImg2ImgPipeline(DiffusionPipeline):
+ r"""
+ Args:
+ transformer ([`SD3Transformer2DModel`]):
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModelWithProjection`]):
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
+ specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant,
+ with an additional added projection layer that is initialized with a diagonal matrix with the `hidden_size`
+ as its dimension.
+ text_encoder_2 ([`CLIPTextModelWithProjection`]):
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
+ specifically the
+ [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
+ variant.
+ text_encoder_3 ([`T5EncoderModel`]):
+ Frozen text-encoder. Stable Diffusion 3 uses
+ [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
+ [t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ tokenizer_2 (`CLIPTokenizer`):
+ Second Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ tokenizer_3 (`T5TokenizerFast`):
+ Tokenizer of class
+ [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
+ """
+
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->transformer->vae"
+ _optional_components = []
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "negative_pooled_prompt_embeds"]
+
+ def __init__(
+ self,
+ transformer: SD3Transformer2DModel,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModelWithProjection,
+ tokenizer: CLIPTokenizer,
+ text_encoder_2: CLIPTextModelWithProjection,
+ tokenizer_2: CLIPTokenizer,
+ text_encoder_3: T5EncoderModel,
+ tokenizer_3: T5TokenizerFast,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ text_encoder_2=text_encoder_2,
+ text_encoder_3=text_encoder_3,
+ tokenizer=tokenizer,
+ tokenizer_2=tokenizer_2,
+ tokenizer_3=tokenizer_3,
+ transformer=transformer,
+ scheduler=scheduler,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(
+ vae_scale_factor=self.vae_scale_factor, vae_latent_channels=self.vae.config.latent_channels
+ )
+ self.mask_processor = VaeImageProcessor(
+ vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_convert_grayscale=True
+ )
+
+ self.tokenizer_max_length = self.tokenizer.model_max_length
+ self.default_sample_size = self.transformer.config.sample_size
+
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline._get_t5_prompt_embeds
+ def _get_t5_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ num_images_per_prompt: int = 1,
+ max_sequence_length: int = 256,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ if self.text_encoder_3 is None:
+ return torch.zeros(
+ (
+ batch_size * num_images_per_prompt,
+ self.tokenizer_max_length,
+ self.transformer.config.joint_attention_dim,
+ ),
+ device=device,
+ dtype=dtype,
+ )
+
+ text_inputs = self.tokenizer_3(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ add_special_tokens=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer_3(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer_3.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because `max_sequence_length` is set to "
+ f" {max_sequence_length} tokens: {removed_text}"
+ )
+
+ prompt_embeds = self.text_encoder_3(text_input_ids.to(device))[0]
+
+ dtype = self.text_encoder_3.dtype
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ _, seq_len, _ = prompt_embeds.shape
+
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline._get_clip_prompt_embeds
+ def _get_clip_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]],
+ num_images_per_prompt: int = 1,
+ device: Optional[torch.device] = None,
+ clip_skip: Optional[int] = None,
+ clip_model_index: int = 0,
+ ):
+ device = device or self._execution_device
+
+ clip_tokenizers = [self.tokenizer, self.tokenizer_2]
+ clip_text_encoders = [self.text_encoder, self.text_encoder_2]
+
+ tokenizer = clip_tokenizers[clip_model_index]
+ text_encoder = clip_text_encoders[clip_model_index]
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ text_inputs = tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer_max_length} tokens: {removed_text}"
+ )
+ prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
+ pooled_prompt_embeds = prompt_embeds[0]
+
+ if clip_skip is None:
+ prompt_embeds = prompt_embeds.hidden_states[-2]
+ else:
+ prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
+
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ _, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
+
+ return prompt_embeds, pooled_prompt_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ prompt_2: Union[str, List[str]],
+ prompt_3: Union[str, List[str]],
+ device: Optional[torch.device] = None,
+ num_images_per_prompt: int = 1,
+ do_classifier_free_guidance: bool = True,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
+ negative_prompt_3: Optional[Union[str, List[str]]] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ clip_skip: Optional[int] = None,
+ max_sequence_length: int = 256,
+ ):
+ r"""
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ used in all text-encoders
+ prompt_3 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to the `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is
+ used in all text-encoders
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
+ `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and
+ `text_encoder_3`. If not defined, `negative_prompt` is used in both text-encoders
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
+ input argument.
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+ """
+ device = device or self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ prompt_2 = prompt_2 or prompt
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
+
+ prompt_3 = prompt_3 or prompt
+ prompt_3 = [prompt_3] if isinstance(prompt_3, str) else prompt_3
+
+ prompt_embed, pooled_prompt_embed = self._get_clip_prompt_embeds(
+ prompt=prompt,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ clip_skip=clip_skip,
+ clip_model_index=0,
+ )
+ prompt_2_embed, pooled_prompt_2_embed = self._get_clip_prompt_embeds(
+ prompt=prompt_2,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ clip_skip=clip_skip,
+ clip_model_index=1,
+ )
+ clip_prompt_embeds = torch.cat([prompt_embed, prompt_2_embed], dim=-1)
+
+ t5_prompt_embed = self._get_t5_prompt_embeds(
+ prompt=prompt_3,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+
+ clip_prompt_embeds = torch.nn.functional.pad(
+ clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1])
+ )
+
+ prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2)
+ pooled_prompt_embeds = torch.cat([pooled_prompt_embed, pooled_prompt_2_embed], dim=-1)
+
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = negative_prompt or ""
+ negative_prompt_2 = negative_prompt_2 or negative_prompt
+ negative_prompt_3 = negative_prompt_3 or negative_prompt
+
+ # normalize str to list
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+ negative_prompt_2 = (
+ batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
+ )
+ negative_prompt_3 = (
+ batch_size * [negative_prompt_3] if isinstance(negative_prompt_3, str) else negative_prompt_3
+ )
+
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+
+ negative_prompt_embed, negative_pooled_prompt_embed = self._get_clip_prompt_embeds(
+ negative_prompt,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ clip_skip=None,
+ clip_model_index=0,
+ )
+ negative_prompt_2_embed, negative_pooled_prompt_2_embed = self._get_clip_prompt_embeds(
+ negative_prompt_2,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ clip_skip=None,
+ clip_model_index=1,
+ )
+ negative_clip_prompt_embeds = torch.cat([negative_prompt_embed, negative_prompt_2_embed], dim=-1)
+
+ t5_negative_prompt_embed = self._get_t5_prompt_embeds(
+ prompt=negative_prompt_3,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+
+ negative_clip_prompt_embeds = torch.nn.functional.pad(
+ negative_clip_prompt_embeds,
+ (0, t5_negative_prompt_embed.shape[-1] - negative_clip_prompt_embeds.shape[-1]),
+ )
+
+ negative_prompt_embeds = torch.cat([negative_clip_prompt_embeds, t5_negative_prompt_embed], dim=-2)
+ negative_pooled_prompt_embeds = torch.cat(
+ [negative_pooled_prompt_embed, negative_pooled_prompt_2_embed], dim=-1
+ )
+
+ return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
+
+ def check_inputs(
+ self,
+ prompt,
+ prompt_2,
+ prompt_3,
+ strength,
+ negative_prompt=None,
+ negative_prompt_2=None,
+ negative_prompt_3=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ pooled_prompt_embeds=None,
+ negative_pooled_prompt_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ max_sequence_length=None,
+ ):
+ if strength < 0 or strength > 1:
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt_2 is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt_3 is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt_3`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
+ elif prompt_3 is not None and (not isinstance(prompt_3, str) and not isinstance(prompt_3, list)):
+ raise ValueError(f"`prompt_3` has to be of type `str` or `list` but is {type(prompt_3)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+ elif negative_prompt_3 is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt_3`: {negative_prompt_3} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
+ raise ValueError(
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
+ )
+
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
+ raise ValueError(
+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
+ )
+
+ if max_sequence_length is not None and max_sequence_length > 512:
+ raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
+
+ def get_timesteps(self, num_inference_steps, strength, device):
+ # get the original timestep using init_timestep
+ init_timestep = min(num_inference_steps * strength, num_inference_steps)
+
+ t_start = int(max(num_inference_steps - init_timestep, 0))
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
+
+ return timesteps, num_inference_steps - t_start
+
+ def prepare_latents(
+ self, batch_size, num_channels_latents, height, width, image, timestep, dtype, device, generator=None
+ ):
+ shape = (
+ batch_size,
+ num_channels_latents,
+ int(height) // self.vae_scale_factor,
+ int(width) // self.vae_scale_factor,
+ )
+
+ image = image.to(device=device, dtype=dtype)
+
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+ elif isinstance(generator, list):
+ init_latents = [
+ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) for i in range(batch_size)
+ ]
+ init_latents = torch.cat(init_latents, dim=0)
+ else:
+ init_latents = retrieve_latents(self.vae.encode(image), generator=generator)
+
+ init_latents = (init_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
+
+ if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
+ # expand init_latents for batch_size
+ additional_image_per_prompt = batch_size // init_latents.shape[0]
+ init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0)
+ elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
+ raise ValueError(
+ f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
+ )
+ else:
+ init_latents = torch.cat([init_latents], dim=0)
+
+ shape = init_latents.shape
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+
+ init_latents = self.scheduler.scale_noise(init_latents, timestep, noise)
+ latents = init_latents.to(device=device, dtype=dtype)
+
+ return latents
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def clip_skip(self):
+ return self._clip_skip
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ prompt_2: Optional[Union[str, List[str]]] = None,
+ prompt_3: Optional[Union[str, List[str]]] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ image: PipelineImageInput = None,
+ strength: float = 0.6,
+ num_inference_steps: int = 50,
+ timesteps: List[int] = None,
+ guidance_scale: float = 7.0,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
+ negative_prompt_3: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ clip_skip: Optional[int] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 256,
+ map: PipelineImageInput = None,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ will be used instead
+ prompt_3 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is
+ will be used instead
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
+ passed will be used. Must be in descending order.
+ guidance_scale (`float`, *optional*, defaults to 5.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
+ `text_encoder_2`. If not defined, `negative_prompt` is used instead
+ negative_prompt_3 (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and
+ `text_encoder_3`. If not defined, `negative_prompt` is used instead
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
+ input argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
+ of a plain tuple.
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] if `return_dict` is True, otherwise a
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
+ """
+
+ # 0. Default height and width
+ height = height or self.default_sample_size * self.vae_scale_factor
+ width = width or self.default_sample_size * self.vae_scale_factor
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ prompt_2,
+ prompt_3,
+ strength,
+ negative_prompt=negative_prompt,
+ negative_prompt_2=negative_prompt_2,
+ negative_prompt_3=negative_prompt_3,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ max_sequence_length=max_sequence_length,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._clip_skip = clip_skip
+ self._interrupt = False
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ (
+ prompt_embeds,
+ negative_prompt_embeds,
+ pooled_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ ) = self.encode_prompt(
+ prompt=prompt,
+ prompt_2=prompt_2,
+ prompt_3=prompt_3,
+ negative_prompt=negative_prompt,
+ negative_prompt_2=negative_prompt_2,
+ negative_prompt_3=negative_prompt_3,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ device=device,
+ clip_skip=self.clip_skip,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ )
+
+ if self.do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+ pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
+
+ # 3. Preprocess image
+ init_image = self.image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
+
+ map = self.mask_processor.preprocess(
+ map, height=height // self.vae_scale_factor, width=width // self.vae_scale_factor
+ ).to(device)
+
+ # 4. Prepare timesteps
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
+
+ # begin diff diff change
+ total_time_steps = num_inference_steps
+ # end diff diff change
+
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.transformer.config.in_channels
+ if latents is None:
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ init_image,
+ latent_timestep,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ )
+
+ # 6. Denoising loop
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ self._num_timesteps = len(timesteps)
+
+ # preparations for diff diff
+ original_with_noise = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ init_image,
+ timesteps,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ )
+ thresholds = torch.arange(total_time_steps, dtype=map.dtype) / total_time_steps
+ thresholds = thresholds.unsqueeze(1).unsqueeze(1).to(device)
+ masks = map.squeeze() > thresholds
+ # end diff diff preparations
+
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ # diff diff
+ if i == 0:
+ latents = original_with_noise[:1]
+ else:
+ mask = masks[i].unsqueeze(0).to(latents.dtype)
+ mask = mask.unsqueeze(1) # fit shape
+ latents = original_with_noise[i] * mask + latents * (1 - mask)
+ # end diff diff
+
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latent_model_input.shape[0])
+
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep,
+ encoder_hidden_states=prompt_embeds,
+ pooled_projections=pooled_prompt_embeds,
+ return_dict=False,
+ )[0]
+
+ # perform guidance
+ if self.do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents_dtype = latents.dtype
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ if latents.dtype != latents_dtype:
+ if torch.backends.mps.is_available():
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
+ latents = latents.to(latents_dtype)
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+ negative_pooled_prompt_embeds = callback_outputs.pop(
+ "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
+ )
+
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ if output_type == "latent":
+ image = latents
+
+ else:
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
+
+ image = self.vae.decode(latents, return_dict=False)[0]
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image,)
+
+ return StableDiffusion3PipelineOutput(images=image)
diff --git a/diffusers/examples/community/pipeline_stable_diffusion_boxdiff.py b/diffusers/examples/community/pipeline_stable_diffusion_boxdiff.py
new file mode 100644
index 0000000000000000000000000000000000000000..6490c1400138c09c7ab85b5e0ea79951586df281
--- /dev/null
+++ b/diffusers/examples/community/pipeline_stable_diffusion_boxdiff.py
@@ -0,0 +1,1705 @@
+# Copyright 2024 Jingyang Zhang and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import abc
+import inspect
+import math
+import numbers
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from packaging import version
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
+
+from diffusers.configuration_utils import FrozenDict
+from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
+from diffusers.loaders import (
+ FromSingleFileMixin,
+ IPAdapterMixin,
+ StableDiffusionLoraLoaderMixin,
+ TextualInversionLoaderMixin,
+)
+from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel
+from diffusers.models.attention_processor import Attention, FusedAttnProcessor2_0
+from diffusers.models.lora import adjust_lora_scale_text_encoder
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline
+from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
+from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
+from diffusers.schedulers import KarrasDiffusionSchedulers
+from diffusers.utils import (
+ USE_PEFT_BACKEND,
+ deprecate,
+ logging,
+ replace_example_docstring,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
+from diffusers.utils.torch_utils import randn_tensor
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import StableDiffusionPipeline
+
+ >>> pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
+ >>> pipe = pipe.to("cuda")
+
+ >>> prompt = "a photo of an astronaut riding a horse on mars"
+ >>> image = pipe(prompt).images[0]
+ ```
+"""
+
+
+class GaussianSmoothing(nn.Module):
+ """
+ Copied from official repo: https://github.com/showlab/BoxDiff/blob/master/utils/gaussian_smoothing.py
+ Apply gaussian smoothing on a
+ 1d, 2d or 3d tensor. Filtering is performed seperately for each channel
+ in the input using a depthwise convolution.
+ Arguments:
+ channels (int, sequence): Number of channels of the input tensors. Output will
+ have this number of channels as well.
+ kernel_size (int, sequence): Size of the gaussian kernel.
+ sigma (float, sequence): Standard deviation of the gaussian kernel.
+ dim (int, optional): The number of dimensions of the data.
+ Default value is 2 (spatial).
+ """
+
+ def __init__(self, channels, kernel_size, sigma, dim=2):
+ super(GaussianSmoothing, self).__init__()
+ if isinstance(kernel_size, numbers.Number):
+ kernel_size = [kernel_size] * dim
+ if isinstance(sigma, numbers.Number):
+ sigma = [sigma] * dim
+
+ # The gaussian kernel is the product of the
+ # gaussian function of each dimension.
+ kernel = 1
+ meshgrids = torch.meshgrid([torch.arange(size, dtype=torch.float32) for size in kernel_size])
+ for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
+ mean = (size - 1) / 2
+ kernel *= 1 / (std * math.sqrt(2 * math.pi)) * torch.exp(-(((mgrid - mean) / (2 * std)) ** 2))
+
+ # Make sure sum of values in gaussian kernel equals 1.
+ kernel = kernel / torch.sum(kernel)
+
+ # Reshape to depthwise convolutional weight
+ kernel = kernel.view(1, 1, *kernel.size())
+ kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1))
+
+ self.register_buffer("weight", kernel)
+ self.groups = channels
+
+ if dim == 1:
+ self.conv = F.conv1d
+ elif dim == 2:
+ self.conv = F.conv2d
+ elif dim == 3:
+ self.conv = F.conv3d
+ else:
+ raise RuntimeError("Only 1, 2 and 3 dimensions are supported. Received {}.".format(dim))
+
+ def forward(self, input):
+ """
+ Apply gaussian filter to input.
+ Arguments:
+ input (torch.Tensor): Input to apply gaussian filter on.
+ Returns:
+ filtered (torch.Tensor): Filtered output.
+ """
+ return self.conv(input, weight=self.weight.to(input.dtype), groups=self.groups)
+
+
+class AttendExciteCrossAttnProcessor:
+ def __init__(self, attnstore, place_in_unet):
+ super().__init__()
+ self.attnstore = attnstore
+ self.place_in_unet = place_in_unet
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.FloatTensor,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ ) -> torch.Tensor:
+ batch_size, sequence_length, _ = hidden_states.shape
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size=1)
+ query = attn.to_q(hidden_states)
+
+ is_cross = encoder_hidden_states is not None
+ encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ query = attn.head_to_batch_dim(query)
+ key = attn.head_to_batch_dim(key)
+ value = attn.head_to_batch_dim(value)
+
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
+ self.attnstore(attention_probs, is_cross, self.place_in_unet)
+
+ hidden_states = torch.bmm(attention_probs, value)
+ hidden_states = attn.batch_to_head_dim(hidden_states)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ return hidden_states
+
+
+class AttentionControl(abc.ABC):
+ def step_callback(self, x_t):
+ return x_t
+
+ def between_steps(self):
+ return
+
+ # @property
+ # def num_uncond_att_layers(self):
+ # return 0
+
+ @abc.abstractmethod
+ def forward(self, attn, is_cross: bool, place_in_unet: str):
+ raise NotImplementedError
+
+ def __call__(self, attn, is_cross: bool, place_in_unet: str):
+ if self.cur_att_layer >= self.num_uncond_att_layers:
+ self.forward(attn, is_cross, place_in_unet)
+ self.cur_att_layer += 1
+ if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers:
+ self.cur_att_layer = 0
+ self.cur_step += 1
+ self.between_steps()
+
+ def reset(self):
+ self.cur_step = 0
+ self.cur_att_layer = 0
+
+ def __init__(self):
+ self.cur_step = 0
+ self.num_att_layers = -1
+ self.cur_att_layer = 0
+
+
+class AttentionStore(AttentionControl):
+ @staticmethod
+ def get_empty_store():
+ return {"down_cross": [], "mid_cross": [], "up_cross": [], "down_self": [], "mid_self": [], "up_self": []}
+
+ def forward(self, attn, is_cross: bool, place_in_unet: str):
+ key = f"{place_in_unet}_{'cross' if is_cross else 'self'}"
+ if attn.shape[1] <= 32**2: # avoid memory overhead
+ self.step_store[key].append(attn)
+ return attn
+
+ def between_steps(self):
+ self.attention_store = self.step_store
+ if self.save_global_store:
+ with torch.no_grad():
+ if len(self.global_store) == 0:
+ self.global_store = self.step_store
+ else:
+ for key in self.global_store:
+ for i in range(len(self.global_store[key])):
+ self.global_store[key][i] += self.step_store[key][i].detach()
+ self.step_store = self.get_empty_store()
+ self.step_store = self.get_empty_store()
+
+ def get_average_attention(self):
+ average_attention = self.attention_store
+ return average_attention
+
+ def get_average_global_attention(self):
+ average_attention = {
+ key: [item / self.cur_step for item in self.global_store[key]] for key in self.attention_store
+ }
+ return average_attention
+
+ def reset(self):
+ super(AttentionStore, self).reset()
+ self.step_store = self.get_empty_store()
+ self.attention_store = {}
+ self.global_store = {}
+
+ def __init__(self, save_global_store=False):
+ """
+ Initialize an empty AttentionStore
+ :param step_index: used to visualize only a specific step in the diffusion process
+ """
+ super(AttentionStore, self).__init__()
+ self.save_global_store = save_global_store
+ self.step_store = self.get_empty_store()
+ self.attention_store = {}
+ self.global_store = {}
+ self.curr_step_index = 0
+ self.num_uncond_att_layers = 0
+
+
+def aggregate_attention(
+ attention_store: AttentionStore, res: int, from_where: List[str], is_cross: bool, select: int
+) -> torch.Tensor:
+ """Aggregates the attention across the different layers and heads at the specified resolution."""
+ out = []
+ attention_maps = attention_store.get_average_attention()
+
+ # for k, v in attention_maps.items():
+ # for vv in v:
+ # print(vv.shape)
+ # exit()
+
+ num_pixels = res**2
+ for location in from_where:
+ for item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]:
+ if item.shape[1] == num_pixels:
+ cross_maps = item.reshape(1, -1, res, res, item.shape[-1])[select]
+ out.append(cross_maps)
+ out = torch.cat(out, dim=0)
+ out = out.sum(0) / out.shape[0]
+ return out
+
+
+def register_attention_control(model, controller):
+ attn_procs = {}
+ cross_att_count = 0
+ for name in model.unet.attn_processors.keys():
+ # cross_attention_dim = None if name.endswith("attn1.processor") else model.unet.config.cross_attention_dim
+ if name.startswith("mid_block"):
+ # hidden_size = model.unet.config.block_out_channels[-1]
+ place_in_unet = "mid"
+ elif name.startswith("up_blocks"):
+ # block_id = int(name[len("up_blocks.")])
+ # hidden_size = list(reversed(model.unet.config.block_out_channels))[block_id]
+ place_in_unet = "up"
+ elif name.startswith("down_blocks"):
+ # block_id = int(name[len("down_blocks.")])
+ # hidden_size = model.unet.config.block_out_channels[block_id]
+ place_in_unet = "down"
+ else:
+ continue
+
+ cross_att_count += 1
+ attn_procs[name] = AttendExciteCrossAttnProcessor(attnstore=controller, place_in_unet=place_in_unet)
+ model.unet.set_attn_processor(attn_procs)
+ controller.num_att_layers = cross_att_count
+
+
+def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
+ """
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
+ """
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
+ # rescale the results from guidance (fixes overexposure)
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
+ return noise_cfg
+
+
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ **kwargs,
+):
+ """
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used,
+ `timesteps` must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
+ timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
+ must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class StableDiffusionBoxDiffPipeline(
+ DiffusionPipeline, TextualInversionLoaderMixin, StableDiffusionLoraLoaderMixin, IPAdapterMixin, FromSingleFileMixin
+):
+ r"""
+ Pipeline for text-to-image generation using Stable Diffusion with BoxDiff.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
+
+ The pipeline also inherits the following loading methods:
+ - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
+ - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
+ - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
+ - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
+ - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
+ text_encoder ([`~transformers.CLIPTextModel`]):
+ Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
+ tokenizer ([`~transformers.CLIPTokenizer`]):
+ A `CLIPTokenizer` to tokenize text.
+ unet ([`UNet2DConditionModel`]):
+ A `UNet2DConditionModel` to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
+ about a model's potential harms.
+ feature_extractor ([`~transformers.CLIPImageProcessor`]):
+ A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
+ """
+
+ model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
+ _optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
+ _exclude_from_cpu_offload = ["safety_checker"]
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: KarrasDiffusionSchedulers,
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPImageProcessor,
+ image_encoder: CLIPVisionModelWithProjection = None,
+ requires_safety_checker: bool = True,
+ ):
+ super().__init__()
+
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
+ " file"
+ )
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["steps_offset"] = 1
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
+ )
+ deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["clip_sample"] = False
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if safety_checker is None and requires_safety_checker:
+ logger.warning(
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
+ )
+
+ if safety_checker is not None and feature_extractor is None:
+ raise ValueError(
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
+ )
+
+ is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
+ version.parse(unet.config._diffusers_version).base_version
+ ) < version.parse("0.9.0.dev0")
+ is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
+ deprecation_message = (
+ "The configuration file of the unet has set the default `sample_size` to smaller than"
+ " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
+ " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
+ " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
+ " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
+ " in the config might lead to incorrect results in future versions. If you have downloaded this"
+ " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
+ " the `unet/config.json` file"
+ )
+ deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(unet.config)
+ new_config["sample_size"] = 64
+ unet._internal_dict = FrozenDict(new_config)
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ image_encoder=image_encoder,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
+
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.vae.enable_slicing()
+
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_slicing()
+
+ def enable_vae_tiling(self):
+ r"""
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
+ processing larger images.
+ """
+ self.vae.enable_tiling()
+
+ def disable_vae_tiling(self):
+ r"""
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_tiling()
+
+ def _encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt=None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ lora_scale: Optional[float] = None,
+ **kwargs,
+ ):
+ deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
+ deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
+
+ prompt_embeds_tuple = self.encode_prompt(
+ prompt=prompt,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ do_classifier_free_guidance=do_classifier_free_guidance,
+ negative_prompt=negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ lora_scale=lora_scale,
+ **kwargs,
+ )
+
+ # concatenate for backwards comp
+ prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
+
+ return prompt_embeds
+
+ def encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt=None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ lora_scale: Optional[float] = None,
+ clip_skip: Optional[int] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ lora_scale (`float`, *optional*):
+ A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+ """
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ # dynamically adjust the LoRA scale
+ if not USE_PEFT_BACKEND:
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
+ else:
+ scale_lora_layers(self.text_encoder, lora_scale)
+
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = self.tokenizer.batch_decode(
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
+ )
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = text_inputs.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ if clip_skip is None:
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
+ prompt_embeds = prompt_embeds[0]
+ else:
+ prompt_embeds = self.text_encoder(
+ text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True
+ )
+ # Access the `hidden_states` first, that contains a tuple of
+ # all the hidden states from the encoder layers. Then index into
+ # the tuple to access the hidden states from the desired layer.
+ prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
+ # We also need to apply the final LayerNorm here to not mess with the
+ # representations. The `last_hidden_states` that we typically use for
+ # obtaining the final prompt representations passes through the LayerNorm
+ # layer.
+ prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
+
+ if self.text_encoder is not None:
+ prompt_embeds_dtype = self.text_encoder.dtype
+ elif self.unet is not None:
+ prompt_embeds_dtype = self.unet.dtype
+ else:
+ prompt_embeds_dtype = prompt_embeds.dtype
+
+ prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = uncond_input.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ negative_prompt_embeds = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ negative_prompt_embeds = negative_prompt_embeds[0]
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder, lora_scale)
+
+ return text_inputs, prompt_embeds, negative_prompt_embeds
+
+ def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
+ dtype = next(self.image_encoder.parameters()).dtype
+
+ if not isinstance(image, torch.Tensor):
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
+
+ image = image.to(device=device, dtype=dtype)
+ if output_hidden_states:
+ image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
+ image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
+ uncond_image_enc_hidden_states = self.image_encoder(
+ torch.zeros_like(image), output_hidden_states=True
+ ).hidden_states[-2]
+ uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
+ num_images_per_prompt, dim=0
+ )
+ return image_enc_hidden_states, uncond_image_enc_hidden_states
+ else:
+ image_embeds = self.image_encoder(image).image_embeds
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+ uncond_image_embeds = torch.zeros_like(image_embeds)
+
+ return image_embeds, uncond_image_embeds
+
+ def run_safety_checker(self, image, device, dtype):
+ if self.safety_checker is None:
+ has_nsfw_concept = None
+ else:
+ if torch.is_tensor(image):
+ feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
+ else:
+ feature_extractor_input = self.image_processor.numpy_to_pil(image)
+ safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
+ image, has_nsfw_concept = self.safety_checker(
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
+ )
+ return image, has_nsfw_concept
+
+ def decode_latents(self, latents):
+ deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
+ deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
+
+ latents = 1 / self.vae.config.scaling_factor * latents
+ image = self.vae.decode(latents, return_dict=False)[0]
+ image = (image / 2 + 0.5).clamp(0, 1)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+ return image
+
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ boxdiff_phrases,
+ boxdiff_boxes,
+ callback_steps,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ ):
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ if boxdiff_phrases is not None or boxdiff_boxes is not None:
+ if not (boxdiff_phrases is not None and boxdiff_boxes is not None):
+ raise ValueError("Either both `boxdiff_phrases` and `boxdiff_boxes` must be passed or none of them.")
+
+ if not isinstance(boxdiff_phrases, list) or not isinstance(boxdiff_boxes, list):
+ raise ValueError("`boxdiff_phrases` and `boxdiff_boxes` must be lists.")
+
+ if len(boxdiff_phrases) != len(boxdiff_boxes):
+ raise ValueError(
+ "`boxdiff_phrases` and `boxdiff_boxes` must have the same length,"
+ f" got: `boxdiff_phrases` {len(boxdiff_phrases)} != `boxdiff_boxes`"
+ f" {len(boxdiff_boxes)}."
+ )
+
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
+ r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497.
+
+ The suffixes after the scaling factors represent the stages where they are being applied.
+
+ Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values
+ that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
+
+ Args:
+ s1 (`float`):
+ Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
+ mitigate "oversmoothing effect" in the enhanced denoising process.
+ s2 (`float`):
+ Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
+ mitigate "oversmoothing effect" in the enhanced denoising process.
+ b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
+ b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
+ """
+ if not hasattr(self, "unet"):
+ raise ValueError("The pipeline must have `unet` for using FreeU.")
+ self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2)
+
+ def disable_freeu(self):
+ """Disables the FreeU mechanism if enabled."""
+ self.unet.disable_freeu()
+
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.fuse_qkv_projections
+ def fuse_qkv_projections(self, unet: bool = True, vae: bool = True):
+ """
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
+ key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
+
+
+
+ This API is 🧪 experimental.
+
+
+
+ Args:
+ unet (`bool`, defaults to `True`): To apply fusion on the UNet.
+ vae (`bool`, defaults to `True`): To apply fusion on the VAE.
+ """
+ self.fusing_unet = False
+ self.fusing_vae = False
+
+ if unet:
+ self.fusing_unet = True
+ self.unet.fuse_qkv_projections()
+ self.unet.set_attn_processor(FusedAttnProcessor2_0())
+
+ if vae:
+ if not isinstance(self.vae, AutoencoderKL):
+ raise ValueError("`fuse_qkv_projections()` is only supported for the VAE of type `AutoencoderKL`.")
+
+ self.fusing_vae = True
+ self.vae.fuse_qkv_projections()
+ self.vae.set_attn_processor(FusedAttnProcessor2_0())
+
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.unfuse_qkv_projections
+ def unfuse_qkv_projections(self, unet: bool = True, vae: bool = True):
+ """Disable QKV projection fusion if enabled.
+
+
+
+ This API is 🧪 experimental.
+
+
+
+ Args:
+ unet (`bool`, defaults to `True`): To apply fusion on the UNet.
+ vae (`bool`, defaults to `True`): To apply fusion on the VAE.
+
+ """
+ if unet:
+ if not self.fusing_unet:
+ logger.warning("The UNet was not initially fused for QKV projections. Doing nothing.")
+ else:
+ self.unet.unfuse_qkv_projections()
+ self.fusing_unet = False
+
+ if vae:
+ if not self.fusing_vae:
+ logger.warning("The VAE was not initially fused for QKV projections. Doing nothing.")
+ else:
+ self.vae.unfuse_qkv_projections()
+ self.fusing_vae = False
+
+ # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
+ def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
+ """
+ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
+
+ Args:
+ timesteps (`torch.Tensor`):
+ generate embedding vectors at these timesteps
+ embedding_dim (`int`, *optional*, defaults to 512):
+ dimension of the embeddings to generate
+ dtype:
+ data type of the generated embeddings
+
+ Returns:
+ `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
+ """
+ assert len(w.shape) == 1
+ w = w * 1000.0
+
+ half_dim = embedding_dim // 2
+ emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
+ emb = w.to(dtype)[:, None] * emb[None, :]
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+ if embedding_dim % 2 == 1: # zero pad
+ emb = torch.nn.functional.pad(emb, (0, 1))
+ assert emb.shape == (w.shape[0], embedding_dim)
+ return emb
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def guidance_rescale(self):
+ return self._guidance_rescale
+
+ @property
+ def clip_skip(self):
+ return self._clip_skip
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
+
+ @property
+ def cross_attention_kwargs(self):
+ return self._cross_attention_kwargs
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ def _compute_max_attention_per_index(
+ self,
+ attention_maps: torch.Tensor,
+ indices_to_alter: List[int],
+ smooth_attentions: bool = False,
+ sigma: float = 0.5,
+ kernel_size: int = 3,
+ normalize_eot: bool = False,
+ bboxes: List[int] = None,
+ L: int = 1,
+ P: float = 0.2,
+ ) -> List[torch.Tensor]:
+ """Computes the maximum attention value for each of the tokens we wish to alter."""
+ last_idx = -1
+ if normalize_eot:
+ prompt = self.prompt
+ if isinstance(self.prompt, list):
+ prompt = self.prompt[0]
+ last_idx = len(self.tokenizer(prompt)["input_ids"]) - 1
+ attention_for_text = attention_maps[:, :, 1:last_idx]
+ attention_for_text *= 100
+ attention_for_text = torch.nn.functional.softmax(attention_for_text, dim=-1)
+
+ # Shift indices since we removed the first token "1:last_idx"
+ indices_to_alter = [index - 1 for index in indices_to_alter]
+
+ # Extract the maximum values
+ max_indices_list_fg = []
+ max_indices_list_bg = []
+ dist_x = []
+ dist_y = []
+
+ cnt = 0
+ for i in indices_to_alter:
+ image = attention_for_text[:, :, i]
+
+ # TODO
+ # box = [max(round(b / (512 / image.shape[0])), 0) for b in bboxes[cnt]]
+ # x1, y1, x2, y2 = box
+ H, W = image.shape
+ x1 = min(max(round(bboxes[cnt][0] * W), 0), W)
+ y1 = min(max(round(bboxes[cnt][1] * H), 0), H)
+ x2 = min(max(round(bboxes[cnt][2] * W), 0), W)
+ y2 = min(max(round(bboxes[cnt][3] * H), 0), H)
+ box = [x1, y1, x2, y2]
+ cnt += 1
+
+ # coordinates to masks
+ obj_mask = torch.zeros_like(image)
+ ones_mask = torch.ones([y2 - y1, x2 - x1], dtype=obj_mask.dtype).to(obj_mask.device)
+ obj_mask[y1:y2, x1:x2] = ones_mask
+ bg_mask = 1 - obj_mask
+
+ if smooth_attentions:
+ smoothing = GaussianSmoothing(channels=1, kernel_size=kernel_size, sigma=sigma, dim=2).to(image.device)
+ input = F.pad(image.unsqueeze(0).unsqueeze(0), (1, 1, 1, 1), mode="reflect")
+ image = smoothing(input).squeeze(0).squeeze(0)
+
+ # Inner-Box constraint
+ k = (obj_mask.sum() * P).long()
+ max_indices_list_fg.append((image * obj_mask).reshape(-1).topk(k)[0].mean())
+
+ # Outer-Box constraint
+ k = (bg_mask.sum() * P).long()
+ max_indices_list_bg.append((image * bg_mask).reshape(-1).topk(k)[0].mean())
+
+ # Corner Constraint
+ gt_proj_x = torch.max(obj_mask, dim=0)[0]
+ gt_proj_y = torch.max(obj_mask, dim=1)[0]
+ corner_mask_x = torch.zeros_like(gt_proj_x)
+ corner_mask_y = torch.zeros_like(gt_proj_y)
+
+ # create gt according to the number config.L
+ N = gt_proj_x.shape[0]
+ corner_mask_x[max(box[0] - L, 0) : min(box[0] + L + 1, N)] = 1.0
+ corner_mask_x[max(box[2] - L, 0) : min(box[2] + L + 1, N)] = 1.0
+ corner_mask_y[max(box[1] - L, 0) : min(box[1] + L + 1, N)] = 1.0
+ corner_mask_y[max(box[3] - L, 0) : min(box[3] + L + 1, N)] = 1.0
+ dist_x.append((F.l1_loss(image.max(dim=0)[0], gt_proj_x, reduction="none") * corner_mask_x).mean())
+ dist_y.append((F.l1_loss(image.max(dim=1)[0], gt_proj_y, reduction="none") * corner_mask_y).mean())
+
+ return max_indices_list_fg, max_indices_list_bg, dist_x, dist_y
+
+ def _aggregate_and_get_max_attention_per_token(
+ self,
+ attention_store: AttentionStore,
+ indices_to_alter: List[int],
+ attention_res: int = 16,
+ smooth_attentions: bool = False,
+ sigma: float = 0.5,
+ kernel_size: int = 3,
+ normalize_eot: bool = False,
+ bboxes: List[int] = None,
+ L: int = 1,
+ P: float = 0.2,
+ ):
+ """Aggregates the attention for each token and computes the max activation value for each token to alter."""
+ attention_maps = aggregate_attention(
+ attention_store=attention_store,
+ res=attention_res,
+ from_where=("up", "down", "mid"),
+ is_cross=True,
+ select=0,
+ )
+ max_attention_per_index_fg, max_attention_per_index_bg, dist_x, dist_y = self._compute_max_attention_per_index(
+ attention_maps=attention_maps,
+ indices_to_alter=indices_to_alter,
+ smooth_attentions=smooth_attentions,
+ sigma=sigma,
+ kernel_size=kernel_size,
+ normalize_eot=normalize_eot,
+ bboxes=bboxes,
+ L=L,
+ P=P,
+ )
+ return max_attention_per_index_fg, max_attention_per_index_bg, dist_x, dist_y
+
+ @staticmethod
+ def _compute_loss(
+ max_attention_per_index_fg: List[torch.Tensor],
+ max_attention_per_index_bg: List[torch.Tensor],
+ dist_x: List[torch.Tensor],
+ dist_y: List[torch.Tensor],
+ return_losses: bool = False,
+ ) -> torch.Tensor:
+ """Computes the attend-and-excite loss using the maximum attention value for each token."""
+ losses_fg = [max(0, 1.0 - curr_max) for curr_max in max_attention_per_index_fg]
+ losses_bg = [max(0, curr_max) for curr_max in max_attention_per_index_bg]
+ loss = sum(losses_fg) + sum(losses_bg) + sum(dist_x) + sum(dist_y)
+ if return_losses:
+ return max(losses_fg), losses_fg
+ else:
+ return max(losses_fg), loss
+
+ @staticmethod
+ def _update_latent(latents: torch.Tensor, loss: torch.Tensor, step_size: float) -> torch.Tensor:
+ """Update the latent according to the computed loss."""
+ grad_cond = torch.autograd.grad(loss.requires_grad_(True), [latents], retain_graph=True)[0]
+ latents = latents - step_size * grad_cond
+ return latents
+
+ def _perform_iterative_refinement_step(
+ self,
+ latents: torch.Tensor,
+ indices_to_alter: List[int],
+ loss_fg: torch.Tensor,
+ threshold: float,
+ text_embeddings: torch.Tensor,
+ text_input,
+ attention_store: AttentionStore,
+ step_size: float,
+ t: int,
+ attention_res: int = 16,
+ smooth_attentions: bool = True,
+ sigma: float = 0.5,
+ kernel_size: int = 3,
+ max_refinement_steps: int = 20,
+ normalize_eot: bool = False,
+ bboxes: List[int] = None,
+ L: int = 1,
+ P: float = 0.2,
+ ):
+ """
+ Performs the iterative latent refinement introduced in the paper. Here, we continuously update the latent
+ code according to our loss objective until the given threshold is reached for all tokens.
+ """
+ iteration = 0
+ target_loss = max(0, 1.0 - threshold)
+
+ while loss_fg > target_loss:
+ iteration += 1
+
+ latents = latents.clone().detach().requires_grad_(True)
+ # noise_pred_text = self.unet(latents, t, encoder_hidden_states=text_embeddings[1].unsqueeze(0)).sample
+ self.unet.zero_grad()
+
+ # Get max activation value for each subject token
+ (
+ max_attention_per_index_fg,
+ max_attention_per_index_bg,
+ dist_x,
+ dist_y,
+ ) = self._aggregate_and_get_max_attention_per_token(
+ attention_store=attention_store,
+ indices_to_alter=indices_to_alter,
+ attention_res=attention_res,
+ smooth_attentions=smooth_attentions,
+ sigma=sigma,
+ kernel_size=kernel_size,
+ normalize_eot=normalize_eot,
+ bboxes=bboxes,
+ L=L,
+ P=P,
+ )
+
+ loss_fg, losses_fg = self._compute_loss(
+ max_attention_per_index_fg, max_attention_per_index_bg, dist_x, dist_y, return_losses=True
+ )
+
+ if loss_fg != 0:
+ latents = self._update_latent(latents, loss_fg, step_size)
+
+ # with torch.no_grad():
+ # noise_pred_uncond = self.unet(latents, t, encoder_hidden_states=text_embeddings[0].unsqueeze(0)).sample
+ # noise_pred_text = self.unet(latents, t, encoder_hidden_states=text_embeddings[1].unsqueeze(0)).sample
+
+ # try:
+ # low_token = np.argmax([l.item() if not isinstance(l, int) else l for l in losses_fg])
+ # except Exception as e:
+ # print(e) # catch edge case :)
+ # low_token = np.argmax(losses_fg)
+
+ # low_word = self.tokenizer.decode(text_input.input_ids[0][indices_to_alter[low_token]])
+ # print(f'\t Try {iteration}. {low_word} has a max attention of {max_attention_per_index_fg[low_token]}')
+
+ if iteration >= max_refinement_steps:
+ # print(f'\t Exceeded max number of iterations ({max_refinement_steps})! '
+ # f'Finished with a max attention of {max_attention_per_index_fg[low_token]}')
+ break
+
+ # Run one more time but don't compute gradients and update the latents.
+ # We just need to compute the new loss - the grad update will occur below
+ latents = latents.clone().detach().requires_grad_(True)
+ # noise_pred_text = self.unet(latents, t, encoder_hidden_states=text_embeddings[1].unsqueeze(0)).sample
+ self.unet.zero_grad()
+
+ # Get max activation value for each subject token
+ (
+ max_attention_per_index_fg,
+ max_attention_per_index_bg,
+ dist_x,
+ dist_y,
+ ) = self._aggregate_and_get_max_attention_per_token(
+ attention_store=attention_store,
+ indices_to_alter=indices_to_alter,
+ attention_res=attention_res,
+ smooth_attentions=smooth_attentions,
+ sigma=sigma,
+ kernel_size=kernel_size,
+ normalize_eot=normalize_eot,
+ bboxes=bboxes,
+ L=L,
+ P=P,
+ )
+ loss_fg, losses_fg = self._compute_loss(
+ max_attention_per_index_fg, max_attention_per_index_bg, dist_x, dist_y, return_losses=True
+ )
+ # print(f"\t Finished with loss of: {loss_fg}")
+ return loss_fg, latents, max_attention_per_index_fg
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ boxdiff_phrases: List[str] = None,
+ boxdiff_boxes: List[List[float]] = None, # TODO
+ boxdiff_kwargs: Optional[Dict[str, Any]] = {
+ "attention_res": 16,
+ "P": 0.2,
+ "L": 1,
+ "max_iter_to_alter": 25,
+ "loss_thresholds": {0: 0.05, 10: 0.5, 20: 0.8},
+ "scale_factor": 20,
+ "scale_range": (1.0, 0.5),
+ "smooth_attentions": True,
+ "sigma": 0.5,
+ "kernel_size": 3,
+ "refine": False,
+ "normalize_eot": True,
+ },
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ timesteps: List[int] = None,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ ip_adapter_image: Optional[PipelineImageInput] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ guidance_rescale: float = 0.0,
+ clip_skip: Optional[int] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ **kwargs,
+ ):
+ r"""
+ The call function to the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
+
+ boxdiff_attention_res (`int`, *optional*, defaults to 16):
+ The resolution of the attention maps used for computing the BoxDiff loss.
+ boxdiff_P (`float`, *optional*, defaults to 0.2):
+
+ boxdiff_L (`int`, *optional*, defaults to 1):
+ The number of pixels around the corner to be selected in BoxDiff loss.
+ height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
+ passed will be used. Must be in descending order.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ A higher guidance scale value encourages the model to generate images closely linked to the text
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
+ generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor is generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
+ provided, text embeddings are generated from the `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
+ Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when
+ using zero terminal SNR.
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
+ otherwise a `tuple` is returned where the first element is a list with the generated images and the
+ second element is a list of `bool`s indicating whether the corresponding generated image contains
+ "not-safe-for-work" (nsfw) content.
+ """
+
+ callback = kwargs.pop("callback", None)
+ callback_steps = kwargs.pop("callback_steps", None)
+
+ if callback is not None:
+ deprecate(
+ "callback",
+ "1.0.0",
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
+ )
+ if callback_steps is not None:
+ deprecate(
+ "callback_steps",
+ "1.0.0",
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
+ )
+
+ # -1. Register attention control (for BoxDiff)
+ attention_store = AttentionStore()
+ register_attention_control(self, attention_store)
+
+ # 0. Default height and width to unet
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
+ # to deal with lora scaling and other possible forward hooks
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ height,
+ width,
+ boxdiff_phrases,
+ boxdiff_boxes,
+ callback_steps,
+ negative_prompt,
+ prompt_embeds,
+ negative_prompt_embeds,
+ callback_on_step_end_tensor_inputs,
+ )
+ self.prompt = prompt
+
+ self._guidance_scale = guidance_scale
+ self._guidance_rescale = guidance_rescale
+ self._clip_skip = clip_skip
+ self._cross_attention_kwargs = cross_attention_kwargs
+ self._interrupt = False
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ # 3. Encode input prompt
+ lora_scale = (
+ self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
+ )
+
+ text_inputs, prompt_embeds, negative_prompt_embeds = self.encode_prompt(
+ prompt,
+ device,
+ num_images_per_prompt,
+ self.do_classifier_free_guidance,
+ negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ lora_scale=lora_scale,
+ clip_skip=self.clip_skip,
+ )
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ if self.do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+
+ if ip_adapter_image is not None:
+ output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
+ image_embeds, negative_image_embeds = self.encode_image(
+ ip_adapter_image, device, num_images_per_prompt, output_hidden_state
+ )
+ if self.do_classifier_free_guidance:
+ image_embeds = torch.cat([negative_image_embeds, image_embeds])
+
+ # 4. Prepare timesteps
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.unet.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 6.1 Add image embeds for IP-Adapter
+ added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
+
+ # 6.2 Optionally get Guidance Scale Embedding
+ timestep_cond = None
+ if self.unet.config.time_cond_proj_dim is not None:
+ guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
+ timestep_cond = self.get_guidance_scale_embedding(
+ guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
+ ).to(device=device, dtype=latents.dtype)
+
+ # 6.3 Prepare BoxDiff inputs
+ # a) Indices to alter
+ input_ids = self.tokenizer(prompt)["input_ids"]
+ decoded = [self.tokenizer.decode([t]) for t in input_ids]
+ indices_to_alter = []
+ bboxes = []
+ for phrase, box in zip(boxdiff_phrases, boxdiff_boxes):
+ # it could happen that phrase does not correspond a single token?
+ if phrase not in decoded:
+ continue
+ indices_to_alter.append(decoded.index(phrase))
+ bboxes.append(box)
+
+ # b) A bunch of hyperparameters
+ attention_res = boxdiff_kwargs.get("attention_res", 16)
+ smooth_attentions = boxdiff_kwargs.get("smooth_attentions", True)
+ sigma = boxdiff_kwargs.get("sigma", 0.5)
+ kernel_size = boxdiff_kwargs.get("kernel_size", 3)
+ L = boxdiff_kwargs.get("L", 1)
+ P = boxdiff_kwargs.get("P", 0.2)
+ thresholds = boxdiff_kwargs.get("loss_thresholds", {0: 0.05, 10: 0.5, 20: 0.8})
+ max_iter_to_alter = boxdiff_kwargs.get("max_iter_to_alter", len(self.scheduler.timesteps) + 1)
+ scale_factor = boxdiff_kwargs.get("scale_factor", 20)
+ refine = boxdiff_kwargs.get("refine", False)
+ normalize_eot = boxdiff_kwargs.get("normalize_eot", True)
+
+ scale_range = boxdiff_kwargs.get("scale_range", (1.0, 0.5))
+ scale_range = np.linspace(scale_range[0], scale_range[1], len(self.scheduler.timesteps))
+
+ # 7. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ self._num_timesteps = len(timesteps)
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ # BoxDiff optimization
+ with torch.enable_grad():
+ latents = latents.clone().detach().requires_grad_(True)
+
+ # Forward pass of denoising with text conditioning
+ noise_pred_text = self.unet(
+ latents,
+ t,
+ encoder_hidden_states=prompt_embeds[1].unsqueeze(0),
+ cross_attention_kwargs=cross_attention_kwargs,
+ ).sample
+ self.unet.zero_grad()
+
+ # Get max activation value for each subject token
+ (
+ max_attention_per_index_fg,
+ max_attention_per_index_bg,
+ dist_x,
+ dist_y,
+ ) = self._aggregate_and_get_max_attention_per_token(
+ attention_store=attention_store,
+ indices_to_alter=indices_to_alter,
+ attention_res=attention_res,
+ smooth_attentions=smooth_attentions,
+ sigma=sigma,
+ kernel_size=kernel_size,
+ normalize_eot=normalize_eot,
+ bboxes=bboxes,
+ L=L,
+ P=P,
+ )
+
+ loss_fg, loss = self._compute_loss(
+ max_attention_per_index_fg, max_attention_per_index_bg, dist_x, dist_y
+ )
+
+ # Refinement from attend-and-excite (not necessary)
+ if refine and i in thresholds.keys() and loss_fg > 1.0 - thresholds[i]:
+ del noise_pred_text
+ torch.cuda.empty_cache()
+ loss_fg, latents, max_attention_per_index_fg = self._perform_iterative_refinement_step(
+ latents=latents,
+ indices_to_alter=indices_to_alter,
+ loss_fg=loss_fg,
+ threshold=thresholds[i],
+ text_embeddings=prompt_embeds,
+ text_input=text_inputs,
+ attention_store=attention_store,
+ step_size=scale_factor * np.sqrt(scale_range[i]),
+ t=t,
+ attention_res=attention_res,
+ smooth_attentions=smooth_attentions,
+ sigma=sigma,
+ kernel_size=kernel_size,
+ normalize_eot=normalize_eot,
+ bboxes=bboxes,
+ L=L,
+ P=P,
+ )
+
+ # Perform gradient update
+ if i < max_iter_to_alter:
+ _, loss = self._compute_loss(
+ max_attention_per_index_fg, max_attention_per_index_bg, dist_x, dist_y
+ )
+ if loss != 0:
+ latents = self._update_latent(
+ latents=latents, loss=loss, step_size=scale_factor * np.sqrt(scale_range[i])
+ )
+
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ timestep_cond=timestep_cond,
+ cross_attention_kwargs=self.cross_attention_kwargs,
+ added_cond_kwargs=added_cond_kwargs,
+ return_dict=False,
+ )[0]
+
+ # perform guidance
+ if self.do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
+
+ if not output_type == "latent":
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
+ 0
+ ]
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
+ else:
+ image = latents
+ has_nsfw_concept = None
+
+ if has_nsfw_concept is None:
+ do_denormalize = [True] * image.shape[0]
+ else:
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
+
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
diff --git a/diffusers/examples/community/pipeline_stable_diffusion_pag.py b/diffusers/examples/community/pipeline_stable_diffusion_pag.py
new file mode 100644
index 0000000000000000000000000000000000000000..cea2c9735747e51c2f415e18e51aaefff2712ca3
--- /dev/null
+++ b/diffusers/examples/community/pipeline_stable_diffusion_pag.py
@@ -0,0 +1,1476 @@
+# Implementation of StableDiffusionPipeline with PAG
+# https://ku-cvlab.github.io/Perturbed-Attention-Guidance
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import torch
+import torch.nn.functional as F
+from packaging import version
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
+
+from diffusers.configuration_utils import FrozenDict
+from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
+from diffusers.loaders import (
+ FromSingleFileMixin,
+ IPAdapterMixin,
+ StableDiffusionLoraLoaderMixin,
+ TextualInversionLoaderMixin,
+)
+from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel
+from diffusers.models.attention_processor import Attention, AttnProcessor2_0, FusedAttnProcessor2_0
+from diffusers.models.lora import adjust_lora_scale_text_encoder
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline
+from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
+from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
+from diffusers.schedulers import KarrasDiffusionSchedulers
+from diffusers.utils import (
+ USE_PEFT_BACKEND,
+ deprecate,
+ logging,
+ replace_example_docstring,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
+from diffusers.utils.torch_utils import randn_tensor
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import StableDiffusionPipeline
+ >>> pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
+ >>> pipe = pipe.to("cuda")
+ >>> prompt = "a photo of an astronaut riding a horse on mars"
+ >>> image = pipe(prompt).images[0]
+ ```
+"""
+
+
+class PAGIdentitySelfAttnProcessor:
+ r"""
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
+ """
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ temb: Optional[torch.Tensor] = None,
+ *args,
+ **kwargs,
+ ) -> torch.Tensor:
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
+ deprecate("scale", "1.0.0", deprecation_message)
+
+ residual = hidden_states
+ if attn.spatial_norm is not None:
+ hidden_states = attn.spatial_norm(hidden_states, temb)
+
+ input_ndim = hidden_states.ndim
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ # chunk
+ hidden_states_org, hidden_states_ptb = hidden_states.chunk(2)
+
+ # original path
+ batch_size, sequence_length, _ = hidden_states_org.shape
+
+ if attention_mask is not None:
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+ # scaled_dot_product_attention expects attention_mask shape to be
+ # (batch, heads, source_length, target_length)
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
+
+ if attn.group_norm is not None:
+ hidden_states_org = attn.group_norm(hidden_states_org.transpose(1, 2)).transpose(1, 2)
+
+ query = attn.to_q(hidden_states_org)
+ key = attn.to_k(hidden_states_org)
+ value = attn.to_v(hidden_states_org)
+
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
+ # TODO: add support for attn.scale when we move to Torch 2.1
+ hidden_states_org = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ )
+
+ hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states_org = hidden_states_org.to(query.dtype)
+
+ # linear proj
+ hidden_states_org = attn.to_out[0](hidden_states_org)
+ # dropout
+ hidden_states_org = attn.to_out[1](hidden_states_org)
+
+ if input_ndim == 4:
+ hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ # perturbed path (identity attention)
+ batch_size, sequence_length, _ = hidden_states_ptb.shape
+
+ if attention_mask is not None:
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+ # scaled_dot_product_attention expects attention_mask shape to be
+ # (batch, heads, source_length, target_length)
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
+
+ if attn.group_norm is not None:
+ hidden_states_ptb = attn.group_norm(hidden_states_ptb.transpose(1, 2)).transpose(1, 2)
+
+ value = attn.to_v(hidden_states_ptb)
+
+ # hidden_states_ptb = torch.zeros(value.shape).to(value.get_device())
+ hidden_states_ptb = value
+
+ hidden_states_ptb = hidden_states_ptb.to(query.dtype)
+
+ # linear proj
+ hidden_states_ptb = attn.to_out[0](hidden_states_ptb)
+ # dropout
+ hidden_states_ptb = attn.to_out[1](hidden_states_ptb)
+
+ if input_ndim == 4:
+ hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ # cat
+ hidden_states = torch.cat([hidden_states_org, hidden_states_ptb])
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ return hidden_states
+
+
+class PAGCFGIdentitySelfAttnProcessor:
+ r"""
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
+ """
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ temb: Optional[torch.Tensor] = None,
+ *args,
+ **kwargs,
+ ) -> torch.Tensor:
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
+ deprecate("scale", "1.0.0", deprecation_message)
+
+ residual = hidden_states
+ if attn.spatial_norm is not None:
+ hidden_states = attn.spatial_norm(hidden_states, temb)
+
+ input_ndim = hidden_states.ndim
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ # chunk
+ hidden_states_uncond, hidden_states_org, hidden_states_ptb = hidden_states.chunk(3)
+ hidden_states_org = torch.cat([hidden_states_uncond, hidden_states_org])
+
+ # original path
+ batch_size, sequence_length, _ = hidden_states_org.shape
+
+ if attention_mask is not None:
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+ # scaled_dot_product_attention expects attention_mask shape to be
+ # (batch, heads, source_length, target_length)
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
+
+ if attn.group_norm is not None:
+ hidden_states_org = attn.group_norm(hidden_states_org.transpose(1, 2)).transpose(1, 2)
+
+ query = attn.to_q(hidden_states_org)
+ key = attn.to_k(hidden_states_org)
+ value = attn.to_v(hidden_states_org)
+
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
+ # TODO: add support for attn.scale when we move to Torch 2.1
+ hidden_states_org = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ )
+
+ hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states_org = hidden_states_org.to(query.dtype)
+
+ # linear proj
+ hidden_states_org = attn.to_out[0](hidden_states_org)
+ # dropout
+ hidden_states_org = attn.to_out[1](hidden_states_org)
+
+ if input_ndim == 4:
+ hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ # perturbed path (identity attention)
+ batch_size, sequence_length, _ = hidden_states_ptb.shape
+
+ if attention_mask is not None:
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+ # scaled_dot_product_attention expects attention_mask shape to be
+ # (batch, heads, source_length, target_length)
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
+
+ if attn.group_norm is not None:
+ hidden_states_ptb = attn.group_norm(hidden_states_ptb.transpose(1, 2)).transpose(1, 2)
+
+ value = attn.to_v(hidden_states_ptb)
+ hidden_states_ptb = value
+ hidden_states_ptb = hidden_states_ptb.to(query.dtype)
+
+ # linear proj
+ hidden_states_ptb = attn.to_out[0](hidden_states_ptb)
+ # dropout
+ hidden_states_ptb = attn.to_out[1](hidden_states_ptb)
+
+ if input_ndim == 4:
+ hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ # cat
+ hidden_states = torch.cat([hidden_states_org, hidden_states_ptb])
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ return hidden_states
+
+
+def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
+ """
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
+ """
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
+ # rescale the results from guidance (fixes overexposure)
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
+ return noise_cfg
+
+
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ **kwargs,
+):
+ """
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used,
+ `timesteps` must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
+ timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
+ must be `None`.
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class StableDiffusionPAGPipeline(
+ DiffusionPipeline, TextualInversionLoaderMixin, StableDiffusionLoraLoaderMixin, IPAdapterMixin, FromSingleFileMixin
+):
+ r"""
+ Pipeline for text-to-image generation using Stable Diffusion.
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
+ The pipeline also inherits the following loading methods:
+ - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
+ - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
+ - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
+ - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
+ - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
+ text_encoder ([`~transformers.CLIPTextModel`]):
+ Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
+ tokenizer ([`~transformers.CLIPTokenizer`]):
+ A `CLIPTokenizer` to tokenize text.
+ unet ([`UNet2DConditionModel`]):
+ A `UNet2DConditionModel` to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
+ about a model's potential harms.
+ feature_extractor ([`~transformers.CLIPImageProcessor`]):
+ A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
+ """
+
+ model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
+ _optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
+ _exclude_from_cpu_offload = ["safety_checker"]
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: KarrasDiffusionSchedulers,
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPImageProcessor,
+ image_encoder: CLIPVisionModelWithProjection = None,
+ requires_safety_checker: bool = True,
+ ):
+ super().__init__()
+
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
+ " file"
+ )
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["steps_offset"] = 1
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
+ )
+ deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["clip_sample"] = False
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if safety_checker is None and requires_safety_checker:
+ logger.warning(
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
+ )
+
+ if safety_checker is not None and feature_extractor is None:
+ raise ValueError(
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
+ )
+
+ is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
+ version.parse(unet.config._diffusers_version).base_version
+ ) < version.parse("0.9.0.dev0")
+ is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
+ deprecation_message = (
+ "The configuration file of the unet has set the default `sample_size` to smaller than"
+ " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
+ " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
+ " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
+ " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
+ " in the config might lead to incorrect results in future versions. If you have downloaded this"
+ " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
+ " the `unet/config.json` file"
+ )
+ deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(unet.config)
+ new_config["sample_size"] = 64
+ unet._internal_dict = FrozenDict(new_config)
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ image_encoder=image_encoder,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
+
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.vae.enable_slicing()
+
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_slicing()
+
+ def enable_vae_tiling(self):
+ r"""
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
+ processing larger images.
+ """
+ self.vae.enable_tiling()
+
+ def disable_vae_tiling(self):
+ r"""
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_tiling()
+
+ def _encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt=None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ lora_scale: Optional[float] = None,
+ **kwargs,
+ ):
+ deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
+ deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
+
+ prompt_embeds_tuple = self.encode_prompt(
+ prompt=prompt,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ do_classifier_free_guidance=do_classifier_free_guidance,
+ negative_prompt=negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ lora_scale=lora_scale,
+ **kwargs,
+ )
+
+ # concatenate for backwards comp
+ prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
+
+ return prompt_embeds
+
+ def encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt=None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ lora_scale: Optional[float] = None,
+ clip_skip: Optional[int] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ lora_scale (`float`, *optional*):
+ A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+ """
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ # dynamically adjust the LoRA scale
+ if not USE_PEFT_BACKEND:
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
+ else:
+ scale_lora_layers(self.text_encoder, lora_scale)
+
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ # textual inversion: process multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = self.tokenizer.batch_decode(
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
+ )
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = text_inputs.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ if clip_skip is None:
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
+ prompt_embeds = prompt_embeds[0]
+ else:
+ prompt_embeds = self.text_encoder(
+ text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True
+ )
+ # Access the `hidden_states` first, that contains a tuple of
+ # all the hidden states from the encoder layers. Then index into
+ # the tuple to access the hidden states from the desired layer.
+ prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
+ # We also need to apply the final LayerNorm here to not mess with the
+ # representations. The `last_hidden_states` that we typically use for
+ # obtaining the final prompt representations passes through the LayerNorm
+ # layer.
+ prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
+
+ if self.text_encoder is not None:
+ prompt_embeds_dtype = self.text_encoder.dtype
+ elif self.unet is not None:
+ prompt_embeds_dtype = self.unet.dtype
+ else:
+ prompt_embeds_dtype = prompt_embeds.dtype
+
+ prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ # textual inversion: process multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = uncond_input.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ negative_prompt_embeds = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ negative_prompt_embeds = negative_prompt_embeds[0]
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder, lora_scale)
+
+ return prompt_embeds, negative_prompt_embeds
+
+ def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
+ dtype = next(self.image_encoder.parameters()).dtype
+
+ if not isinstance(image, torch.Tensor):
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
+
+ image = image.to(device=device, dtype=dtype)
+ if output_hidden_states:
+ image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
+ image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
+ uncond_image_enc_hidden_states = self.image_encoder(
+ torch.zeros_like(image), output_hidden_states=True
+ ).hidden_states[-2]
+ uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
+ num_images_per_prompt, dim=0
+ )
+ return image_enc_hidden_states, uncond_image_enc_hidden_states
+ else:
+ image_embeds = self.image_encoder(image).image_embeds
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+ uncond_image_embeds = torch.zeros_like(image_embeds)
+
+ return image_embeds, uncond_image_embeds
+
+ def prepare_ip_adapter_image_embeds(
+ self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt
+ ):
+ if ip_adapter_image_embeds is None:
+ if not isinstance(ip_adapter_image, list):
+ ip_adapter_image = [ip_adapter_image]
+
+ if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
+ raise ValueError(
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
+ )
+
+ image_embeds = []
+ for single_ip_adapter_image, image_proj_layer in zip(
+ ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
+ ):
+ output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
+ single_image_embeds, single_negative_image_embeds = self.encode_image(
+ single_ip_adapter_image, device, 1, output_hidden_state
+ )
+ single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
+ single_negative_image_embeds = torch.stack(
+ [single_negative_image_embeds] * num_images_per_prompt, dim=0
+ )
+
+ if self.do_classifier_free_guidance:
+ single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
+ single_image_embeds = single_image_embeds.to(device)
+
+ image_embeds.append(single_image_embeds)
+ else:
+ image_embeds = ip_adapter_image_embeds
+ return image_embeds
+
+ def run_safety_checker(self, image, device, dtype):
+ if self.safety_checker is None:
+ has_nsfw_concept = None
+ else:
+ if torch.is_tensor(image):
+ feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
+ else:
+ feature_extractor_input = self.image_processor.numpy_to_pil(image)
+ safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
+ image, has_nsfw_concept = self.safety_checker(
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
+ )
+ return image, has_nsfw_concept
+
+ def decode_latents(self, latents):
+ deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
+ deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
+
+ latents = 1 / self.vae.config.scaling_factor * latents
+ image = self.vae.decode(latents, return_dict=False)[0]
+ image = (image / 2 + 0.5).clamp(0, 1)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+ return image
+
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ callback_steps,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ ip_adapter_image=None,
+ ip_adapter_image_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ ):
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
+ raise ValueError(
+ "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
+ )
+
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
+ shape = (
+ batch_size,
+ num_channels_latents,
+ int(height) // self.vae_scale_factor,
+ int(width) // self.vae_scale_factor,
+ )
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
+ r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497.
+ The suffixes after the scaling factors represent the stages where they are being applied.
+ Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values
+ that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
+ Args:
+ s1 (`float`):
+ Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
+ mitigate "oversmoothing effect" in the enhanced denoising process.
+ s2 (`float`):
+ Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
+ mitigate "oversmoothing effect" in the enhanced denoising process.
+ b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
+ b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
+ """
+ if not hasattr(self, "unet"):
+ raise ValueError("The pipeline must have `unet` for using FreeU.")
+ self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2)
+
+ def disable_freeu(self):
+ """Disables the FreeU mechanism if enabled."""
+ self.unet.disable_freeu()
+
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.fuse_qkv_projections
+ def fuse_qkv_projections(self, unet: bool = True, vae: bool = True):
+ """
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
+ key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
+
+ This API is 🧪 experimental.
+
+ Args:
+ unet (`bool`, defaults to `True`): To apply fusion on the UNet.
+ vae (`bool`, defaults to `True`): To apply fusion on the VAE.
+ """
+ self.fusing_unet = False
+ self.fusing_vae = False
+
+ if unet:
+ self.fusing_unet = True
+ self.unet.fuse_qkv_projections()
+ self.unet.set_attn_processor(FusedAttnProcessor2_0())
+
+ if vae:
+ if not isinstance(self.vae, AutoencoderKL):
+ raise ValueError("`fuse_qkv_projections()` is only supported for the VAE of type `AutoencoderKL`.")
+
+ self.fusing_vae = True
+ self.vae.fuse_qkv_projections()
+ self.vae.set_attn_processor(FusedAttnProcessor2_0())
+
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.unfuse_qkv_projections
+ def unfuse_qkv_projections(self, unet: bool = True, vae: bool = True):
+ """Disable QKV projection fusion if enabled.
+
+ This API is 🧪 experimental.
+
+ Args:
+ unet (`bool`, defaults to `True`): To apply fusion on the UNet.
+ vae (`bool`, defaults to `True`): To apply fusion on the VAE.
+ """
+ if unet:
+ if not self.fusing_unet:
+ logger.warning("The UNet was not initially fused for QKV projections. Doing nothing.")
+ else:
+ self.unet.unfuse_qkv_projections()
+ self.fusing_unet = False
+
+ if vae:
+ if not self.fusing_vae:
+ logger.warning("The VAE was not initially fused for QKV projections. Doing nothing.")
+ else:
+ self.vae.unfuse_qkv_projections()
+ self.fusing_vae = False
+
+ # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
+ def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
+ """
+ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
+ Args:
+ timesteps (`torch.Tensor`):
+ generate embedding vectors at these timesteps
+ embedding_dim (`int`, *optional*, defaults to 512):
+ dimension of the embeddings to generate
+ dtype:
+ data type of the generated embeddings
+ Returns:
+ `torch.Tensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
+ """
+ assert len(w.shape) == 1
+ w = w * 1000.0
+
+ half_dim = embedding_dim // 2
+ emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
+ emb = w.to(dtype)[:, None] * emb[None, :]
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+ if embedding_dim % 2 == 1: # zero pad
+ emb = torch.nn.functional.pad(emb, (0, 1))
+ assert emb.shape == (w.shape[0], embedding_dim)
+ return emb
+
+ def pred_z0(self, sample, model_output, timestep):
+ alpha_prod_t = self.scheduler.alphas_cumprod[timestep].to(sample.device)
+
+ beta_prod_t = 1 - alpha_prod_t
+ if self.scheduler.config.prediction_type == "epsilon":
+ pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
+ elif self.scheduler.config.prediction_type == "sample":
+ pred_original_sample = model_output
+ elif self.scheduler.config.prediction_type == "v_prediction":
+ pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
+ # predict V
+ model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
+ else:
+ raise ValueError(
+ f"prediction_type given as {self.scheduler.config.prediction_type} must be one of `epsilon`, `sample`,"
+ " or `v_prediction`"
+ )
+
+ return pred_original_sample
+
+ def pred_x0(self, latents, noise_pred, t, generator, device, prompt_embeds, output_type):
+ pred_z0 = self.pred_z0(latents, noise_pred, t)
+ pred_x0 = self.vae.decode(pred_z0 / self.vae.config.scaling_factor, return_dict=False, generator=generator)[0]
+ pred_x0, ____ = self.run_safety_checker(pred_x0, device, prompt_embeds.dtype)
+ do_denormalize = [True] * pred_x0.shape[0]
+ pred_x0 = self.image_processor.postprocess(pred_x0, output_type=output_type, do_denormalize=do_denormalize)
+
+ return pred_x0
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def guidance_rescale(self):
+ return self._guidance_rescale
+
+ @property
+ def clip_skip(self):
+ return self._clip_skip
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
+
+ @property
+ def cross_attention_kwargs(self):
+ return self._cross_attention_kwargs
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @property
+ def pag_scale(self):
+ return self._pag_scale
+
+ @property
+ def do_perturbed_attention_guidance(self):
+ return self._pag_scale > 0
+
+ @property
+ def pag_adaptive_scaling(self):
+ return self._pag_adaptive_scaling
+
+ @property
+ def do_pag_adaptive_scaling(self):
+ return self._pag_adaptive_scaling > 0
+
+ @property
+ def pag_applied_layers_index(self):
+ return self._pag_applied_layers_index
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ timesteps: List[int] = None,
+ guidance_scale: float = 7.5,
+ pag_scale: float = 0.0,
+ pag_adaptive_scaling: float = 0.0,
+ pag_applied_layers_index: List[str] = ["d4"], # ['d4', 'd5', 'm0']
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ ip_adapter_image: Optional[PipelineImageInput] = None,
+ ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ guidance_rescale: float = 0.0,
+ clip_skip: Optional[int] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ **kwargs,
+ ):
+ r"""
+ The call function to the pipeline for generation.
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
+ height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
+ passed will be used. Must be in descending order.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ A higher guidance scale value encourages the model to generate images closely linked to the text
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
+ generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor is generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
+ provided, text embeddings are generated from the `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
+ ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
+ Pre-generated image embeddings for IP-Adapter. If not
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
+ Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when
+ using zero terminal SNR.
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ Examples:
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
+ otherwise a `tuple` is returned where the first element is a list with the generated images and the
+ second element is a list of `bool`s indicating whether the corresponding generated image contains
+ "not-safe-for-work" (nsfw) content.
+ """
+
+ callback = kwargs.pop("callback", None)
+ callback_steps = kwargs.pop("callback_steps", None)
+
+ if callback is not None:
+ deprecate(
+ "callback",
+ "1.0.0",
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
+ )
+ if callback_steps is not None:
+ deprecate(
+ "callback_steps",
+ "1.0.0",
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
+ )
+
+ # 0. Default height and width to unet
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
+ # to deal with lora scaling and other possible forward hooks
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ height,
+ width,
+ callback_steps,
+ negative_prompt,
+ prompt_embeds,
+ negative_prompt_embeds,
+ ip_adapter_image,
+ ip_adapter_image_embeds,
+ callback_on_step_end_tensor_inputs,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._guidance_rescale = guidance_rescale
+ self._clip_skip = clip_skip
+ self._cross_attention_kwargs = cross_attention_kwargs
+ self._interrupt = False
+
+ self._pag_scale = pag_scale
+ self._pag_adaptive_scaling = pag_adaptive_scaling
+ self._pag_applied_layers_index = pag_applied_layers_index
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ # 3. Encode input prompt
+ lora_scale = (
+ self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
+ )
+
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
+ prompt,
+ device,
+ num_images_per_prompt,
+ self.do_classifier_free_guidance,
+ negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ lora_scale=lora_scale,
+ clip_skip=self.clip_skip,
+ )
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+
+ # cfg
+ if self.do_classifier_free_guidance and not self.do_perturbed_attention_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+ # pag
+ elif not self.do_classifier_free_guidance and self.do_perturbed_attention_guidance:
+ prompt_embeds = torch.cat([prompt_embeds, prompt_embeds])
+ # both
+ elif self.do_classifier_free_guidance and self.do_perturbed_attention_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds, prompt_embeds])
+
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
+ image_embeds = self.prepare_ip_adapter_image_embeds(
+ ip_adapter_image, ip_adapter_image_embeds, device, batch_size * num_images_per_prompt
+ )
+
+ # 4. Prepare timesteps
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.unet.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 6.1 Add image embeds for IP-Adapter
+ added_cond_kwargs = (
+ {"image_embeds": image_embeds}
+ if (ip_adapter_image is not None or ip_adapter_image_embeds is not None)
+ else None
+ )
+
+ # 6.2 Optionally get Guidance Scale Embedding
+ timestep_cond = None
+ if self.unet.config.time_cond_proj_dim is not None:
+ guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
+ timestep_cond = self.get_guidance_scale_embedding(
+ guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
+ ).to(device=device, dtype=latents.dtype)
+
+ # 7. Denoising loop
+ if self.do_perturbed_attention_guidance:
+ down_layers = []
+ mid_layers = []
+ up_layers = []
+ for name, module in self.unet.named_modules():
+ if "attn1" in name and "to" not in name:
+ layer_type = name.split(".")[0].split("_")[0]
+ if layer_type == "down":
+ down_layers.append(module)
+ elif layer_type == "mid":
+ mid_layers.append(module)
+ elif layer_type == "up":
+ up_layers.append(module)
+ else:
+ raise ValueError(f"Invalid layer type: {layer_type}")
+
+ # change attention layer in UNet if use PAG
+ if self.do_perturbed_attention_guidance:
+ if self.do_classifier_free_guidance:
+ replace_processor = PAGCFGIdentitySelfAttnProcessor()
+ else:
+ replace_processor = PAGIdentitySelfAttnProcessor()
+
+ drop_layers = self.pag_applied_layers_index
+ for drop_layer in drop_layers:
+ try:
+ if drop_layer[0] == "d":
+ down_layers[int(drop_layer[1])].processor = replace_processor
+ elif drop_layer[0] == "m":
+ mid_layers[int(drop_layer[1])].processor = replace_processor
+ elif drop_layer[0] == "u":
+ up_layers[int(drop_layer[1])].processor = replace_processor
+ else:
+ raise ValueError(f"Invalid layer type: {drop_layer[0]}")
+ except IndexError:
+ raise ValueError(
+ f"Invalid layer index: {drop_layer}. Available layers: {len(down_layers)} down layers, {len(mid_layers)} mid layers, {len(up_layers)} up layers."
+ )
+
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ self._num_timesteps = len(timesteps)
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ # cfg
+ if self.do_classifier_free_guidance and not self.do_perturbed_attention_guidance:
+ latent_model_input = torch.cat([latents] * 2)
+ # pag
+ elif not self.do_classifier_free_guidance and self.do_perturbed_attention_guidance:
+ latent_model_input = torch.cat([latents] * 2)
+ # both
+ elif self.do_classifier_free_guidance and self.do_perturbed_attention_guidance:
+ latent_model_input = torch.cat([latents] * 3)
+ # no
+ else:
+ latent_model_input = latents
+
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ timestep_cond=timestep_cond,
+ cross_attention_kwargs=self.cross_attention_kwargs,
+ added_cond_kwargs=added_cond_kwargs,
+ return_dict=False,
+ )[0]
+
+ # perform guidance
+
+ # cfg
+ if self.do_classifier_free_guidance and not self.do_perturbed_attention_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+
+ delta = noise_pred_text - noise_pred_uncond
+ noise_pred = noise_pred_uncond + self.guidance_scale * delta
+
+ # pag
+ elif not self.do_classifier_free_guidance and self.do_perturbed_attention_guidance:
+ noise_pred_original, noise_pred_perturb = noise_pred.chunk(2)
+
+ signal_scale = self.pag_scale
+ if self.do_pag_adaptive_scaling:
+ signal_scale = self.pag_scale - self.pag_adaptive_scaling * (1000 - t)
+ if signal_scale < 0:
+ signal_scale = 0
+
+ noise_pred = noise_pred_original + signal_scale * (noise_pred_original - noise_pred_perturb)
+
+ # both
+ elif self.do_classifier_free_guidance and self.do_perturbed_attention_guidance:
+ noise_pred_uncond, noise_pred_text, noise_pred_text_perturb = noise_pred.chunk(3)
+
+ signal_scale = self.pag_scale
+ if self.do_pag_adaptive_scaling:
+ signal_scale = self.pag_scale - self.pag_adaptive_scaling * (1000 - t)
+ if signal_scale < 0:
+ signal_scale = 0
+
+ noise_pred = (
+ noise_pred_text
+ + (self.guidance_scale - 1.0) * (noise_pred_text - noise_pred_uncond)
+ + signal_scale * (noise_pred_text - noise_pred_text_perturb)
+ )
+
+ if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
+
+ if not output_type == "latent":
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
+ 0
+ ]
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
+ else:
+ image = latents
+ has_nsfw_concept = None
+
+ if has_nsfw_concept is None:
+ do_denormalize = [True] * image.shape[0]
+ else:
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
+
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ # change attention layer in UNet if use PAG
+ if self.do_perturbed_attention_guidance:
+ drop_layers = self.pag_applied_layers_index
+ for drop_layer in drop_layers:
+ try:
+ if drop_layer[0] == "d":
+ down_layers[int(drop_layer[1])].processor = AttnProcessor2_0()
+ elif drop_layer[0] == "m":
+ mid_layers[int(drop_layer[1])].processor = AttnProcessor2_0()
+ elif drop_layer[0] == "u":
+ up_layers[int(drop_layer[1])].processor = AttnProcessor2_0()
+ else:
+ raise ValueError(f"Invalid layer type: {drop_layer[0]}")
+ except IndexError:
+ raise ValueError(
+ f"Invalid layer index: {drop_layer}. Available layers: {len(down_layers)} down layers, {len(mid_layers)} mid layers, {len(up_layers)} up layers."
+ )
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
diff --git a/diffusers/examples/community/pipeline_stable_diffusion_upscale_ldm3d.py b/diffusers/examples/community/pipeline_stable_diffusion_upscale_ldm3d.py
new file mode 100644
index 0000000000000000000000000000000000000000..1ac651a1fe609ee474046ab7421e9184bc29ee13
--- /dev/null
+++ b/diffusers/examples/community/pipeline_stable_diffusion_upscale_ldm3d.py
@@ -0,0 +1,772 @@
+# Copyright 2024 The Intel Labs Team Authors and the HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import numpy as np
+import PIL
+import torch
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
+
+from diffusers import DiffusionPipeline
+from diffusers.image_processor import PipelineDepthInput, PipelineImageInput, VaeImageProcessorLDM3D
+from diffusers.loaders import FromSingleFileMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
+from diffusers.models import AutoencoderKL, UNet2DConditionModel
+from diffusers.models.lora import adjust_lora_scale_text_encoder
+from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
+from diffusers.pipelines.stable_diffusion_ldm3d.pipeline_stable_diffusion_ldm3d import LDM3DPipelineOutput
+from diffusers.schedulers import DDPMScheduler, KarrasDiffusionSchedulers
+from diffusers.utils import (
+ USE_PEFT_BACKEND,
+ deprecate,
+ logging,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
+from diffusers.utils.torch_utils import randn_tensor
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```python
+ >>> from diffusers import StableDiffusionUpscaleLDM3DPipeline
+ >>> from PIL import Image
+ >>> from io import BytesIO
+ >>> import requests
+
+ >>> pipe = StableDiffusionUpscaleLDM3DPipeline.from_pretrained("Intel/ldm3d-sr")
+ >>> pipe = pipe.to("cuda")
+ >>> rgb_path = "https://huggingface.co/Intel/ldm3d-sr/resolve/main/lemons_ldm3d_rgb.jpg"
+ >>> depth_path = "https://huggingface.co/Intel/ldm3d-sr/resolve/main/lemons_ldm3d_depth.png"
+ >>> low_res_rgb = Image.open(BytesIO(requests.get(rgb_path).content)).convert("RGB")
+ >>> low_res_depth = Image.open(BytesIO(requests.get(depth_path).content)).convert("L")
+ >>> output = pipe(
+ ... prompt="high quality high resolution uhd 4k image",
+ ... rgb=low_res_rgb,
+ ... depth=low_res_depth,
+ ... num_inference_steps=50,
+ ... target_res=[1024, 1024],
+ ... )
+ >>> rgb_image, depth_image = output.rgb, output.depth
+ >>> rgb_image[0].save("hr_ldm3d_rgb.jpg")
+ >>> depth_image[0].save("hr_ldm3d_depth.png")
+ ```
+"""
+
+
+class StableDiffusionUpscaleLDM3DPipeline(
+ DiffusionPipeline, TextualInversionLoaderMixin, StableDiffusionLoraLoaderMixin, FromSingleFileMixin
+):
+ r"""
+ Pipeline for text-to-image and 3D generation using LDM3D.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
+
+ The pipeline also inherits the following loading methods:
+ - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
+ - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
+ - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
+ - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
+ text_encoder ([`~transformers.CLIPTextModel`]):
+ Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
+ tokenizer ([`~transformers.CLIPTokenizer`]):
+ A `CLIPTokenizer` to tokenize text.
+ unet ([`UNet2DConditionModel`]):
+ A `UNet2DConditionModel` to denoise the encoded image latents.
+ low_res_scheduler ([`SchedulerMixin`]):
+ A scheduler used to add initial noise to the low resolution conditioning image. It must be an instance of
+ [`DDPMScheduler`].
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
+ about a model's potential harms.
+ feature_extractor ([`~transformers.CLIPImageProcessor`]):
+ A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
+ """
+
+ _optional_components = ["safety_checker", "feature_extractor"]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ low_res_scheduler: DDPMScheduler,
+ scheduler: KarrasDiffusionSchedulers,
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPImageProcessor,
+ requires_safety_checker: bool = True,
+ watermarker: Optional[Any] = None,
+ max_noise_level: int = 350,
+ ):
+ super().__init__()
+
+ if safety_checker is None and requires_safety_checker:
+ logger.warning(
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
+ )
+
+ if safety_checker is not None and feature_extractor is None:
+ raise ValueError(
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
+ )
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ low_res_scheduler=low_res_scheduler,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ watermarker=watermarker,
+ feature_extractor=feature_extractor,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessorLDM3D(vae_scale_factor=self.vae_scale_factor, resample="bilinear")
+ # self.register_to_config(requires_safety_checker=requires_safety_checker)
+ self.register_to_config(max_noise_level=max_noise_level)
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_ldm3d.StableDiffusionLDM3DPipeline._encode_prompt
+ def _encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt=None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ lora_scale: Optional[float] = None,
+ **kwargs,
+ ):
+ deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
+ deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
+
+ prompt_embeds_tuple = self.encode_prompt(
+ prompt=prompt,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ do_classifier_free_guidance=do_classifier_free_guidance,
+ negative_prompt=negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ lora_scale=lora_scale,
+ **kwargs,
+ )
+
+ # concatenate for backwards comp
+ prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
+
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_ldm3d.StableDiffusionLDM3DPipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt=None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ lora_scale: Optional[float] = None,
+ clip_skip: Optional[int] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ lora_scale (`float`, *optional*):
+ A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+ """
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ # dynamically adjust the LoRA scale
+ if not USE_PEFT_BACKEND:
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
+ else:
+ scale_lora_layers(self.text_encoder, lora_scale)
+
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ # textual inversion: process multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = self.tokenizer.batch_decode(
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
+ )
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = text_inputs.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ if clip_skip is None:
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
+ prompt_embeds = prompt_embeds[0]
+ else:
+ prompt_embeds = self.text_encoder(
+ text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True
+ )
+ # Access the `hidden_states` first, that contains a tuple of
+ # all the hidden states from the encoder layers. Then index into
+ # the tuple to access the hidden states from the desired layer.
+ prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
+ # We also need to apply the final LayerNorm here to not mess with the
+ # representations. The `last_hidden_states` that we typically use for
+ # obtaining the final prompt representations passes through the LayerNorm
+ # layer.
+ prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
+
+ if self.text_encoder is not None:
+ prompt_embeds_dtype = self.text_encoder.dtype
+ elif self.unet is not None:
+ prompt_embeds_dtype = self.unet.dtype
+ else:
+ prompt_embeds_dtype = prompt_embeds.dtype
+
+ prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ # textual inversion: process multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = uncond_input.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ negative_prompt_embeds = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ negative_prompt_embeds = negative_prompt_embeds[0]
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder, lora_scale)
+
+ return prompt_embeds, negative_prompt_embeds
+
+ def run_safety_checker(self, image, device, dtype):
+ if self.safety_checker is None:
+ has_nsfw_concept = None
+ else:
+ if torch.is_tensor(image):
+ feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
+ else:
+ feature_extractor_input = self.image_processor.numpy_to_pil(image)
+ rgb_feature_extractor_input = feature_extractor_input[0]
+ safety_checker_input = self.feature_extractor(rgb_feature_extractor_input, return_tensors="pt").to(device)
+ image, has_nsfw_concept = self.safety_checker(
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
+ )
+ return image, has_nsfw_concept
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(
+ self,
+ prompt,
+ image,
+ noise_level,
+ callback_steps,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ target_res=None,
+ ):
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ if (
+ not isinstance(image, torch.Tensor)
+ and not isinstance(image, PIL.Image.Image)
+ and not isinstance(image, np.ndarray)
+ and not isinstance(image, list)
+ ):
+ raise ValueError(
+ f"`image` has to be of type `torch.Tensor`, `np.ndarray`, `PIL.Image.Image` or `list` but is {type(image)}"
+ )
+
+ # verify batch size of prompt and image are same if image is a list or tensor or numpy array
+ if isinstance(image, (list, np.ndarray, torch.Tensor)):
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if isinstance(image, list):
+ image_batch_size = len(image)
+ else:
+ image_batch_size = image.shape[0]
+ if batch_size != image_batch_size:
+ raise ValueError(
+ f"`prompt` has batch size {batch_size} and `image` has batch size {image_batch_size}."
+ " Please make sure that passed `prompt` matches the batch size of `image`."
+ )
+
+ # check noise level
+ if noise_level > self.config.max_noise_level:
+ raise ValueError(f"`noise_level` has to be <= {self.config.max_noise_level} but is {noise_level}")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
+ shape = (batch_size, num_channels_latents, height, width)
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ if latents.shape != shape:
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ # def upcast_vae(self):
+ # dtype = self.vae.dtype
+ # self.vae.to(dtype=torch.float32)
+ # use_torch_2_0_or_xformers = isinstance(
+ # self.vae.decoder.mid_block.attentions[0].processor,
+ # (
+ # AttnProcessor2_0,
+ # XFormersAttnProcessor,
+ # LoRAXFormersAttnProcessor,
+ # LoRAAttnProcessor2_0,
+ # ),
+ # )
+ # # if xformers or torch_2_0 is used attention block does not need
+ # # to be in float32 which can save lots of memory
+ # if use_torch_2_0_or_xformers:
+ # self.vae.post_quant_conv.to(dtype)
+ # self.vae.decoder.conv_in.to(dtype)
+ # self.vae.decoder.mid_block.to(dtype)
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ rgb: PipelineImageInput = None,
+ depth: PipelineDepthInput = None,
+ num_inference_steps: int = 75,
+ guidance_scale: float = 9.0,
+ noise_level: int = 20,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
+ callback_steps: int = 1,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ target_res: Optional[List[int]] = [1024, 1024],
+ ):
+ r"""
+ The call function to the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
+ image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
+ `Image` or tensor representing an image batch to be upscaled.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 5.0):
+ A higher guidance scale value encourages the model to generate images closely linked to the text
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
+ generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor is generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
+ provided, text embeddings are generated from the `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that calls every `callback_steps` steps during inference. The function is called with the
+ following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function is called. If not specified, the callback is called at
+ every step.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
+ otherwise a `tuple` is returned where the first element is a list with the generated images and the
+ second element is a list of `bool`s indicating whether the corresponding generated image contains
+ "not-safe-for-work" (nsfw) content.
+ """
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ rgb,
+ noise_level,
+ callback_steps,
+ negative_prompt,
+ prompt_embeds,
+ negative_prompt_embeds,
+ )
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ )
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ if do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+
+ # 4. Preprocess image
+ rgb, depth = self.image_processor.preprocess(rgb, depth, target_res=target_res)
+ rgb = rgb.to(dtype=prompt_embeds.dtype, device=device)
+ depth = depth.to(dtype=prompt_embeds.dtype, device=device)
+
+ # 5. set timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # 6. Encode low resolutiom image to latent space
+ image = torch.cat([rgb, depth], axis=1)
+ latent_space_image = self.vae.encode(image).latent_dist.sample(generator)
+ latent_space_image *= self.vae.scaling_factor
+ noise_level = torch.tensor([noise_level], dtype=torch.long, device=device)
+ # noise_rgb = randn_tensor(rgb.shape, generator=generator, device=device, dtype=prompt_embeds.dtype)
+ # rgb = self.low_res_scheduler.add_noise(rgb, noise_rgb, noise_level)
+ # noise_depth = randn_tensor(depth.shape, generator=generator, device=device, dtype=prompt_embeds.dtype)
+ # depth = self.low_res_scheduler.add_noise(depth, noise_depth, noise_level)
+
+ batch_multiplier = 2 if do_classifier_free_guidance else 1
+ latent_space_image = torch.cat([latent_space_image] * batch_multiplier * num_images_per_prompt)
+ noise_level = torch.cat([noise_level] * latent_space_image.shape[0])
+
+ # 7. Prepare latent variables
+ height, width = latent_space_image.shape[2:]
+ num_channels_latents = self.vae.config.latent_channels
+
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 8. Check that sizes of image and latents match
+ num_channels_image = latent_space_image.shape[1]
+ if num_channels_latents + num_channels_image != self.unet.config.in_channels:
+ raise ValueError(
+ f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
+ f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
+ f" `num_channels_image`: {num_channels_image} "
+ f" = {num_channels_latents+num_channels_image}. Please verify the config of"
+ " `pipeline.unet` or your `image` input."
+ )
+
+ # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 10. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+
+ # concat latents, mask, masked_image_latents in the channel dimension
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+ latent_model_input = torch.cat([latent_model_input, latent_space_image], dim=1)
+
+ # predict the noise residual
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs=cross_attention_kwargs,
+ class_labels=noise_level,
+ return_dict=False,
+ )[0]
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, latents)
+
+ if not output_type == "latent":
+ # make sure the VAE is in float32 mode, as it overflows in float16
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
+
+ if needs_upcasting:
+ self.upcast_vae()
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
+
+ image = self.vae.decode(latents / self.vae.scaling_factor, return_dict=False)[0]
+
+ # cast back to fp16 if needed
+ if needs_upcasting:
+ self.vae.to(dtype=torch.float16)
+
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
+
+ else:
+ image = latents
+ has_nsfw_concept = None
+
+ if has_nsfw_concept is None:
+ do_denormalize = [True] * image.shape[0]
+ else:
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
+
+ rgb, depth = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
+
+ # 11. Apply watermark
+ if output_type == "pil" and self.watermarker is not None:
+ rgb = self.watermarker.apply_watermark(rgb)
+
+ # Offload last model to CPU
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.final_offload_hook.offload()
+
+ if not return_dict:
+ return ((rgb, depth), has_nsfw_concept)
+
+ return LDM3DPipelineOutput(rgb=rgb, depth=depth, nsfw_content_detected=has_nsfw_concept)
diff --git a/diffusers/examples/community/pipeline_stable_diffusion_xl_controlnet_adapter.py b/diffusers/examples/community/pipeline_stable_diffusion_xl_controlnet_adapter.py
new file mode 100644
index 0000000000000000000000000000000000000000..ae495979f3664cc5fceff0c07df2b18978db3ecc
--- /dev/null
+++ b/diffusers/examples/community/pipeline_stable_diffusion_xl_controlnet_adapter.py
@@ -0,0 +1,1401 @@
+# Copyright 2024 TencentARC and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import PIL.Image
+import torch
+import torch.nn.functional as F
+from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
+
+from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
+from diffusers.loaders import FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin
+from diffusers.models import AutoencoderKL, ControlNetModel, MultiAdapter, T2IAdapter, UNet2DConditionModel
+from diffusers.models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor
+from diffusers.models.lora import adjust_lora_scale_text_encoder
+from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
+from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
+from diffusers.schedulers import KarrasDiffusionSchedulers
+from diffusers.utils import (
+ PIL_INTERPOLATION,
+ USE_PEFT_BACKEND,
+ logging,
+ replace_example_docstring,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
+from diffusers.utils.torch_utils import is_compiled_module, randn_tensor
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import T2IAdapter, StableDiffusionXLAdapterPipeline, DDPMScheduler
+ >>> from diffusers.utils import load_image
+ >>> from controlnet_aux.midas import MidasDetector
+
+ >>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
+ >>> mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
+
+ >>> image = load_image(img_url).resize((1024, 1024))
+ >>> mask_image = load_image(mask_url).resize((1024, 1024))
+
+ >>> midas_depth = MidasDetector.from_pretrained(
+ ... "valhalla/t2iadapter-aux-models", filename="dpt_large_384.pt", model_type="dpt_large"
+ ... ).to("cuda")
+
+ >>> depth_image = midas_depth(
+ ... image, detect_resolution=512, image_resolution=1024
+ ... )
+
+ >>> model_id = "stabilityai/stable-diffusion-xl-base-1.0"
+
+ >>> adapter = T2IAdapter.from_pretrained(
+ ... "Adapter/t2iadapter",
+ ... subfolder="sketch_sdxl_1.0",
+ ... torch_dtype=torch.float16,
+ ... adapter_type="full_adapter_xl",
+ ... )
+
+ >>> controlnet = ControlNetModel.from_pretrained(
+ ... "diffusers/controlnet-depth-sdxl-1.0",
+ ... torch_dtype=torch.float16,
+ ... variant="fp16",
+ ... use_safetensors=True
+ ... ).to("cuda")
+
+ >>> scheduler = DDPMScheduler.from_pretrained(model_id, subfolder="scheduler")
+
+ >>> pipe = StableDiffusionXLAdapterPipeline.from_pretrained(
+ ... model_id,
+ ... adapter=adapter,
+ ... controlnet=controlnet,
+ ... torch_dtype=torch.float16,
+ ... variant="fp16",
+ ... scheduler=scheduler
+ ... ).to("cuda")
+
+ >>> strength = 0.5
+
+ >>> generator = torch.manual_seed(42)
+ >>> sketch_image_out = pipe(
+ ... prompt="a photo of a tiger sitting on a park bench",
+ ... negative_prompt="extra digit, fewer digits, cropped, worst quality, low quality",
+ ... adapter_image=depth_image,
+ ... control_image=mask_image,
+ ... adapter_conditioning_scale=strength,
+ ... controlnet_conditioning_scale=strength,
+ ... generator=generator,
+ ... guidance_scale=7.5,
+ ... ).images[0]
+ ```
+"""
+
+
+def _preprocess_adapter_image(image, height, width):
+ if isinstance(image, torch.Tensor):
+ return image
+ elif isinstance(image, PIL.Image.Image):
+ image = [image]
+
+ if isinstance(image[0], PIL.Image.Image):
+ image = [np.array(i.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])) for i in image]
+ image = [
+ i[None, ..., None] if i.ndim == 2 else i[None, ...] for i in image
+ ] # expand [h, w] or [h, w, c] to [b, h, w, c]
+ image = np.concatenate(image, axis=0)
+ image = np.array(image).astype(np.float32) / 255.0
+ image = image.transpose(0, 3, 1, 2)
+ image = torch.from_numpy(image)
+ elif isinstance(image[0], torch.Tensor):
+ if image[0].ndim == 3:
+ image = torch.stack(image, dim=0)
+ elif image[0].ndim == 4:
+ image = torch.cat(image, dim=0)
+ else:
+ raise ValueError(
+ f"Invalid image tensor! Expecting image tensor with 3 or 4 dimension, but recive: {image[0].ndim}"
+ )
+ return image
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
+def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
+ """
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
+ """
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
+ # rescale the results from guidance (fixes overexposure)
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
+ return noise_cfg
+
+
+class StableDiffusionXLControlNetAdapterPipeline(
+ DiffusionPipeline,
+ StableDiffusionMixin,
+ FromSingleFileMixin,
+ StableDiffusionXLLoraLoaderMixin,
+ TextualInversionLoaderMixin,
+):
+ r"""
+ Pipeline for text-to-image generation using Stable Diffusion augmented with T2I-Adapter
+ https://arxiv.org/abs/2302.08453
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ adapter ([`T2IAdapter`] or [`MultiAdapter`] or `List[T2IAdapter]`):
+ Provides additional conditioning to the unet during the denoising process. If you set multiple Adapter as a
+ list, the outputs from each Adapter are added together to create one combined additional conditioning.
+ adapter_weights (`List[float]`, *optional*, defaults to None):
+ List of floats representing the weight which will be multiply to each adapter's output before adding them
+ together.
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ feature_extractor ([`CLIPImageProcessor`]):
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+ """
+
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae"
+ _optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ text_encoder_2: CLIPTextModelWithProjection,
+ tokenizer: CLIPTokenizer,
+ tokenizer_2: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ adapter: Union[T2IAdapter, MultiAdapter, List[T2IAdapter]],
+ controlnet: Union[ControlNetModel, MultiControlNetModel],
+ scheduler: KarrasDiffusionSchedulers,
+ force_zeros_for_empty_prompt: bool = True,
+ ):
+ super().__init__()
+
+ if isinstance(controlnet, (list, tuple)):
+ controlnet = MultiControlNetModel(controlnet)
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ text_encoder_2=text_encoder_2,
+ tokenizer=tokenizer,
+ tokenizer_2=tokenizer_2,
+ unet=unet,
+ adapter=adapter,
+ controlnet=controlnet,
+ scheduler=scheduler,
+ )
+ self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+ self.control_image_processor = VaeImageProcessor(
+ vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
+ )
+ self.default_sample_size = self.unet.config.sample_size
+
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt: str,
+ prompt_2: Optional[str] = None,
+ device: Optional[torch.device] = None,
+ num_images_per_prompt: int = 1,
+ do_classifier_free_guidance: bool = True,
+ negative_prompt: Optional[str] = None,
+ negative_prompt_2: Optional[str] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ pooled_prompt_embeds: Optional[torch.Tensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
+ lora_scale: Optional[float] = None,
+ clip_skip: Optional[int] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ used in both text-encoders
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ pooled_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
+ input argument.
+ lora_scale (`float`, *optional*):
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+ """
+ device = device or self._execution_device
+
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ # dynamically adjust the LoRA scale
+ if self.text_encoder is not None:
+ if not USE_PEFT_BACKEND:
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
+ else:
+ scale_lora_layers(self.text_encoder, lora_scale)
+
+ if self.text_encoder_2 is not None:
+ if not USE_PEFT_BACKEND:
+ adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
+ else:
+ scale_lora_layers(self.text_encoder_2, lora_scale)
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # Define tokenizers and text encoders
+ tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
+ text_encoders = (
+ [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
+ )
+
+ if prompt_embeds is None:
+ prompt_2 = prompt_2 or prompt
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
+
+ # textual inversion: process multi-vector tokens if necessary
+ prompt_embeds_list = []
+ prompts = [prompt, prompt_2]
+ for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, tokenizer)
+
+ text_inputs = tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
+
+ # We are only ALWAYS interested in the pooled output of the final text encoder
+ pooled_prompt_embeds = prompt_embeds[0]
+ if clip_skip is None:
+ prompt_embeds = prompt_embeds.hidden_states[-2]
+ else:
+ # "2" because SDXL always indexes from the penultimate layer.
+ prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
+
+ prompt_embeds_list.append(prompt_embeds)
+
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
+
+ # get unconditional embeddings for classifier free guidance
+ zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt
+ if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:
+ negative_prompt_embeds = torch.zeros_like(prompt_embeds)
+ negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
+ elif do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = negative_prompt or ""
+ negative_prompt_2 = negative_prompt_2 or negative_prompt
+
+ # normalize str to list
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+ negative_prompt_2 = (
+ batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
+ )
+
+ uncond_tokens: List[str]
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = [negative_prompt, negative_prompt_2]
+
+ negative_prompt_embeds_list = []
+ for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):
+ if isinstance(self, TextualInversionLoaderMixin):
+ negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer)
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = tokenizer(
+ negative_prompt,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ negative_prompt_embeds = text_encoder(
+ uncond_input.input_ids.to(device),
+ output_hidden_states=True,
+ )
+ # We are only ALWAYS interested in the pooled output of the final text encoder
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
+ negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
+
+ negative_prompt_embeds_list.append(negative_prompt_embeds)
+
+ negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
+
+ if self.text_encoder_2 is not None:
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
+ else:
+ prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ if self.text_encoder_2 is not None:
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
+ else:
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
+ bs_embed * num_images_per_prompt, -1
+ )
+ if do_classifier_free_guidance:
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
+ bs_embed * num_images_per_prompt, -1
+ )
+
+ if self.text_encoder is not None:
+ if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder, lora_scale)
+
+ if self.text_encoder_2 is not None:
+ if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
+
+ return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image
+ def check_image(self, image, prompt, prompt_embeds):
+ image_is_pil = isinstance(image, PIL.Image.Image)
+ image_is_tensor = isinstance(image, torch.Tensor)
+ image_is_np = isinstance(image, np.ndarray)
+ image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)
+ image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)
+ image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray)
+
+ if (
+ not image_is_pil
+ and not image_is_tensor
+ and not image_is_np
+ and not image_is_pil_list
+ and not image_is_tensor_list
+ and not image_is_np_list
+ ):
+ raise TypeError(
+ f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}"
+ )
+
+ if image_is_pil:
+ image_batch_size = 1
+ else:
+ image_batch_size = len(image)
+
+ if prompt is not None and isinstance(prompt, str):
+ prompt_batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ prompt_batch_size = len(prompt)
+ elif prompt_embeds is not None:
+ prompt_batch_size = prompt_embeds.shape[0]
+
+ if image_batch_size != 1 and image_batch_size != prompt_batch_size:
+ raise ValueError(
+ f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
+ )
+
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.check_inputs
+ def check_inputs(
+ self,
+ prompt,
+ prompt_2,
+ height,
+ width,
+ callback_steps,
+ negative_prompt=None,
+ negative_prompt_2=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ pooled_prompt_embeds=None,
+ negative_pooled_prompt_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ ):
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt_2 is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
+ raise ValueError(
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
+ )
+
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
+ raise ValueError(
+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
+ )
+
+ def check_conditions(
+ self,
+ prompt,
+ prompt_embeds,
+ adapter_image,
+ control_image,
+ adapter_conditioning_scale,
+ controlnet_conditioning_scale,
+ control_guidance_start,
+ control_guidance_end,
+ ):
+ # controlnet checks
+ if not isinstance(control_guidance_start, (tuple, list)):
+ control_guidance_start = [control_guidance_start]
+
+ if not isinstance(control_guidance_end, (tuple, list)):
+ control_guidance_end = [control_guidance_end]
+
+ if len(control_guidance_start) != len(control_guidance_end):
+ raise ValueError(
+ f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list."
+ )
+
+ if isinstance(self.controlnet, MultiControlNetModel):
+ if len(control_guidance_start) != len(self.controlnet.nets):
+ raise ValueError(
+ f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}."
+ )
+
+ for start, end in zip(control_guidance_start, control_guidance_end):
+ if start >= end:
+ raise ValueError(
+ f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}."
+ )
+ if start < 0.0:
+ raise ValueError(f"control guidance start: {start} can't be smaller than 0.")
+ if end > 1.0:
+ raise ValueError(f"control guidance end: {end} can't be larger than 1.0.")
+
+ # Check controlnet `image`
+ is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
+ self.controlnet, torch._dynamo.eval_frame.OptimizedModule
+ )
+ if (
+ isinstance(self.controlnet, ControlNetModel)
+ or is_compiled
+ and isinstance(self.controlnet._orig_mod, ControlNetModel)
+ ):
+ self.check_image(control_image, prompt, prompt_embeds)
+ elif (
+ isinstance(self.controlnet, MultiControlNetModel)
+ or is_compiled
+ and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
+ ):
+ if not isinstance(control_image, list):
+ raise TypeError("For multiple controlnets: `control_image` must be type `list`")
+
+ # When `image` is a nested list:
+ # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]])
+ elif any(isinstance(i, list) for i in control_image):
+ raise ValueError("A single batch of multiple conditionings are supported at the moment.")
+ elif len(control_image) != len(self.controlnet.nets):
+ raise ValueError(
+ f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(control_image)} images and {len(self.controlnet.nets)} ControlNets."
+ )
+
+ for image_ in control_image:
+ self.check_image(image_, prompt, prompt_embeds)
+ else:
+ assert False
+
+ # Check `controlnet_conditioning_scale`
+ if (
+ isinstance(self.controlnet, ControlNetModel)
+ or is_compiled
+ and isinstance(self.controlnet._orig_mod, ControlNetModel)
+ ):
+ if not isinstance(controlnet_conditioning_scale, float):
+ raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
+ elif (
+ isinstance(self.controlnet, MultiControlNetModel)
+ or is_compiled
+ and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
+ ):
+ if isinstance(controlnet_conditioning_scale, list):
+ if any(isinstance(i, list) for i in controlnet_conditioning_scale):
+ raise ValueError("A single batch of multiple conditionings are supported at the moment.")
+ elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len(
+ self.controlnet.nets
+ ):
+ raise ValueError(
+ "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have"
+ " the same length as the number of controlnets"
+ )
+ else:
+ assert False
+
+ # adapter checks
+ if isinstance(self.adapter, T2IAdapter) or is_compiled and isinstance(self.adapter._orig_mod, T2IAdapter):
+ self.check_image(adapter_image, prompt, prompt_embeds)
+ elif (
+ isinstance(self.adapter, MultiAdapter) or is_compiled and isinstance(self.adapter._orig_mod, MultiAdapter)
+ ):
+ if not isinstance(adapter_image, list):
+ raise TypeError("For multiple adapters: `adapter_image` must be type `list`")
+
+ # When `image` is a nested list:
+ # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]])
+ elif any(isinstance(i, list) for i in adapter_image):
+ raise ValueError("A single batch of multiple conditionings are supported at the moment.")
+ elif len(adapter_image) != len(self.adapter.adapters):
+ raise ValueError(
+ f"For multiple adapters: `image` must have the same length as the number of adapters, but got {len(adapter_image)} images and {len(self.adapters.nets)} Adapters."
+ )
+
+ for image_ in adapter_image:
+ self.check_image(image_, prompt, prompt_embeds)
+ else:
+ assert False
+
+ # Check `adapter_conditioning_scale`
+ if isinstance(self.adapter, T2IAdapter) or is_compiled and isinstance(self.adapter._orig_mod, T2IAdapter):
+ if not isinstance(adapter_conditioning_scale, float):
+ raise TypeError("For single adapter: `adapter_conditioning_scale` must be type `float`.")
+ elif (
+ isinstance(self.adapter, MultiAdapter) or is_compiled and isinstance(self.adapter._orig_mod, MultiAdapter)
+ ):
+ if isinstance(adapter_conditioning_scale, list):
+ if any(isinstance(i, list) for i in adapter_conditioning_scale):
+ raise ValueError("A single batch of multiple conditionings are supported at the moment.")
+ elif isinstance(adapter_conditioning_scale, list) and len(adapter_conditioning_scale) != len(
+ self.adapter.adapters
+ ):
+ raise ValueError(
+ "For multiple adapters: When `adapter_conditioning_scale` is specified as `list`, it must have"
+ " the same length as the number of adapters"
+ )
+ else:
+ assert False
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
+ shape = (
+ batch_size,
+ num_channels_latents,
+ int(height) // self.vae_scale_factor,
+ int(width) // self.vae_scale_factor,
+ )
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids
+ def _get_add_time_ids(
+ self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None
+ ):
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
+
+ passed_add_embed_dim = (
+ self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
+ )
+ expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
+
+ if expected_add_embed_dim != passed_add_embed_dim:
+ raise ValueError(
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
+ )
+
+ add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
+ return add_time_ids
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
+ def upcast_vae(self):
+ dtype = self.vae.dtype
+ self.vae.to(dtype=torch.float32)
+ use_torch_2_0_or_xformers = isinstance(
+ self.vae.decoder.mid_block.attentions[0].processor,
+ (AttnProcessor2_0, XFormersAttnProcessor),
+ )
+ # if xformers or torch_2_0 is used attention block does not need
+ # to be in float32 which can save lots of memory
+ if use_torch_2_0_or_xformers:
+ self.vae.post_quant_conv.to(dtype)
+ self.vae.decoder.conv_in.to(dtype)
+ self.vae.decoder.mid_block.to(dtype)
+
+ # Copied from diffusers.pipelines.t2i_adapter.pipeline_stable_diffusion_adapter.StableDiffusionAdapterPipeline._default_height_width
+ def _default_height_width(self, height, width, image):
+ # NOTE: It is possible that a list of images have different
+ # dimensions for each image, so just checking the first image
+ # is not _exactly_ correct, but it is simple.
+ while isinstance(image, list):
+ image = image[0]
+
+ if height is None:
+ if isinstance(image, PIL.Image.Image):
+ height = image.height
+ elif isinstance(image, torch.Tensor):
+ height = image.shape[-2]
+
+ # round down to nearest multiple of `self.adapter.downscale_factor`
+ height = (height // self.adapter.downscale_factor) * self.adapter.downscale_factor
+
+ if width is None:
+ if isinstance(image, PIL.Image.Image):
+ width = image.width
+ elif isinstance(image, torch.Tensor):
+ width = image.shape[-1]
+
+ # round down to nearest multiple of `self.adapter.downscale_factor`
+ width = (width // self.adapter.downscale_factor) * self.adapter.downscale_factor
+
+ return height, width
+
+ def prepare_control_image(
+ self,
+ image,
+ width,
+ height,
+ batch_size,
+ num_images_per_prompt,
+ device,
+ dtype,
+ do_classifier_free_guidance=False,
+ guess_mode=False,
+ ):
+ image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
+ image_batch_size = image.shape[0]
+
+ if image_batch_size == 1:
+ repeat_by = batch_size
+ else:
+ # image batch size is the same as prompt batch size
+ repeat_by = num_images_per_prompt
+
+ image = image.repeat_interleave(repeat_by, dim=0)
+
+ image = image.to(device=device, dtype=dtype)
+
+ if do_classifier_free_guidance and not guess_mode:
+ image = torch.cat([image] * 2)
+
+ return image
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ prompt_2: Optional[Union[str, List[str]]] = None,
+ adapter_image: PipelineImageInput = None,
+ control_image: PipelineImageInput = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ denoising_end: Optional[float] = None,
+ guidance_scale: float = 5.0,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ pooled_prompt_embeds: Optional[torch.Tensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
+ callback_steps: int = 1,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ guidance_rescale: float = 0.0,
+ original_size: Optional[Tuple[int, int]] = None,
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
+ target_size: Optional[Tuple[int, int]] = None,
+ negative_original_size: Optional[Tuple[int, int]] = None,
+ negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
+ negative_target_size: Optional[Tuple[int, int]] = None,
+ adapter_conditioning_scale: Union[float, List[float]] = 1.0,
+ adapter_conditioning_factor: float = 1.0,
+ clip_skip: Optional[int] = None,
+ controlnet_conditioning_scale=1.0,
+ guess_mode: bool = False,
+ control_guidance_start: float = 0.0,
+ control_guidance_end: float = 1.0,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ used in both text-encoders
+ adapter_image (`torch.Tensor`, `PIL.Image.Image`, `List[torch.Tensor]` or `List[PIL.Image.Image]` or `List[List[PIL.Image.Image]]`):
+ The Adapter input condition. Adapter uses this input condition to generate guidance to Unet. If the
+ type is specified as `torch.Tensor`, it is passed to Adapter as is. PIL.Image.Image` can also be
+ accepted as an image. The control image is automatically resized to fit the output image.
+ control_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
+ `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
+ The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
+ specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be
+ accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height
+ and/or width are passed, `image` is resized accordingly. If multiple ControlNets are specified in
+ `init`, images must be passed as a list such that each element of the list can be correctly batched for
+ input to a single ControlNet.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image. Anything below 512 pixels won't work well for
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
+ and checkpoints that are not specifically fine-tuned on low resolutions.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image. Anything below 512 pixels won't work well for
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
+ and checkpoints that are not specifically fine-tuned on low resolutions.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ denoising_end (`float`, *optional*):
+ When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
+ completed before it is intentionally prematurely terminated. As a result, the returned sample will
+ still retain a substantial amount of noise as determined by the discrete timesteps selected by the
+ scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a
+ "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
+ Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
+ guidance_scale (`float`, *optional*, defaults to 5.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ pooled_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
+ input argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionAdapterPipelineOutput`]
+ instead of a plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
+ Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
+ [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
+ Guidance rescale factor should fix overexposure when using zero terminal SNR.
+ original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
+ `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
+ explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
+ `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
+ `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
+ `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ For most cases, `target_size` should be set to the desired height and width of the generated image. If
+ not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in
+ section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ To negatively condition the generation process based on a specific image resolution. Part of SDXL's
+ micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
+ negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
+ To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
+ micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
+ negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ To negatively condition the generation process based on a target image resolution. It should be as same
+ as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
+ controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
+ The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added to the
+ residual in the original unet. If multiple adapters are specified in init, you can set the
+ corresponding scale as a list.
+ adapter_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
+ The outputs of the adapter are multiplied by `adapter_conditioning_scale` before they are added to the
+ residual in the original unet. If multiple adapters are specified in init, you can set the
+ corresponding scale as a list.
+ adapter_conditioning_factor (`float`, *optional*, defaults to 1.0):
+ The fraction of timesteps for which adapter should be applied. If `adapter_conditioning_factor` is
+ `0.0`, adapter is not applied at all. If `adapter_conditioning_factor` is `1.0`, adapter is applied for
+ all timesteps. If `adapter_conditioning_factor` is `0.5`, adapter is applied for half of the timesteps.
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionAdapterPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionAdapterPipelineOutput`] if `return_dict` is True, otherwise a
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
+ """
+ controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
+ adapter = self.adapter._orig_mod if is_compiled_module(self.adapter) else self.adapter
+
+ # 0. Default height and width to unet
+
+ height, width = self._default_height_width(height, width, adapter_image)
+ device = self._execution_device
+
+ if isinstance(adapter, MultiAdapter):
+ adapter_input = []
+
+ for one_image in adapter_image:
+ one_image = _preprocess_adapter_image(one_image, height, width)
+ one_image = one_image.to(device=device, dtype=adapter.dtype)
+ adapter_input.append(one_image)
+ else:
+ adapter_input = _preprocess_adapter_image(adapter_image, height, width)
+ adapter_input = adapter_input.to(device=device, dtype=adapter.dtype)
+ original_size = original_size or (height, width)
+ target_size = target_size or (height, width)
+
+ # 0.1 align format for control guidance
+ if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
+ control_guidance_start = len(control_guidance_end) * [control_guidance_start]
+ elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
+ control_guidance_end = len(control_guidance_start) * [control_guidance_end]
+ elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
+ mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
+ control_guidance_start, control_guidance_end = (
+ mult * [control_guidance_start],
+ mult * [control_guidance_end],
+ )
+
+ if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
+ controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
+ if isinstance(adapter, MultiAdapter) and isinstance(adapter_conditioning_scale, float):
+ adapter_conditioning_scale = [adapter_conditioning_scale] * len(adapter.adapters)
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ prompt_2,
+ height,
+ width,
+ callback_steps,
+ negative_prompt=negative_prompt,
+ negative_prompt_2=negative_prompt_2,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ )
+
+ self.check_conditions(
+ prompt,
+ prompt_embeds,
+ adapter_image,
+ control_image,
+ adapter_conditioning_scale,
+ controlnet_conditioning_scale,
+ control_guidance_start,
+ control_guidance_end,
+ )
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ (
+ prompt_embeds,
+ negative_prompt_embeds,
+ pooled_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ ) = self.encode_prompt(
+ prompt=prompt,
+ prompt_2=prompt_2,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ do_classifier_free_guidance=do_classifier_free_guidance,
+ negative_prompt=negative_prompt,
+ negative_prompt_2=negative_prompt_2,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ clip_skip=clip_skip,
+ )
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+
+ timesteps = self.scheduler.timesteps
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.unet.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 7. Prepare added time ids & embeddings & adapter features
+ if isinstance(adapter, MultiAdapter):
+ adapter_state = adapter(adapter_input, adapter_conditioning_scale)
+ for k, v in enumerate(adapter_state):
+ adapter_state[k] = v
+ else:
+ adapter_state = adapter(adapter_input)
+ for k, v in enumerate(adapter_state):
+ adapter_state[k] = v * adapter_conditioning_scale
+ if num_images_per_prompt > 1:
+ for k, v in enumerate(adapter_state):
+ adapter_state[k] = v.repeat(num_images_per_prompt, 1, 1, 1)
+ if do_classifier_free_guidance:
+ for k, v in enumerate(adapter_state):
+ adapter_state[k] = torch.cat([v] * 2, dim=0)
+
+ # 7.2 Prepare control images
+ if isinstance(controlnet, ControlNetModel):
+ control_image = self.prepare_control_image(
+ image=control_image,
+ width=width,
+ height=height,
+ batch_size=batch_size * num_images_per_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ dtype=controlnet.dtype,
+ do_classifier_free_guidance=do_classifier_free_guidance,
+ guess_mode=guess_mode,
+ )
+ elif isinstance(controlnet, MultiControlNetModel):
+ control_images = []
+
+ for control_image_ in control_image:
+ control_image_ = self.prepare_control_image(
+ image=control_image_,
+ width=width,
+ height=height,
+ batch_size=batch_size * num_images_per_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ dtype=controlnet.dtype,
+ do_classifier_free_guidance=do_classifier_free_guidance,
+ guess_mode=guess_mode,
+ )
+
+ control_images.append(control_image_)
+
+ control_image = control_images
+ else:
+ raise ValueError(f"{controlnet.__class__} is not supported.")
+
+ # 8.2 Create tensor stating which controlnets to keep
+ controlnet_keep = []
+ for i in range(len(timesteps)):
+ keeps = [
+ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
+ for s, e in zip(control_guidance_start, control_guidance_end)
+ ]
+ if isinstance(self.controlnet, MultiControlNetModel):
+ controlnet_keep.append(keeps)
+ else:
+ controlnet_keep.append(keeps[0])
+
+ add_text_embeds = pooled_prompt_embeds
+ if self.text_encoder_2 is None:
+ text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
+ else:
+ text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
+
+ add_time_ids = self._get_add_time_ids(
+ original_size,
+ crops_coords_top_left,
+ target_size,
+ dtype=prompt_embeds.dtype,
+ text_encoder_projection_dim=text_encoder_projection_dim,
+ )
+ if negative_original_size is not None and negative_target_size is not None:
+ negative_add_time_ids = self._get_add_time_ids(
+ negative_original_size,
+ negative_crops_coords_top_left,
+ negative_target_size,
+ dtype=prompt_embeds.dtype,
+ text_encoder_projection_dim=text_encoder_projection_dim,
+ )
+ else:
+ negative_add_time_ids = add_time_ids
+
+ if do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
+ add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
+
+ prompt_embeds = prompt_embeds.to(device)
+ add_text_embeds = add_text_embeds.to(device)
+ add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
+
+ # 8. Denoising loop
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+
+ # 7.1 Apply denoising_end
+ if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1:
+ discrete_timestep_cutoff = int(
+ round(
+ self.scheduler.config.num_train_timesteps
+ - (denoising_end * self.scheduler.config.num_train_timesteps)
+ )
+ )
+ num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
+ timesteps = timesteps[:num_inference_steps]
+
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
+
+ if i < int(num_inference_steps * adapter_conditioning_factor):
+ down_intrablock_additional_residuals = [state.clone() for state in adapter_state]
+ else:
+ down_intrablock_additional_residuals = None
+
+ # ----------- ControlNet
+
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input_controlnet = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+
+ # concat latents, mask, masked_image_latents in the channel dimension
+ latent_model_input_controlnet = self.scheduler.scale_model_input(latent_model_input_controlnet, t)
+
+ # controlnet(s) inference
+ if guess_mode and do_classifier_free_guidance:
+ # Infer ControlNet only for the conditional batch.
+ control_model_input = latents
+ control_model_input = self.scheduler.scale_model_input(control_model_input, t)
+ controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
+ controlnet_added_cond_kwargs = {
+ "text_embeds": add_text_embeds.chunk(2)[1],
+ "time_ids": add_time_ids.chunk(2)[1],
+ }
+ else:
+ control_model_input = latent_model_input_controlnet
+ controlnet_prompt_embeds = prompt_embeds
+ controlnet_added_cond_kwargs = added_cond_kwargs
+
+ if isinstance(controlnet_keep[i], list):
+ cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
+ else:
+ controlnet_cond_scale = controlnet_conditioning_scale
+ if isinstance(controlnet_cond_scale, list):
+ controlnet_cond_scale = controlnet_cond_scale[0]
+ cond_scale = controlnet_cond_scale * controlnet_keep[i]
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
+ control_model_input,
+ t,
+ encoder_hidden_states=controlnet_prompt_embeds,
+ controlnet_cond=control_image,
+ conditioning_scale=cond_scale,
+ guess_mode=guess_mode,
+ added_cond_kwargs=controlnet_added_cond_kwargs,
+ return_dict=False,
+ )
+
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs=cross_attention_kwargs,
+ added_cond_kwargs=added_cond_kwargs,
+ return_dict=False,
+ down_intrablock_additional_residuals=down_intrablock_additional_residuals, # t2iadapter
+ down_block_additional_residuals=down_block_res_samples, # controlnet
+ mid_block_additional_residual=mid_block_res_sample, # controlnet
+ )[0]
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
+
+ if not output_type == "latent":
+ # make sure the VAE is in float32 mode, as it overflows in float16
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
+
+ if needs_upcasting:
+ self.upcast_vae()
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
+
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+
+ # cast back to fp16 if needed
+ if needs_upcasting:
+ self.vae.to(dtype=torch.float16)
+ else:
+ image = latents
+ return StableDiffusionXLPipelineOutput(images=image)
+
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image,)
+
+ return StableDiffusionXLPipelineOutput(images=image)
diff --git a/diffusers/examples/community/pipeline_stable_diffusion_xl_controlnet_adapter_inpaint.py b/diffusers/examples/community/pipeline_stable_diffusion_xl_controlnet_adapter_inpaint.py
new file mode 100644
index 0000000000000000000000000000000000000000..94ca71cf7b1bf61fb465c5ac134405286ff386c9
--- /dev/null
+++ b/diffusers/examples/community/pipeline_stable_diffusion_xl_controlnet_adapter_inpaint.py
@@ -0,0 +1,1840 @@
+# Copyright 2024 Jake Babbidge, TencentARC and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# ignore the entire file for precommit
+# type: ignore
+
+import inspect
+from collections.abc import Callable
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import PIL
+import torch
+import torch.nn.functional as F
+from transformers import (
+ CLIPTextModel,
+ CLIPTextModelWithProjection,
+ CLIPTokenizer,
+)
+
+from diffusers import DiffusionPipeline
+from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
+from diffusers.loaders import (
+ FromSingleFileMixin,
+ StableDiffusionLoraLoaderMixin,
+ StableDiffusionXLLoraLoaderMixin,
+ TextualInversionLoaderMixin,
+)
+from diffusers.models import (
+ AutoencoderKL,
+ ControlNetModel,
+ MultiAdapter,
+ T2IAdapter,
+ UNet2DConditionModel,
+)
+from diffusers.models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor
+from diffusers.models.lora import adjust_lora_scale_text_encoder
+from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
+from diffusers.pipelines.pipeline_utils import StableDiffusionMixin
+from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
+from diffusers.schedulers import KarrasDiffusionSchedulers
+from diffusers.utils import (
+ PIL_INTERPOLATION,
+ USE_PEFT_BACKEND,
+ logging,
+ replace_example_docstring,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
+from diffusers.utils.torch_utils import is_compiled_module, randn_tensor
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import DiffusionPipeline, T2IAdapter
+ >>> from diffusers.utils import load_image
+ >>> from PIL import Image
+ >>> from controlnet_aux.midas import MidasDetector
+
+ >>> adapter = T2IAdapter.from_pretrained(
+ ... "TencentARC/t2i-adapter-sketch-sdxl-1.0", torch_dtype=torch.float16, variant="fp16"
+ ... ).to("cuda")
+
+ >>> controlnet = ControlNetModel.from_pretrained(
+ ... "diffusers/controlnet-depth-sdxl-1.0",
+ ... torch_dtype=torch.float16,
+ ... variant="fp16",
+ ... use_safetensors=True
+ ... ).to("cuda")
+
+ >>> pipe = DiffusionPipeline.from_pretrained(
+ ... "diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
+ ... torch_dtype=torch.float16,
+ ... variant="fp16",
+ ... use_safetensors=True,
+ ... custom_pipeline="stable_diffusion_xl_adapter_controlnet_inpaint",
+ ... adapter=adapter,
+ ... controlnet=controlnet,
+ ... ).to("cuda")
+
+ >>> prompt = "a tiger sitting on a park bench"
+ >>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
+ >>> mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
+
+ >>> image = load_image(img_url).resize((1024, 1024))
+ >>> mask_image = load_image(mask_url).resize((1024, 1024))
+
+ >>> midas_depth = MidasDetector.from_pretrained(
+ ... "valhalla/t2iadapter-aux-models", filename="dpt_large_384.pt", model_type="dpt_large"
+ ... ).to("cuda")
+
+ >>> depth_image = midas_depth(
+ ... image, detect_resolution=512, image_resolution=1024
+ ... )
+
+ >>> strength = 0.4
+
+ >>> generator = torch.manual_seed(42)
+
+ >>> result_image = pipe(
+ ... image=image,
+ ... mask_image=mask,
+ ... adapter_image=depth_image,
+ ... control_image=depth_image,
+ ... controlnet_conditioning_scale=strength,
+ ... adapter_conditioning_scale=strength,
+ ... strength=0.7,
+ ... generator=generator,
+ ... prompt=prompt,
+ ... negative_prompt="extra digit, fewer digits, cropped, worst quality, low quality",
+ ... num_inference_steps=50
+ ... ).images[0]
+ ```
+"""
+
+
+def _preprocess_adapter_image(image, height, width):
+ if isinstance(image, torch.Tensor):
+ return image
+ elif isinstance(image, PIL.Image.Image):
+ image = [image]
+
+ if isinstance(image[0], PIL.Image.Image):
+ image = [np.array(i.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])) for i in image]
+ image = [
+ i[None, ..., None] if i.ndim == 2 else i[None, ...] for i in image
+ ] # expand [h, w] or [h, w, c] to [b, h, w, c]
+ image = np.concatenate(image, axis=0)
+ image = np.array(image).astype(np.float32) / 255.0
+ image = image.transpose(0, 3, 1, 2)
+ image = torch.from_numpy(image)
+ elif isinstance(image[0], torch.Tensor):
+ if image[0].ndim == 3:
+ image = torch.stack(image, dim=0)
+ elif image[0].ndim == 4:
+ image = torch.cat(image, dim=0)
+ else:
+ raise ValueError(
+ f"Invalid image tensor! Expecting image tensor with 3 or 4 dimension, but recive: {image[0].ndim}"
+ )
+ return image
+
+
+def mask_pil_to_torch(mask, height, width):
+ # preprocess mask
+ if isinstance(mask, Union[PIL.Image.Image, np.ndarray]):
+ mask = [mask]
+
+ if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image):
+ mask = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in mask]
+ mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0)
+ mask = mask.astype(np.float32) / 255.0
+ elif isinstance(mask, list) and isinstance(mask[0], np.ndarray):
+ mask = np.concatenate([m[None, None, :] for m in mask], axis=0)
+
+ mask = torch.from_numpy(mask)
+ return mask
+
+
+def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool = False):
+ """
+ Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be
+ converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the
+ ``image`` and ``1`` for the ``mask``.
+
+ The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be
+ binarized (``mask > 0.5``) and cast to ``torch.float32`` too.
+
+ Args:
+ image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint.
+ It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width``
+ ``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``.
+ mask (_type_): The mask to apply to the image, i.e. regions to inpaint.
+ It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width``
+ ``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``.
+
+
+ Raises:
+ ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask
+ should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions.
+ TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not
+ (ot the other way around).
+
+ Returns:
+ tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4
+ dimensions: ``batch x channels x height x width``.
+ """
+
+ # checkpoint. #TODO(Yiyi) - need to clean this up later
+ if image is None:
+ raise ValueError("`image` input cannot be undefined.")
+
+ if mask is None:
+ raise ValueError("`mask_image` input cannot be undefined.")
+
+ if isinstance(image, torch.Tensor):
+ if not isinstance(mask, torch.Tensor):
+ mask = mask_pil_to_torch(mask, height, width)
+
+ if image.ndim == 3:
+ image = image.unsqueeze(0)
+
+ # Batch and add channel dim for single mask
+ if mask.ndim == 2:
+ mask = mask.unsqueeze(0).unsqueeze(0)
+
+ # Batch single mask or add channel dim
+ if mask.ndim == 3:
+ # Single batched mask, no channel dim or single mask not batched but channel dim
+ if mask.shape[0] == 1:
+ mask = mask.unsqueeze(0)
+
+ # Batched masks no channel dim
+ else:
+ mask = mask.unsqueeze(1)
+
+ assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions"
+ # assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions"
+ assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size"
+
+ # Check image is in [-1, 1]
+ # if image.min() < -1 or image.max() > 1:
+ # raise ValueError("Image should be in [-1, 1] range")
+
+ # Check mask is in [0, 1]
+ if mask.min() < 0 or mask.max() > 1:
+ raise ValueError("Mask should be in [0, 1] range")
+
+ # Binarize mask
+ mask[mask < 0.5] = 0
+ mask[mask >= 0.5] = 1
+
+ # Image as float32
+ image = image.to(dtype=torch.float32)
+ elif isinstance(mask, torch.Tensor):
+ raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not")
+ else:
+ # preprocess image
+ if isinstance(image, Union[PIL.Image.Image, np.ndarray]):
+ image = [image]
+ if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):
+ # resize all images w.r.t passed height an width
+ image = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in image]
+ image = [np.array(i.convert("RGB"))[None, :] for i in image]
+ image = np.concatenate(image, axis=0)
+ elif isinstance(image, list) and isinstance(image[0], np.ndarray):
+ image = np.concatenate([i[None, :] for i in image], axis=0)
+
+ image = image.transpose(0, 3, 1, 2)
+ image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
+
+ mask = mask_pil_to_torch(mask, height, width)
+ mask[mask < 0.5] = 0
+ mask[mask >= 0.5] = 1
+
+ if image.shape[1] == 4:
+ # images are in latent space and thus can't
+ # be masked set masked_image to None
+ # we assume that the checkpoint is not an inpainting
+ # checkpoint. #TODO(Yiyi) - need to clean this up later
+ masked_image = None
+ else:
+ masked_image = image * (mask < 0.5)
+
+ # n.b. ensure backwards compatibility as old function does not return image
+ if return_image:
+ return mask, masked_image, image
+
+ return mask, masked_image
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
+def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
+ """
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
+ """
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
+ # rescale the results from guidance (fixes overexposure)
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
+ return noise_cfg
+
+
+class StableDiffusionXLControlNetAdapterInpaintPipeline(
+ DiffusionPipeline, StableDiffusionMixin, FromSingleFileMixin, StableDiffusionLoraLoaderMixin
+):
+ r"""
+ Pipeline for text-to-image generation using Stable Diffusion augmented with T2I-Adapter
+ https://arxiv.org/abs/2302.08453
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ adapter ([`T2IAdapter`] or [`MultiAdapter`] or `List[T2IAdapter]`):
+ Provides additional conditioning to the unet during the denoising process. If you set multiple Adapter as a
+ list, the outputs from each Adapter are added together to create one combined additional conditioning.
+ adapter_weights (`List[float]`, *optional*, defaults to None):
+ List of floats representing the weight which will be multiply to each adapter's output before adding them
+ together.
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ feature_extractor ([`CLIPImageProcessor`]):
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+ requires_aesthetics_score (`bool`, *optional*, defaults to `"False"`):
+ Whether the `unet` requires a aesthetic_score condition to be passed during inference. Also see the config
+ of `stabilityai/stable-diffusion-xl-refiner-1-0`.
+ force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`):
+ Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of
+ `stabilityai/stable-diffusion-xl-base-1-0`.
+ """
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ text_encoder_2: CLIPTextModelWithProjection,
+ tokenizer: CLIPTokenizer,
+ tokenizer_2: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ adapter: Union[T2IAdapter, MultiAdapter],
+ controlnet: Union[ControlNetModel, MultiControlNetModel],
+ scheduler: KarrasDiffusionSchedulers,
+ requires_aesthetics_score: bool = False,
+ force_zeros_for_empty_prompt: bool = True,
+ ):
+ super().__init__()
+
+ if isinstance(controlnet, (list, tuple)):
+ controlnet = MultiControlNetModel(controlnet)
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ text_encoder_2=text_encoder_2,
+ tokenizer=tokenizer,
+ tokenizer_2=tokenizer_2,
+ unet=unet,
+ adapter=adapter,
+ controlnet=controlnet,
+ scheduler=scheduler,
+ )
+ self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
+ self.register_to_config(requires_aesthetics_score=requires_aesthetics_score)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+ self.control_image_processor = VaeImageProcessor(
+ vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
+ )
+ self.default_sample_size = self.unet.config.sample_size
+
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt: str,
+ prompt_2: Optional[str] = None,
+ device: Optional[torch.device] = None,
+ num_images_per_prompt: int = 1,
+ do_classifier_free_guidance: bool = True,
+ negative_prompt: Optional[str] = None,
+ negative_prompt_2: Optional[str] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ pooled_prompt_embeds: Optional[torch.Tensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
+ lora_scale: Optional[float] = None,
+ clip_skip: Optional[int] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ used in both text-encoders
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ pooled_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
+ input argument.
+ lora_scale (`float`, *optional*):
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+ """
+ device = device or self._execution_device
+
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ # dynamically adjust the LoRA scale
+ if self.text_encoder is not None:
+ if not USE_PEFT_BACKEND:
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
+ else:
+ scale_lora_layers(self.text_encoder, lora_scale)
+
+ if self.text_encoder_2 is not None:
+ if not USE_PEFT_BACKEND:
+ adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
+ else:
+ scale_lora_layers(self.text_encoder_2, lora_scale)
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # Define tokenizers and text encoders
+ tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
+ text_encoders = (
+ [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
+ )
+
+ if prompt_embeds is None:
+ prompt_2 = prompt_2 or prompt
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
+
+ # textual inversion: process multi-vector tokens if necessary
+ prompt_embeds_list = []
+ prompts = [prompt, prompt_2]
+ for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, tokenizer)
+
+ text_inputs = tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
+
+ # We are only ALWAYS interested in the pooled output of the final text encoder
+ pooled_prompt_embeds = prompt_embeds[0]
+ if clip_skip is None:
+ prompt_embeds = prompt_embeds.hidden_states[-2]
+ else:
+ # "2" because SDXL always indexes from the penultimate layer.
+ prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
+
+ prompt_embeds_list.append(prompt_embeds)
+
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
+
+ # get unconditional embeddings for classifier free guidance
+ zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt
+ if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:
+ negative_prompt_embeds = torch.zeros_like(prompt_embeds)
+ negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
+ elif do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = negative_prompt or ""
+ negative_prompt_2 = negative_prompt_2 or negative_prompt
+
+ # normalize str to list
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+ negative_prompt_2 = (
+ batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
+ )
+
+ uncond_tokens: List[str]
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = [negative_prompt, negative_prompt_2]
+
+ negative_prompt_embeds_list = []
+ for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):
+ if isinstance(self, TextualInversionLoaderMixin):
+ negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer)
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = tokenizer(
+ negative_prompt,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ negative_prompt_embeds = text_encoder(
+ uncond_input.input_ids.to(device),
+ output_hidden_states=True,
+ )
+ # We are only ALWAYS interested in the pooled output of the final text encoder
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
+ negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
+
+ negative_prompt_embeds_list.append(negative_prompt_embeds)
+
+ negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
+
+ if self.text_encoder_2 is not None:
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
+ else:
+ prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ if self.text_encoder_2 is not None:
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
+ else:
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
+ bs_embed * num_images_per_prompt, -1
+ )
+ if do_classifier_free_guidance:
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
+ bs_embed * num_images_per_prompt, -1
+ )
+
+ if self.text_encoder is not None:
+ if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder, lora_scale)
+
+ if self.text_encoder_2 is not None:
+ if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
+
+ return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image
+ def check_image(self, image, prompt, prompt_embeds):
+ image_is_pil = isinstance(image, PIL.Image.Image)
+ image_is_tensor = isinstance(image, torch.Tensor)
+ image_is_np = isinstance(image, np.ndarray)
+ image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)
+ image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)
+ image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray)
+
+ if (
+ not image_is_pil
+ and not image_is_tensor
+ and not image_is_np
+ and not image_is_pil_list
+ and not image_is_tensor_list
+ and not image_is_np_list
+ ):
+ raise TypeError(
+ f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}"
+ )
+
+ if image_is_pil:
+ image_batch_size = 1
+ else:
+ image_batch_size = len(image)
+
+ if prompt is not None and isinstance(prompt, str):
+ prompt_batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ prompt_batch_size = len(prompt)
+ elif prompt_embeds is not None:
+ prompt_batch_size = prompt_embeds.shape[0]
+
+ if image_batch_size != 1 and image_batch_size != prompt_batch_size:
+ raise ValueError(
+ f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
+ )
+
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.check_inputs
+ def check_inputs(
+ self,
+ prompt,
+ prompt_2,
+ height,
+ width,
+ callback_steps,
+ negative_prompt=None,
+ negative_prompt_2=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ pooled_prompt_embeds=None,
+ negative_pooled_prompt_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ ):
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt_2 is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
+ raise ValueError(
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
+ )
+
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
+ raise ValueError(
+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
+ )
+
+ def check_conditions(
+ self,
+ prompt,
+ prompt_embeds,
+ adapter_image,
+ control_image,
+ adapter_conditioning_scale,
+ controlnet_conditioning_scale,
+ control_guidance_start,
+ control_guidance_end,
+ ):
+ # controlnet checks
+ if not isinstance(control_guidance_start, (tuple, list)):
+ control_guidance_start = [control_guidance_start]
+
+ if not isinstance(control_guidance_end, (tuple, list)):
+ control_guidance_end = [control_guidance_end]
+
+ if len(control_guidance_start) != len(control_guidance_end):
+ raise ValueError(
+ f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list."
+ )
+
+ if isinstance(self.controlnet, MultiControlNetModel):
+ if len(control_guidance_start) != len(self.controlnet.nets):
+ raise ValueError(
+ f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}."
+ )
+
+ for start, end in zip(control_guidance_start, control_guidance_end):
+ if start >= end:
+ raise ValueError(
+ f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}."
+ )
+ if start < 0.0:
+ raise ValueError(f"control guidance start: {start} can't be smaller than 0.")
+ if end > 1.0:
+ raise ValueError(f"control guidance end: {end} can't be larger than 1.0.")
+
+ # Check controlnet `image`
+ is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
+ self.controlnet, torch._dynamo.eval_frame.OptimizedModule
+ )
+ if (
+ isinstance(self.controlnet, ControlNetModel)
+ or is_compiled
+ and isinstance(self.controlnet._orig_mod, ControlNetModel)
+ ):
+ self.check_image(control_image, prompt, prompt_embeds)
+ elif (
+ isinstance(self.controlnet, MultiControlNetModel)
+ or is_compiled
+ and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
+ ):
+ if not isinstance(control_image, list):
+ raise TypeError("For multiple controlnets: `control_image` must be type `list`")
+
+ # When `image` is a nested list:
+ # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]])
+ elif any(isinstance(i, list) for i in control_image):
+ raise ValueError("A single batch of multiple conditionings are supported at the moment.")
+ elif len(control_image) != len(self.controlnet.nets):
+ raise ValueError(
+ f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(control_image)} images and {len(self.controlnet.nets)} ControlNets."
+ )
+
+ for image_ in control_image:
+ self.check_image(image_, prompt, prompt_embeds)
+ else:
+ assert False
+
+ # Check `controlnet_conditioning_scale`
+ if (
+ isinstance(self.controlnet, ControlNetModel)
+ or is_compiled
+ and isinstance(self.controlnet._orig_mod, ControlNetModel)
+ ):
+ if not isinstance(controlnet_conditioning_scale, float):
+ raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
+ elif (
+ isinstance(self.controlnet, MultiControlNetModel)
+ or is_compiled
+ and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
+ ):
+ if isinstance(controlnet_conditioning_scale, list):
+ if any(isinstance(i, list) for i in controlnet_conditioning_scale):
+ raise ValueError("A single batch of multiple conditionings are supported at the moment.")
+ elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len(
+ self.controlnet.nets
+ ):
+ raise ValueError(
+ "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have"
+ " the same length as the number of controlnets"
+ )
+ else:
+ assert False
+
+ # adapter checks
+ if isinstance(self.adapter, T2IAdapter) or is_compiled and isinstance(self.adapter._orig_mod, T2IAdapter):
+ self.check_image(adapter_image, prompt, prompt_embeds)
+ elif (
+ isinstance(self.adapter, MultiAdapter) or is_compiled and isinstance(self.adapter._orig_mod, MultiAdapter)
+ ):
+ if not isinstance(adapter_image, list):
+ raise TypeError("For multiple adapters: `adapter_image` must be type `list`")
+
+ # When `image` is a nested list:
+ # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]])
+ elif any(isinstance(i, list) for i in adapter_image):
+ raise ValueError("A single batch of multiple conditionings are supported at the moment.")
+ elif len(adapter_image) != len(self.adapter.adapters):
+ raise ValueError(
+ f"For multiple adapters: `image` must have the same length as the number of adapters, but got {len(adapter_image)} images and {len(self.adapters.nets)} Adapters."
+ )
+
+ for image_ in adapter_image:
+ self.check_image(image_, prompt, prompt_embeds)
+ else:
+ assert False
+
+ # Check `adapter_conditioning_scale`
+ if isinstance(self.adapter, T2IAdapter) or is_compiled and isinstance(self.adapter._orig_mod, T2IAdapter):
+ if not isinstance(adapter_conditioning_scale, float):
+ raise TypeError("For single adapter: `adapter_conditioning_scale` must be type `float`.")
+ elif (
+ isinstance(self.adapter, MultiAdapter) or is_compiled and isinstance(self.adapter._orig_mod, MultiAdapter)
+ ):
+ if isinstance(adapter_conditioning_scale, list):
+ if any(isinstance(i, list) for i in adapter_conditioning_scale):
+ raise ValueError("A single batch of multiple conditionings are supported at the moment.")
+ elif isinstance(adapter_conditioning_scale, list) and len(adapter_conditioning_scale) != len(
+ self.adapter.adapters
+ ):
+ raise ValueError(
+ "For multiple adapters: When `adapter_conditioning_scale` is specified as `list`, it must have"
+ " the same length as the number of adapters"
+ )
+ else:
+ assert False
+
+ def prepare_latents(
+ self,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ latents=None,
+ image=None,
+ timestep=None,
+ is_strength_max=True,
+ add_noise=True,
+ return_noise=False,
+ return_image_latents=False,
+ ):
+ shape = (
+ batch_size,
+ num_channels_latents,
+ height // self.vae_scale_factor,
+ width // self.vae_scale_factor,
+ )
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if (image is None or timestep is None) and not is_strength_max:
+ raise ValueError(
+ "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise."
+ "However, either the image or the noise timestep has not been provided."
+ )
+
+ if image.shape[1] == 4:
+ image_latents = image.to(device=device, dtype=dtype)
+ elif return_image_latents or (latents is None and not is_strength_max):
+ image = image.to(device=device, dtype=dtype)
+ image_latents = self._encode_vae_image(image=image, generator=generator)
+
+ image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1)
+
+ if latents is None and add_noise:
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ # if strength is 1. then initialise the latents to noise, else initial to image + noise
+ latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep)
+ # if pure noise then scale the initial latents by the Scheduler's init sigma
+ latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents
+ elif add_noise:
+ noise = latents.to(device)
+ latents = noise * self.scheduler.init_noise_sigma
+ else:
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ latents = image_latents.to(device)
+
+ outputs = (latents,)
+
+ if return_noise:
+ outputs += (noise,)
+
+ if return_image_latents:
+ outputs += (image_latents,)
+
+ return outputs
+
+ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
+ dtype = image.dtype
+ if self.vae.config.force_upcast:
+ image = image.float()
+ self.vae.to(dtype=torch.float32)
+
+ if isinstance(generator, list):
+ image_latents = [
+ self.vae.encode(image[i : i + 1]).latent_dist.sample(generator=generator[i])
+ for i in range(image.shape[0])
+ ]
+ image_latents = torch.cat(image_latents, dim=0)
+ else:
+ image_latents = self.vae.encode(image).latent_dist.sample(generator=generator)
+
+ if self.vae.config.force_upcast:
+ self.vae.to(dtype)
+
+ image_latents = image_latents.to(dtype)
+ image_latents = self.vae.config.scaling_factor * image_latents
+
+ return image_latents
+
+ def prepare_mask_latents(
+ self,
+ mask,
+ masked_image,
+ batch_size,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ do_classifier_free_guidance,
+ ):
+ # resize the mask to latents shape as we concatenate the mask to the latents
+ # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
+ # and half precision
+ mask = torch.nn.functional.interpolate(
+ mask,
+ size=(
+ height // self.vae_scale_factor,
+ width // self.vae_scale_factor,
+ ),
+ )
+ mask = mask.to(device=device, dtype=dtype)
+
+ # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
+ if mask.shape[0] < batch_size:
+ if not batch_size % mask.shape[0] == 0:
+ raise ValueError(
+ "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
+ f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
+ " of masks that you pass is divisible by the total requested batch size."
+ )
+ mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
+
+ mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
+
+ masked_image_latents = None
+ if masked_image is not None:
+ masked_image = masked_image.to(device=device, dtype=dtype)
+ masked_image_latents = self._encode_vae_image(masked_image, generator=generator)
+ if masked_image_latents.shape[0] < batch_size:
+ if not batch_size % masked_image_latents.shape[0] == 0:
+ raise ValueError(
+ "The passed images and the required batch size don't match. Images are supposed to be duplicated"
+ f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
+ " Make sure the number of images that you pass is divisible by the total requested batch size."
+ )
+ masked_image_latents = masked_image_latents.repeat(
+ batch_size // masked_image_latents.shape[0], 1, 1, 1
+ )
+
+ masked_image_latents = (
+ torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents
+ )
+
+ # aligning device to prevent device errors when concating it with the latent model input
+ masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
+
+ return mask, masked_image_latents
+
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.get_timesteps
+ def get_timesteps(self, num_inference_steps, strength, device, denoising_start=None):
+ # get the original timestep using init_timestep
+ if denoising_start is None:
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
+ t_start = max(num_inference_steps - init_timestep, 0)
+ else:
+ t_start = 0
+
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
+
+ # Strength is irrelevant if we directly request a timestep to start at;
+ # that is, strength is determined by the denoising_start instead.
+ if denoising_start is not None:
+ discrete_timestep_cutoff = int(
+ round(
+ self.scheduler.config.num_train_timesteps
+ - (denoising_start * self.scheduler.config.num_train_timesteps)
+ )
+ )
+
+ num_inference_steps = (timesteps < discrete_timestep_cutoff).sum().item()
+ if self.scheduler.order == 2 and num_inference_steps % 2 == 0:
+ # if the scheduler is a 2nd order scheduler we might have to do +1
+ # because `num_inference_steps` might be even given that every timestep
+ # (except the highest one) is duplicated. If `num_inference_steps` is even it would
+ # mean that we cut the timesteps in the middle of the denoising step
+ # (between 1st and 2nd derivative) which leads to incorrect results. By adding 1
+ # we ensure that the denoising process always ends after the 2nd derivate step of the scheduler
+ num_inference_steps = num_inference_steps + 1
+
+ # because t_n+1 >= t_n, we slice the timesteps starting from the end
+ timesteps = timesteps[-num_inference_steps:]
+ return timesteps, num_inference_steps
+
+ return timesteps, num_inference_steps - t_start
+
+ def _get_add_time_ids(
+ self,
+ original_size,
+ crops_coords_top_left,
+ target_size,
+ aesthetic_score,
+ negative_aesthetic_score,
+ dtype,
+ text_encoder_projection_dim=None,
+ ):
+ if self.config.requires_aesthetics_score:
+ add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,))
+ add_neg_time_ids = list(original_size + crops_coords_top_left + (negative_aesthetic_score,))
+ else:
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
+ add_neg_time_ids = list(original_size + crops_coords_top_left + target_size)
+
+ passed_add_embed_dim = (
+ self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
+ )
+ expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
+
+ if (
+ expected_add_embed_dim > passed_add_embed_dim
+ and (expected_add_embed_dim - passed_add_embed_dim) == self.unet.config.addition_time_embed_dim
+ ):
+ raise ValueError(
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model."
+ )
+ elif (
+ expected_add_embed_dim < passed_add_embed_dim
+ and (passed_add_embed_dim - expected_add_embed_dim) == self.unet.config.addition_time_embed_dim
+ ):
+ raise ValueError(
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model."
+ )
+ elif expected_add_embed_dim != passed_add_embed_dim:
+ raise ValueError(
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
+ )
+
+ add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
+ add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype)
+
+ return add_time_ids, add_neg_time_ids
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
+ def upcast_vae(self):
+ dtype = self.vae.dtype
+ self.vae.to(dtype=torch.float32)
+ use_torch_2_0_or_xformers = isinstance(
+ self.vae.decoder.mid_block.attentions[0].processor,
+ (AttnProcessor2_0, XFormersAttnProcessor),
+ )
+ # if xformers or torch_2_0 is used attention block does not need
+ # to be in float32 which can save lots of memory
+ if use_torch_2_0_or_xformers:
+ self.vae.post_quant_conv.to(dtype)
+ self.vae.decoder.conv_in.to(dtype)
+ self.vae.decoder.mid_block.to(dtype)
+
+ # Copied from diffusers.pipelines.t2i_adapter.pipeline_stable_diffusion_adapter.StableDiffusionAdapterPipeline._default_height_width
+ def _default_height_width(self, height, width, image):
+ # NOTE: It is possible that a list of images have different
+ # dimensions for each image, so just checking the first image
+ # is not _exactly_ correct, but it is simple.
+ while isinstance(image, list):
+ image = image[0]
+
+ if height is None:
+ if isinstance(image, PIL.Image.Image):
+ height = image.height
+ elif isinstance(image, torch.Tensor):
+ height = image.shape[-2]
+
+ # round down to nearest multiple of `self.adapter.downscale_factor`
+ height = (height // self.adapter.downscale_factor) * self.adapter.downscale_factor
+
+ if width is None:
+ if isinstance(image, PIL.Image.Image):
+ width = image.width
+ elif isinstance(image, torch.Tensor):
+ width = image.shape[-1]
+
+ # round down to nearest multiple of `self.adapter.downscale_factor`
+ width = (width // self.adapter.downscale_factor) * self.adapter.downscale_factor
+
+ return height, width
+
+ def prepare_control_image(
+ self,
+ image,
+ width,
+ height,
+ batch_size,
+ num_images_per_prompt,
+ device,
+ dtype,
+ do_classifier_free_guidance=False,
+ guess_mode=False,
+ ):
+ image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
+ image_batch_size = image.shape[0]
+
+ if image_batch_size == 1:
+ repeat_by = batch_size
+ else:
+ # image batch size is the same as prompt batch size
+ repeat_by = num_images_per_prompt
+
+ image = image.repeat_interleave(repeat_by, dim=0)
+
+ image = image.to(device=device, dtype=dtype)
+
+ if do_classifier_free_guidance and not guess_mode:
+ image = torch.cat([image] * 2)
+
+ return image
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Optional[Union[str, List[str]]] = None,
+ prompt_2: Optional[Union[str, List[str]]] = None,
+ image: Optional[Union[torch.Tensor, PIL.Image.Image]] = None,
+ mask_image: Optional[Union[torch.Tensor, PIL.Image.Image]] = None,
+ adapter_image: PipelineImageInput = None,
+ control_image: PipelineImageInput = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ strength: float = 0.9999,
+ num_inference_steps: int = 50,
+ denoising_start: Optional[float] = None,
+ denoising_end: Optional[float] = None,
+ guidance_scale: float = 5.0,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[Union[torch.Tensor]] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ pooled_prompt_embeds: Optional[torch.Tensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
+ callback_steps: int = 1,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ guidance_rescale: float = 0.0,
+ original_size: Optional[Tuple[int, int]] = None,
+ crops_coords_top_left: Optional[Tuple[int, int]] = (0, 0),
+ target_size: Optional[Tuple[int, int]] = None,
+ adapter_conditioning_scale: Optional[Union[float, List[float]]] = 1.0,
+ cond_tau: float = 1.0,
+ aesthetic_score: float = 6.0,
+ negative_aesthetic_score: float = 2.5,
+ controlnet_conditioning_scale=1.0,
+ guess_mode: bool = False,
+ control_guidance_start=0.0,
+ control_guidance_end=1.0,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ used in both text-encoders
+ image (`PIL.Image.Image`):
+ `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will
+ be masked out with `mask_image` and repainted according to `prompt`.
+ mask_image (`PIL.Image.Image`):
+ `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
+ repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted
+ to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)
+ instead of 3, so the expected shape would be `(B, H, W, 1)`.
+ adapter_image (`torch.Tensor`, `PIL.Image.Image`, `List[torch.Tensor]` or `List[PIL.Image.Image]` or `List[List[PIL.Image.Image]]`):
+ The Adapter input condition. Adapter uses this input condition to generate guidance to Unet. If the
+ type is specified as `torch.Tensor`, it is passed to Adapter as is. PIL.Image.Image` can also be
+ accepted as an image. The control image is automatically resized to fit the output image.
+ control_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
+ `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
+ The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
+ specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be
+ accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height
+ and/or width are passed, `image` is resized accordingly. If multiple ControlNets are specified in
+ `init`, images must be passed as a list such that each element of the list can be correctly batched for
+ input to a single ControlNet.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image.
+ strength (`float`, *optional*, defaults to 1.0):
+ Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
+ starting point and more noise is added the higher the `strength`. The number of denoising steps depends
+ on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
+ process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
+ essentially ignores `image`.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ denoising_start (`float`, *optional*):
+ When specified, indicates the fraction (between 0.0 and 1.0) of the total denoising process to be
+ bypassed before it is initiated. Consequently, the initial part of the denoising process is skipped and
+ it is assumed that the passed `image` is a partly denoised image. Note that when this is specified,
+ strength will be ignored. The `denoising_start` parameter is particularly beneficial when this pipeline
+ is integrated into a "Mixture of Denoisers" multi-pipeline setup, as detailed in [**Refining the Image
+ Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output).
+ denoising_end (`float`, *optional*):
+ When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
+ completed before it is intentionally prematurely terminated. As a result, the returned sample will
+ still retain a substantial amount of noise as determined by the discrete timesteps selected by the
+ scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a
+ "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
+ Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
+ guidance_scale (`float`, *optional*, defaults to 5.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ pooled_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
+ input argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionAdapterPipelineOutput`]
+ instead of a plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ guidance_rescale (`float`, *optional*, defaults to 0.7):
+ Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
+ [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
+ Guidance rescale factor should fix overexposure when using zero terminal SNR.
+ original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
+ `original_size` defaults to `(width, height)` if not specified. Part of SDXL's micro-conditioning as
+ explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
+ `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
+ `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
+ `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ For most cases, `target_size` should be set to the desired height and width of the generated image. If
+ not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in
+ section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
+ The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added to the
+ residual in the original unet. If multiple adapters are specified in init, you can set the
+ corresponding scale as a list.
+ adapter_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
+ The outputs of the adapter are multiplied by `adapter_conditioning_scale` before they are added to the
+ residual in the original unet. If multiple adapters are specified in init, you can set the
+ corresponding scale as a list.
+ aesthetic_score (`float`, *optional*, defaults to 6.0):
+ Used to simulate an aesthetic score of the generated image by influencing the positive text condition.
+ Part of SDXL's micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ negative_aesthetic_score (`float`, *optional*, defaults to 2.5):
+ Part of SDXL's micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). Can be used to
+ simulate an aesthetic score of the generated image by influencing the negative text condition.
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionAdapterPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionAdapterPipelineOutput`] if `return_dict` is True, otherwise a
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
+ """
+ # 0. Default height and width to unet
+ controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
+ adapter = self.adapter._orig_mod if is_compiled_module(self.adapter) else self.adapter
+ height, width = self._default_height_width(height, width, adapter_image)
+ device = self._execution_device
+
+ if isinstance(adapter, MultiAdapter):
+ adapter_input = []
+ for one_image in adapter_image:
+ one_image = _preprocess_adapter_image(one_image, height, width)
+ one_image = one_image.to(device=device, dtype=adapter.dtype)
+ adapter_input.append(one_image)
+ else:
+ adapter_input = _preprocess_adapter_image(adapter_image, height, width)
+ adapter_input = adapter_input.to(device=device, dtype=adapter.dtype)
+
+ original_size = original_size or (height, width)
+ target_size = target_size or (height, width)
+
+ # 0.1 align format for control guidance
+ if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
+ control_guidance_start = len(control_guidance_end) * [control_guidance_start]
+ elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
+ control_guidance_end = len(control_guidance_start) * [control_guidance_end]
+ elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
+ mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
+ control_guidance_start, control_guidance_end = (
+ mult * [control_guidance_start],
+ mult * [control_guidance_end],
+ )
+
+ if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
+ controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
+ if isinstance(adapter, MultiAdapter) and isinstance(adapter_conditioning_scale, float):
+ adapter_conditioning_scale = [adapter_conditioning_scale] * len(adapter.nets)
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ prompt_2,
+ height,
+ width,
+ callback_steps,
+ negative_prompt=negative_prompt,
+ negative_prompt_2=negative_prompt_2,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ )
+
+ self.check_conditions(
+ prompt,
+ prompt_embeds,
+ adapter_image,
+ control_image,
+ adapter_conditioning_scale,
+ controlnet_conditioning_scale,
+ control_guidance_start,
+ control_guidance_end,
+ )
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ (
+ prompt_embeds,
+ negative_prompt_embeds,
+ pooled_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ ) = self.encode_prompt(
+ prompt=prompt,
+ prompt_2=prompt_2,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ do_classifier_free_guidance=do_classifier_free_guidance,
+ negative_prompt=negative_prompt,
+ negative_prompt_2=negative_prompt_2,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ )
+
+ # 4. set timesteps
+ def denoising_value_valid(dnv):
+ return isinstance(dnv, float) and 0 < dnv < 1
+
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps, num_inference_steps = self.get_timesteps(
+ num_inference_steps,
+ strength,
+ device,
+ denoising_start=denoising_start if denoising_value_valid(denoising_start) else None,
+ )
+ # check that number of inference steps is not < 1 - as this doesn't make sense
+ if num_inference_steps < 1:
+ raise ValueError(
+ f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline"
+ f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
+ )
+ # at which timestep to set the initial noise (n.b. 50% if strength is 0.5)
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
+ # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
+ is_strength_max = strength == 1.0
+
+ # 5. Preprocess mask and image - resizes image and mask w.r.t height and width
+ mask, masked_image, init_image = prepare_mask_and_masked_image(
+ image, mask_image, height, width, return_image=True
+ )
+
+ # 6. Prepare latent variables
+ num_channels_latents = self.vae.config.latent_channels
+ num_channels_unet = self.unet.config.in_channels
+ return_image_latents = num_channels_unet == 4
+
+ add_noise = denoising_start is None
+ latents_outputs = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ image=init_image,
+ timestep=latent_timestep,
+ is_strength_max=is_strength_max,
+ add_noise=add_noise,
+ return_noise=True,
+ return_image_latents=return_image_latents,
+ )
+
+ if return_image_latents:
+ latents, noise, image_latents = latents_outputs
+ else:
+ latents, noise = latents_outputs
+
+ # 7. Prepare mask latent variables
+ mask, masked_image_latents = self.prepare_mask_latents(
+ mask,
+ masked_image,
+ batch_size * num_images_per_prompt,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ do_classifier_free_guidance,
+ )
+
+ # 8. Check that sizes of mask, masked image and latents match
+ if num_channels_unet == 9:
+ # default case for runwayml/stable-diffusion-inpainting
+ num_channels_mask = mask.shape[1]
+ num_channels_masked_image = masked_image_latents.shape[1]
+ if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels:
+ raise ValueError(
+ f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
+ f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
+ f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
+ f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
+ " `pipeline.unet` or your `mask_image` or `image` input."
+ )
+ elif num_channels_unet != 4:
+ raise ValueError(
+ f"The unet {self.unet.__class__} should have either 4 or 9 input channels, not {self.unet.config.in_channels}."
+ )
+
+ # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 10. Prepare added time ids & embeddings & adapter features
+ if isinstance(adapter, MultiAdapter):
+ adapter_state = adapter(adapter_input, adapter_conditioning_scale)
+ for k, v in enumerate(adapter_state):
+ adapter_state[k] = v
+ else:
+ adapter_state = adapter(adapter_input)
+ for k, v in enumerate(adapter_state):
+ adapter_state[k] = v * adapter_conditioning_scale
+ if num_images_per_prompt > 1:
+ for k, v in enumerate(adapter_state):
+ adapter_state[k] = v.repeat(num_images_per_prompt, 1, 1, 1)
+ if do_classifier_free_guidance:
+ for k, v in enumerate(adapter_state):
+ adapter_state[k] = torch.cat([v] * 2, dim=0)
+
+ # 10.2 Prepare control images
+ if isinstance(controlnet, ControlNetModel):
+ control_image = self.prepare_control_image(
+ image=control_image,
+ width=width,
+ height=height,
+ batch_size=batch_size * num_images_per_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ dtype=controlnet.dtype,
+ do_classifier_free_guidance=do_classifier_free_guidance,
+ guess_mode=guess_mode,
+ )
+ elif isinstance(controlnet, MultiControlNetModel):
+ control_images = []
+
+ for control_image_ in control_image:
+ control_image_ = self.prepare_control_image(
+ image=control_image_,
+ width=width,
+ height=height,
+ batch_size=batch_size * num_images_per_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ dtype=controlnet.dtype,
+ do_classifier_free_guidance=do_classifier_free_guidance,
+ guess_mode=guess_mode,
+ )
+
+ control_images.append(control_image_)
+
+ control_image = control_images
+ else:
+ raise ValueError(f"{controlnet.__class__} is not supported.")
+
+ # 8.2 Create tensor stating which controlnets to keep
+ controlnet_keep = []
+ for i in range(len(timesteps)):
+ keeps = [
+ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
+ for s, e in zip(control_guidance_start, control_guidance_end)
+ ]
+ if isinstance(self.controlnet, MultiControlNetModel):
+ controlnet_keep.append(keeps)
+ else:
+ controlnet_keep.append(keeps[0])
+ # ----------------------------------------------------------------
+
+ add_text_embeds = pooled_prompt_embeds
+ if self.text_encoder_2 is None:
+ text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
+ else:
+ text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
+
+ add_time_ids, add_neg_time_ids = self._get_add_time_ids(
+ original_size,
+ crops_coords_top_left,
+ target_size,
+ aesthetic_score,
+ negative_aesthetic_score,
+ dtype=prompt_embeds.dtype,
+ text_encoder_projection_dim=text_encoder_projection_dim,
+ )
+ add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1)
+
+ if do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
+ add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1)
+ add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0)
+
+ prompt_embeds = prompt_embeds.to(device)
+ add_text_embeds = add_text_embeds.to(device)
+ add_time_ids = add_time_ids.to(device)
+
+ # 11. Denoising loop
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+
+ # 11.1 Apply denoising_end
+ if (
+ denoising_end is not None
+ and denoising_start is not None
+ and denoising_value_valid(denoising_end)
+ and denoising_value_valid(denoising_start)
+ and denoising_start >= denoising_end
+ ):
+ raise ValueError(
+ f"`denoising_start`: {denoising_start} cannot be larger than or equal to `denoising_end`: "
+ + f" {denoising_end} when using type float."
+ )
+ elif denoising_end is not None and denoising_value_valid(denoising_end):
+ discrete_timestep_cutoff = int(
+ round(
+ self.scheduler.config.num_train_timesteps
+ - (denoising_end * self.scheduler.config.num_train_timesteps)
+ )
+ )
+ num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
+ timesteps = timesteps[:num_inference_steps]
+
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ if num_channels_unet == 9:
+ latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
+
+ # predict the noise residual
+ added_cond_kwargs = {
+ "text_embeds": add_text_embeds,
+ "time_ids": add_time_ids,
+ }
+
+ if i < int(num_inference_steps * cond_tau):
+ down_block_additional_residuals = [state.clone() for state in adapter_state]
+ else:
+ down_block_additional_residuals = None
+
+ # ----------- ControlNet
+
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input_controlnet = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+
+ # concat latents, mask, masked_image_latents in the channel dimension
+ latent_model_input_controlnet = self.scheduler.scale_model_input(latent_model_input_controlnet, t)
+
+ # controlnet(s) inference
+ if guess_mode and do_classifier_free_guidance:
+ # Infer ControlNet only for the conditional batch.
+ control_model_input = latents
+ control_model_input = self.scheduler.scale_model_input(control_model_input, t)
+ controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
+ controlnet_added_cond_kwargs = {
+ "text_embeds": add_text_embeds.chunk(2)[1],
+ "time_ids": add_time_ids.chunk(2)[1],
+ }
+ else:
+ control_model_input = latent_model_input_controlnet
+ controlnet_prompt_embeds = prompt_embeds
+ controlnet_added_cond_kwargs = added_cond_kwargs
+
+ if isinstance(controlnet_keep[i], list):
+ cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
+ else:
+ controlnet_cond_scale = controlnet_conditioning_scale
+ if isinstance(controlnet_cond_scale, list):
+ controlnet_cond_scale = controlnet_cond_scale[0]
+ cond_scale = controlnet_cond_scale * controlnet_keep[i]
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
+ control_model_input,
+ t,
+ encoder_hidden_states=controlnet_prompt_embeds,
+ controlnet_cond=control_image,
+ conditioning_scale=cond_scale,
+ guess_mode=guess_mode,
+ added_cond_kwargs=controlnet_added_cond_kwargs,
+ return_dict=False,
+ )
+
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs=cross_attention_kwargs,
+ added_cond_kwargs=added_cond_kwargs,
+ return_dict=False,
+ down_intrablock_additional_residuals=down_block_additional_residuals, # t2iadapter
+ down_block_additional_residuals=down_block_res_samples, # controlnet
+ mid_block_additional_residual=mid_block_res_sample, # controlnet
+ )[0]
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+ noise_pred = rescale_noise_cfg(
+ noise_pred,
+ noise_pred_text,
+ guidance_rescale=guidance_rescale,
+ )
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(
+ noise_pred,
+ t,
+ latents,
+ **extra_step_kwargs,
+ return_dict=False,
+ )[0]
+
+ if num_channels_unet == 4:
+ init_latents_proper = image_latents
+ if do_classifier_free_guidance:
+ init_mask, _ = mask.chunk(2)
+ else:
+ init_mask = mask
+
+ if i < len(timesteps) - 1:
+ noise_timestep = timesteps[i + 1]
+ init_latents_proper = self.scheduler.add_noise(
+ init_latents_proper,
+ noise,
+ torch.tensor([noise_timestep]),
+ )
+
+ latents = (1 - init_mask) * init_latents_proper + init_mask * latents
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, latents)
+
+ # make sure the VAE is in float32 mode, as it overflows in float16
+ if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
+ self.upcast_vae()
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
+
+ if output_type != "latent":
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+ else:
+ image = latents
+ return StableDiffusionXLPipelineOutput(images=image)
+
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload last model to CPU
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.final_offload_hook.offload()
+
+ if not return_dict:
+ return (image,)
+
+ return StableDiffusionXLPipelineOutput(images=image)
diff --git a/diffusers/examples/community/pipeline_stable_diffusion_xl_differential_img2img.py b/diffusers/examples/community/pipeline_stable_diffusion_xl_differential_img2img.py
new file mode 100644
index 0000000000000000000000000000000000000000..584820e86254a65ee0da26669c0eafacdb98729b
--- /dev/null
+++ b/diffusers/examples/community/pipeline_stable_diffusion_xl_differential_img2img.py
@@ -0,0 +1,1466 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import PIL.Image
+import torch
+import torchvision
+from transformers import (
+ CLIPImageProcessor,
+ CLIPTextModel,
+ CLIPTextModelWithProjection,
+ CLIPTokenizer,
+ CLIPVisionModelWithProjection,
+)
+
+from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
+from diffusers.loaders import (
+ FromSingleFileMixin,
+ IPAdapterMixin,
+ StableDiffusionXLLoraLoaderMixin,
+ TextualInversionLoaderMixin,
+)
+from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel
+from diffusers.models.attention_processor import (
+ AttnProcessor2_0,
+ XFormersAttnProcessor,
+)
+from diffusers.models.lora import adjust_lora_scale_text_encoder
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
+from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
+from diffusers.schedulers import KarrasDiffusionSchedulers
+from diffusers.utils import (
+ USE_PEFT_BACKEND,
+ deprecate,
+ is_invisible_watermark_available,
+ is_torch_xla_available,
+ logging,
+ replace_example_docstring,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
+from diffusers.utils.torch_utils import randn_tensor
+
+
+if is_invisible_watermark_available():
+ from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import StableDiffusionXLImg2ImgPipeline
+ >>> from diffusers.utils import load_image
+
+ >>> pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
+ ... "stabilityai/stable-diffusion-xl-refiner-1.0", torch_dtype=torch.float16
+ ... )
+ >>> pipe = pipe.to("cuda")
+ >>> url = "https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/aa_xl/000000009.png"
+
+ >>> init_image = load_image(url).convert("RGB")
+ >>> prompt = "a photo of an astronaut riding a horse on mars"
+ >>> image = pipe(prompt, image=init_image).images[0]
+ ```
+"""
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
+def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
+ """
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
+ """
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
+ # rescale the results from guidance (fixes overexposure)
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
+ return noise_cfg
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ **kwargs,
+):
+ """
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used,
+ `timesteps` must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
+ timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
+ must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class StableDiffusionXLDifferentialImg2ImgPipeline(
+ DiffusionPipeline,
+ StableDiffusionMixin,
+ TextualInversionLoaderMixin,
+ FromSingleFileMixin,
+ StableDiffusionXLLoraLoaderMixin,
+ IPAdapterMixin,
+):
+ r"""
+ Pipeline for text-to-image generation using Stable Diffusion XL.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ In addition the pipeline inherits the following loading methods:
+ - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
+ - *LoRA*: [`loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]
+ - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`]
+
+ as well as the following saving methods:
+ - *LoRA*: [`loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`]
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion XL uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ text_encoder_2 ([` CLIPTextModelWithProjection`]):
+ Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
+ specifically the
+ [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
+ variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ tokenizer_2 (`CLIPTokenizer`):
+ Second Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ """
+
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae"
+ _optional_components = [
+ "tokenizer",
+ "tokenizer_2",
+ "text_encoder",
+ "text_encoder_2",
+ "image_encoder",
+ "feature_extractor",
+ ]
+ _callback_tensor_inputs = [
+ "latents",
+ "prompt_embeds",
+ "negative_prompt_embeds",
+ "add_text_embeds",
+ "add_time_ids",
+ "negative_pooled_prompt_embeds",
+ "add_neg_time_ids",
+ ]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ text_encoder_2: CLIPTextModelWithProjection,
+ tokenizer: CLIPTokenizer,
+ tokenizer_2: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: KarrasDiffusionSchedulers,
+ image_encoder: CLIPVisionModelWithProjection = None,
+ feature_extractor: CLIPImageProcessor = None,
+ requires_aesthetics_score: bool = False,
+ force_zeros_for_empty_prompt: bool = True,
+ add_watermarker: Optional[bool] = None,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ text_encoder_2=text_encoder_2,
+ tokenizer=tokenizer,
+ tokenizer_2=tokenizer_2,
+ unet=unet,
+ image_encoder=image_encoder,
+ feature_extractor=feature_extractor,
+ scheduler=scheduler,
+ )
+ self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
+ self.register_to_config(requires_aesthetics_score=requires_aesthetics_score)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+
+ add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
+
+ if add_watermarker:
+ self.watermark = StableDiffusionXLWatermarker()
+ else:
+ self.watermark = None
+
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt: str,
+ prompt_2: Optional[str] = None,
+ device: Optional[torch.device] = None,
+ num_images_per_prompt: int = 1,
+ do_classifier_free_guidance: bool = True,
+ negative_prompt: Optional[str] = None,
+ negative_prompt_2: Optional[str] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ pooled_prompt_embeds: Optional[torch.Tensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
+ lora_scale: Optional[float] = None,
+ clip_skip: Optional[int] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ used in both text-encoders
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ pooled_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
+ input argument.
+ lora_scale (`float`, *optional*):
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+ """
+ device = device or self._execution_device
+
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ # dynamically adjust the LoRA scale
+ if self.text_encoder is not None:
+ if not USE_PEFT_BACKEND:
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
+ else:
+ scale_lora_layers(self.text_encoder, lora_scale)
+
+ if self.text_encoder_2 is not None:
+ if not USE_PEFT_BACKEND:
+ adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
+ else:
+ scale_lora_layers(self.text_encoder_2, lora_scale)
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # Define tokenizers and text encoders
+ tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
+ text_encoders = (
+ [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
+ )
+
+ if prompt_embeds is None:
+ prompt_2 = prompt_2 or prompt
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
+
+ # textual inversion: process multi-vector tokens if necessary
+ prompt_embeds_list = []
+ prompts = [prompt, prompt_2]
+ for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, tokenizer)
+
+ text_inputs = tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
+
+ # We are only ALWAYS interested in the pooled output of the final text encoder
+ pooled_prompt_embeds = prompt_embeds[0]
+ if clip_skip is None:
+ prompt_embeds = prompt_embeds.hidden_states[-2]
+ else:
+ # "2" because SDXL always indexes from the penultimate layer.
+ prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
+
+ prompt_embeds_list.append(prompt_embeds)
+
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
+
+ # get unconditional embeddings for classifier free guidance
+ zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt
+ if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:
+ negative_prompt_embeds = torch.zeros_like(prompt_embeds)
+ negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
+ elif do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = negative_prompt or ""
+ negative_prompt_2 = negative_prompt_2 or negative_prompt
+
+ # normalize str to list
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+ negative_prompt_2 = (
+ batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
+ )
+
+ uncond_tokens: List[str]
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = [negative_prompt, negative_prompt_2]
+
+ negative_prompt_embeds_list = []
+ for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):
+ if isinstance(self, TextualInversionLoaderMixin):
+ negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer)
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = tokenizer(
+ negative_prompt,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ negative_prompt_embeds = text_encoder(
+ uncond_input.input_ids.to(device),
+ output_hidden_states=True,
+ )
+ # We are only ALWAYS interested in the pooled output of the final text encoder
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
+ negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
+
+ negative_prompt_embeds_list.append(negative_prompt_embeds)
+
+ negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
+
+ if self.text_encoder_2 is not None:
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
+ else:
+ prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ if self.text_encoder_2 is not None:
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
+ else:
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
+ bs_embed * num_images_per_prompt, -1
+ )
+ if do_classifier_free_guidance:
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
+ bs_embed * num_images_per_prompt, -1
+ )
+
+ if self.text_encoder is not None:
+ if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder, lora_scale)
+
+ if self.text_encoder_2 is not None:
+ if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
+
+ return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(
+ self,
+ prompt,
+ prompt_2,
+ strength,
+ num_inference_steps,
+ callback_steps,
+ negative_prompt=None,
+ negative_prompt_2=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ ip_adapter_image=None,
+ ip_adapter_image_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ ):
+ if strength < 0 or strength > 1:
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
+ if num_inference_steps is None:
+ raise ValueError("`num_inference_steps` cannot be None.")
+ elif not isinstance(num_inference_steps, int) or num_inference_steps <= 0:
+ raise ValueError(
+ f"`num_inference_steps` has to be a positive integer but is {num_inference_steps} of type"
+ f" {type(num_inference_steps)}."
+ )
+ if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt_2 is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
+ raise ValueError(
+ "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
+ )
+
+ if ip_adapter_image_embeds is not None:
+ if not isinstance(ip_adapter_image_embeds, list):
+ raise ValueError(
+ f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
+ )
+ elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
+ raise ValueError(
+ f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
+ )
+
+ def get_timesteps(self, num_inference_steps, strength, device, denoising_start=None):
+ # get the original timestep using init_timestep
+ if denoising_start is None:
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
+ t_start = max(num_inference_steps - init_timestep, 0)
+ else:
+ t_start = 0
+
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
+
+ # Strength is irrelevant if we directly request a timestep to start at;
+ # that is, strength is determined by the denoising_start instead.
+ if denoising_start is not None:
+ discrete_timestep_cutoff = int(
+ round(
+ self.scheduler.config.num_train_timesteps
+ - (denoising_start * self.scheduler.config.num_train_timesteps)
+ )
+ )
+
+ num_inference_steps = (timesteps < discrete_timestep_cutoff).sum().item()
+ if self.scheduler.order == 2 and num_inference_steps % 2 == 0:
+ # if the scheduler is a 2nd order scheduler we might have to do +1
+ # because `num_inference_steps` might be even given that every timestep
+ # (except the highest one) is duplicated. If `num_inference_steps` is even it would
+ # mean that we cut the timesteps in the middle of the denoising step
+ # (between 1st and 2nd derivative) which leads to incorrect results. By adding 1
+ # we ensure that the denoising process always ends after the 2nd derivate step of the scheduler
+ num_inference_steps = num_inference_steps + 1
+
+ # because t_n+1 >= t_n, we slice the timesteps starting from the end
+ timesteps = timesteps[-num_inference_steps:]
+ return timesteps, num_inference_steps
+
+ return timesteps, num_inference_steps - t_start
+
+ def prepare_latents(
+ self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None, add_noise=True
+ ):
+ if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
+ raise ValueError(
+ f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
+ )
+
+ # Offload text encoder if `enable_model_cpu_offload` was enabled
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.text_encoder_2.to("cpu")
+ torch.cuda.empty_cache()
+
+ image = image.to(device=device, dtype=dtype)
+
+ batch_size = batch_size * num_images_per_prompt
+
+ if image.shape[1] == 4:
+ init_latents = image
+
+ else:
+ # make sure the VAE is in float32 mode, as it overflows in float16
+ if self.vae.config.force_upcast:
+ image = image.float()
+ self.vae.to(dtype=torch.float32)
+
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ elif isinstance(generator, list):
+ init_latents = [
+ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
+ for i in range(batch_size)
+ ]
+ init_latents = torch.cat(init_latents, dim=0)
+ else:
+ init_latents = retrieve_latents(self.vae.encode(image), generator=generator)
+
+ if self.vae.config.force_upcast:
+ self.vae.to(dtype)
+
+ init_latents = init_latents.to(dtype)
+ init_latents = self.vae.config.scaling_factor * init_latents
+
+ if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
+ # expand init_latents for batch_size
+ additional_image_per_prompt = batch_size // init_latents.shape[0]
+ init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0)
+ elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
+ raise ValueError(
+ f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
+ )
+ else:
+ init_latents = torch.cat([init_latents], dim=0)
+
+ if add_noise:
+ shape = init_latents.shape
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ # get latents
+ init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
+
+ latents = init_latents
+
+ return latents
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
+ def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
+ dtype = next(self.image_encoder.parameters()).dtype
+
+ if not isinstance(image, torch.Tensor):
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
+
+ image = image.to(device=device, dtype=dtype)
+ if output_hidden_states:
+ image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
+ image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
+ uncond_image_enc_hidden_states = self.image_encoder(
+ torch.zeros_like(image), output_hidden_states=True
+ ).hidden_states[-2]
+ uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
+ num_images_per_prompt, dim=0
+ )
+ return image_enc_hidden_states, uncond_image_enc_hidden_states
+ else:
+ image_embeds = self.image_encoder(image).image_embeds
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+ uncond_image_embeds = torch.zeros_like(image_embeds)
+
+ return image_embeds, uncond_image_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
+ def prepare_ip_adapter_image_embeds(
+ self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
+ ):
+ if ip_adapter_image_embeds is None:
+ if not isinstance(ip_adapter_image, list):
+ ip_adapter_image = [ip_adapter_image]
+
+ if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
+ raise ValueError(
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
+ )
+
+ image_embeds = []
+ for single_ip_adapter_image, image_proj_layer in zip(
+ ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
+ ):
+ output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
+ single_image_embeds, single_negative_image_embeds = self.encode_image(
+ single_ip_adapter_image, device, 1, output_hidden_state
+ )
+ single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
+ single_negative_image_embeds = torch.stack(
+ [single_negative_image_embeds] * num_images_per_prompt, dim=0
+ )
+
+ if do_classifier_free_guidance:
+ single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
+ single_image_embeds = single_image_embeds.to(device)
+
+ image_embeds.append(single_image_embeds)
+ else:
+ repeat_dims = [1]
+ image_embeds = []
+ for single_image_embeds in ip_adapter_image_embeds:
+ if do_classifier_free_guidance:
+ single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
+ single_image_embeds = single_image_embeds.repeat(
+ num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
+ )
+ single_negative_image_embeds = single_negative_image_embeds.repeat(
+ num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
+ )
+ single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
+ else:
+ single_image_embeds = single_image_embeds.repeat(
+ num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
+ )
+ image_embeds.append(single_image_embeds)
+
+ return image_embeds
+
+ def _get_add_time_ids(
+ self,
+ original_size,
+ crops_coords_top_left,
+ target_size,
+ aesthetic_score,
+ negative_aesthetic_score,
+ negative_original_size,
+ negative_crops_coords_top_left,
+ negative_target_size,
+ dtype,
+ text_encoder_projection_dim=None,
+ ):
+ if self.config.requires_aesthetics_score:
+ add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,))
+ add_neg_time_ids = list(
+ negative_original_size + negative_crops_coords_top_left + (negative_aesthetic_score,)
+ )
+ else:
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
+ add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size)
+
+ passed_add_embed_dim = (
+ self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
+ )
+ expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
+
+ if (
+ expected_add_embed_dim > passed_add_embed_dim
+ and (expected_add_embed_dim - passed_add_embed_dim) == self.unet.config.addition_time_embed_dim
+ ):
+ raise ValueError(
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model."
+ )
+ elif (
+ expected_add_embed_dim < passed_add_embed_dim
+ and (passed_add_embed_dim - expected_add_embed_dim) == self.unet.config.addition_time_embed_dim
+ ):
+ raise ValueError(
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model."
+ )
+ elif expected_add_embed_dim != passed_add_embed_dim:
+ raise ValueError(
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
+ )
+
+ add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
+ add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype)
+
+ return add_time_ids, add_neg_time_ids
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
+ def upcast_vae(self):
+ dtype = self.vae.dtype
+ self.vae.to(dtype=torch.float32)
+ use_torch_2_0_or_xformers = isinstance(
+ self.vae.decoder.mid_block.attentions[0].processor,
+ (
+ AttnProcessor2_0,
+ XFormersAttnProcessor,
+ ),
+ )
+ # if xformers or torch_2_0 is used attention block does not need
+ # to be in float32 which can save lots of memory
+ if use_torch_2_0_or_xformers:
+ self.vae.post_quant_conv.to(dtype)
+ self.vae.decoder.conv_in.to(dtype)
+ self.vae.decoder.mid_block.to(dtype)
+
+ # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
+ def get_guidance_scale_embedding(
+ self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32
+ ) -> torch.Tensor:
+ """
+ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
+
+ Args:
+ w (`torch.Tensor`):
+ Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings.
+ embedding_dim (`int`, *optional*, defaults to 512):
+ Dimension of the embeddings to generate.
+ dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
+ Data type of the generated embeddings.
+
+ Returns:
+ `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`.
+ """
+ assert len(w.shape) == 1
+ w = w * 1000.0
+
+ half_dim = embedding_dim // 2
+ emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
+ emb = w.to(dtype)[:, None] * emb[None, :]
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+ if embedding_dim % 2 == 1: # zero pad
+ emb = torch.nn.functional.pad(emb, (0, 1))
+ assert emb.shape == (w.shape[0], embedding_dim)
+ return emb
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def guidance_rescale(self):
+ return self._guidance_rescale
+
+ @property
+ def clip_skip(self):
+ return self._clip_skip
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
+
+ @property
+ def cross_attention_kwargs(self):
+ return self._cross_attention_kwargs
+
+ @property
+ def denoising_end(self):
+ return self._denoising_end
+
+ @property
+ def denoising_start(self):
+ return self._denoising_start
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ prompt_2: Optional[Union[str, List[str]]] = None,
+ image: Union[
+ torch.Tensor,
+ PIL.Image.Image,
+ np.ndarray,
+ List[torch.Tensor],
+ List[PIL.Image.Image],
+ List[np.ndarray],
+ ] = None,
+ strength: float = 0.3,
+ num_inference_steps: int = 50,
+ timesteps: List[int] = None,
+ denoising_start: Optional[float] = None,
+ denoising_end: Optional[float] = None,
+ guidance_scale: float = 5.0,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ pooled_prompt_embeds: Optional[torch.Tensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
+ ip_adapter_image: Optional[PipelineImageInput] = None,
+ ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ guidance_rescale: float = 0.0,
+ original_size: Tuple[int, int] = None,
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
+ target_size: Tuple[int, int] = None,
+ negative_original_size: Optional[Tuple[int, int]] = None,
+ negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
+ negative_target_size: Optional[Tuple[int, int]] = None,
+ aesthetic_score: float = 6.0,
+ negative_aesthetic_score: float = 2.5,
+ clip_skip: Optional[int] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ map: torch.Tensor = None,
+ original_image: Union[
+ torch.Tensor,
+ PIL.Image.Image,
+ np.ndarray,
+ List[torch.Tensor],
+ List[PIL.Image.Image],
+ List[np.ndarray],
+ ] = None,
+ **kwargs,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ used in both text-encoders
+ image (`torch.Tensor` or `PIL.Image.Image` or `np.ndarray` or `List[torch.Tensor]` or `List[PIL.Image.Image]` or `List[np.ndarray]`):
+ The image(s) to modify with the pipeline.
+ strength (`float`, *optional*, defaults to 0.3):
+ Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image`
+ will be used as a starting point, adding more noise to it the larger the `strength`. The number of
+ denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will
+ be maximum and the denoising process will run for the full number of iterations specified in
+ `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. Note that in the case of
+ `denoising_start` being declared as an integer, the value of `strength` will be ignored.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ denoising_start (`float`, *optional*):
+ When specified, indicates the fraction (between 0.0 and 1.0) of the total denoising process to be
+ bypassed before it is initiated. Consequently, the initial part of the denoising process is skipped and
+ it is assumed that the passed `image` is a partly denoised image. Note that when this is specified,
+ strength will be ignored. The `denoising_start` parameter is particularly beneficial when this pipeline
+ is integrated into a "Mixture of Denoisers" multi-pipeline setup, as detailed in [**Refining the Image
+ Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output).
+ denoising_end (`float`, *optional*):
+ When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
+ completed before it is intentionally prematurely terminated. As a result, the returned sample will
+ still retain a substantial amount of noise (ca. final 20% of timesteps still needed) and should be
+ denoised by a successor pipeline that has `denoising_start` set to 0.8 so that it only denoises the
+ final 20% of the scheduler. The denoising_end parameter should ideally be utilized when this pipeline
+ forms a part of a "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
+ Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output).
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ pooled_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
+ input argument.
+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
+ ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters.
+ Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding
+ if `do_classifier_free_guidance` is set to `True`.
+ If not provided, embeddings are computed from the `ip_adapter_image` input argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
+ guidance_rescale (`float`, *optional*, defaults to 0.7):
+ Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
+ [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
+ Guidance rescale factor should fix overexposure when using zero terminal SNR.
+ original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
+ `original_size` defaults to `(width, height)` if not specified. Part of SDXL's micro-conditioning as
+ explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
+ `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
+ `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
+ `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ For most cases, `target_size` should be set to the desired height and width of the generated image. If
+ not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in
+ section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ aesthetic_score (`float`, *optional*, defaults to 6.0):
+ Used to simulate an aesthetic score of the generated image by influencing the positive text condition.
+ Part of SDXL's micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ negative_aesthetic_score (`float`, *optional*, defaults to 2.5):
+ Part of SDXL's micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). Can be used to
+ simulate an aesthetic score of the generated image by influencing the negative text condition.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
+ `tuple. When returning a tuple, the first element is a list with the generated images.
+ """
+
+ callback = kwargs.pop("callback", None)
+ callback_steps = kwargs.pop("callback_steps", None)
+
+ if callback is not None:
+ deprecate(
+ "callback",
+ "1.0.0",
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
+ )
+ if callback_steps is not None:
+ deprecate(
+ "callback_steps",
+ "1.0.0",
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
+ )
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ prompt_2,
+ strength,
+ num_inference_steps,
+ callback_steps,
+ negative_prompt,
+ negative_prompt_2,
+ prompt_embeds,
+ negative_prompt_embeds,
+ ip_adapter_image,
+ ip_adapter_image_embeds,
+ callback_on_step_end_tensor_inputs,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._guidance_rescale = guidance_rescale
+ self._clip_skip = clip_skip
+ self._cross_attention_kwargs = cross_attention_kwargs
+ self._denoising_end = denoising_end
+ self._denoising_start = denoising_start
+ self._interrupt = False
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ # 3. Encode input prompt
+ text_encoder_lora_scale = (
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
+ )
+ (
+ prompt_embeds,
+ negative_prompt_embeds,
+ pooled_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ ) = self.encode_prompt(
+ prompt=prompt,
+ prompt_2=prompt_2,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ negative_prompt=negative_prompt,
+ negative_prompt_2=negative_prompt_2,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ lora_scale=text_encoder_lora_scale,
+ )
+
+ # 4. Preprocess image
+ # image = self.image_processor.preprocess(image) #ideally we would have preprocess the image with diffusers, but for this POC we won't --- it throws a deprecated warning
+ map = torchvision.transforms.Resize(
+ tuple(s // self.vae_scale_factor for s in original_image.shape[2:]), antialias=None
+ )(map)
+
+ # 5. Prepare timesteps
+ def denoising_value_valid(dnv):
+ return isinstance(dnv, float) and 0 < dnv < 1
+
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
+
+ # begin diff diff change
+ total_time_steps = num_inference_steps
+ # end diff diff change
+
+ timesteps, num_inference_steps = self.get_timesteps(
+ num_inference_steps,
+ strength,
+ device,
+ denoising_start=self.denoising_start if denoising_value_valid(self.denoising_start) else None,
+ )
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
+
+ add_noise = True if denoising_start is None else False
+ # 6. Prepare latent variables
+ latents = self.prepare_latents(
+ image,
+ latent_timestep,
+ batch_size,
+ num_images_per_prompt,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ add_noise,
+ )
+ # 7. Prepare extra step kwargs.
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ height, width = latents.shape[-2:]
+ height = height * self.vae_scale_factor
+ width = width * self.vae_scale_factor
+
+ original_size = original_size or (height, width)
+ target_size = target_size or (height, width)
+
+ # 8. Prepare added time ids & embeddings
+ if negative_original_size is None:
+ negative_original_size = original_size
+ if negative_target_size is None:
+ negative_target_size = target_size
+
+ add_text_embeds = pooled_prompt_embeds
+ if self.text_encoder_2 is None:
+ text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
+ else:
+ text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
+
+ add_time_ids, add_neg_time_ids = self._get_add_time_ids(
+ original_size,
+ crops_coords_top_left,
+ target_size,
+ aesthetic_score,
+ negative_aesthetic_score,
+ negative_original_size,
+ negative_crops_coords_top_left,
+ negative_target_size,
+ dtype=prompt_embeds.dtype,
+ text_encoder_projection_dim=text_encoder_projection_dim,
+ )
+ add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1)
+
+ if self.do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
+ add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1)
+ add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0)
+
+ prompt_embeds = prompt_embeds.to(device)
+ add_text_embeds = add_text_embeds.to(device)
+ add_time_ids = add_time_ids.to(device)
+
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
+ image_embeds = self.prepare_ip_adapter_image_embeds(
+ ip_adapter_image,
+ ip_adapter_image_embeds,
+ device,
+ batch_size * num_images_per_prompt,
+ self.do_classifier_free_guidance,
+ )
+
+ # 9. Denoising loop
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+
+ # 9.1 Apply denoising_end
+ if (
+ denoising_end is not None
+ and denoising_start is not None
+ and denoising_value_valid(denoising_end)
+ and denoising_value_valid(denoising_start)
+ and denoising_start >= denoising_end
+ ):
+ raise ValueError(
+ f"`denoising_start`: {denoising_start} cannot be larger than or equal to `denoising_end`: "
+ + f" {denoising_end} when using type float."
+ )
+ elif denoising_end is not None and denoising_value_valid(denoising_end):
+ discrete_timestep_cutoff = int(
+ round(
+ self.scheduler.config.num_train_timesteps
+ - (denoising_end * self.scheduler.config.num_train_timesteps)
+ )
+ )
+ num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
+ timesteps = timesteps[:num_inference_steps]
+
+ # preparations for diff diff
+ original_with_noise = self.prepare_latents(
+ original_image, timesteps, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator
+ )
+ thresholds = torch.arange(total_time_steps, dtype=map.dtype) / total_time_steps
+ thresholds = thresholds.unsqueeze(1).unsqueeze(1).to(device)
+ masks = map > (thresholds + (denoising_start or 0))
+ # end diff diff preparations
+
+ # 9.2 Optionally get Guidance Scale Embedding
+ timestep_cond = None
+ if self.unet.config.time_cond_proj_dim is not None:
+ guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
+ timestep_cond = self.get_guidance_scale_embedding(
+ guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
+ ).to(device=device, dtype=latents.dtype)
+
+ self._num_timesteps = len(timesteps)
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ # diff diff
+ if i == 0 and denoising_start is None:
+ latents = original_with_noise[:1]
+ else:
+ mask = masks[i].unsqueeze(0)
+ # cast mask to the same type as latents etc
+ mask = mask.to(latents.dtype)
+ mask = mask.unsqueeze(1) # fit shape
+ latents = original_with_noise[i] * mask + latents * (1 - mask)
+ # end diff diff
+
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
+ added_cond_kwargs["image_embeds"] = image_embeds
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ timestep_cond=timestep_cond,
+ cross_attention_kwargs=cross_attention_kwargs,
+ added_cond_kwargs=added_cond_kwargs,
+ return_dict=False,
+ )[0]
+
+ # perform guidance
+ if self.do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ if self.do_classifier_free_guidance and guidance_rescale > 0.0:
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents_dtype = latents.dtype
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+ if latents.dtype != latents_dtype:
+ if torch.backends.mps.is_available():
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
+ latents = latents.to(latents_dtype)
+ else:
+ raise ValueError(
+ "For the given accelerator, there seems to be an unexpected problem in type-casting. Please file an issue on the PyTorch GitHub repository. See also: https://github.com/huggingface/diffusers/pull/7446/."
+ )
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+ add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
+ negative_pooled_prompt_embeds = callback_outputs.pop(
+ "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
+ )
+ add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
+ add_neg_time_ids = callback_outputs.pop("add_neg_time_ids", add_neg_time_ids)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ if not output_type == "latent":
+ # make sure the VAE is in float32 mode, as it overflows in float16
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
+
+ if needs_upcasting:
+ self.upcast_vae()
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
+ elif latents.dtype != self.vae.dtype:
+ if torch.backends.mps.is_available():
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
+ self.vae = self.vae.to(latents.dtype)
+ else:
+ raise ValueError(
+ "For the given accelerator, there seems to be an unexpected problem in type-casting. Please file an issue on the PyTorch GitHub repository. See also: https://github.com/huggingface/diffusers/pull/7446/."
+ )
+ # unscale/denormalize the latents
+ # denormalize with the mean and std if available and not None
+ has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None
+ has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None
+ if has_latents_mean and has_latents_std:
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)
+ )
+ latents_std = (
+ torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)
+ )
+ latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean
+ else:
+ latents = latents / self.vae.config.scaling_factor
+
+ image = self.vae.decode(latents, return_dict=False)[0]
+
+ # cast back to fp16 if needed
+ if needs_upcasting:
+ self.vae.to(dtype=torch.float16)
+ else:
+ image = latents
+
+ # apply watermark if available
+ if self.watermark is not None:
+ image = self.watermark.apply_watermark(image)
+
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image,)
+
+ return StableDiffusionXLPipelineOutput(images=image)
diff --git a/diffusers/examples/community/pipeline_stable_diffusion_xl_instandid_img2img.py b/diffusers/examples/community/pipeline_stable_diffusion_xl_instandid_img2img.py
new file mode 100644
index 0000000000000000000000000000000000000000..7aeba79ae930ad6ba7d2b1550cf04c6852ec7f33
--- /dev/null
+++ b/diffusers/examples/community/pipeline_stable_diffusion_xl_instandid_img2img.py
@@ -0,0 +1,1077 @@
+# Copyright 2024 The InstantX Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import math
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import cv2
+import numpy as np
+import PIL.Image
+import torch
+import torch.nn as nn
+
+from diffusers import StableDiffusionXLControlNetImg2ImgPipeline
+from diffusers.image_processor import PipelineImageInput
+from diffusers.models import ControlNetModel
+from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
+from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput
+from diffusers.utils import (
+ deprecate,
+ logging,
+ replace_example_docstring,
+)
+from diffusers.utils.import_utils import is_xformers_available
+from diffusers.utils.torch_utils import is_compiled_module, is_torch_version
+
+
+try:
+ import xformers
+ import xformers.ops
+
+ xformers_available = True
+except Exception:
+ xformers_available = False
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+logger.warning(
+ "To use instant id pipelines, please make sure you have the `insightface` library installed: `pip install insightface`."
+ "Please refer to: https://huggingface.co/InstantX/InstantID for further instructions regarding inference"
+)
+
+
+def FeedForward(dim, mult=4):
+ inner_dim = int(dim * mult)
+ return nn.Sequential(
+ nn.LayerNorm(dim),
+ nn.Linear(dim, inner_dim, bias=False),
+ nn.GELU(),
+ nn.Linear(inner_dim, dim, bias=False),
+ )
+
+
+def reshape_tensor(x, heads):
+ bs, length, width = x.shape
+ # (bs, length, width) --> (bs, length, n_heads, dim_per_head)
+ x = x.view(bs, length, heads, -1)
+ # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
+ x = x.transpose(1, 2)
+ # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
+ x = x.reshape(bs, heads, length, -1)
+ return x
+
+
+class PerceiverAttention(nn.Module):
+ def __init__(self, *, dim, dim_head=64, heads=8):
+ super().__init__()
+ self.scale = dim_head**-0.5
+ self.dim_head = dim_head
+ self.heads = heads
+ inner_dim = dim_head * heads
+
+ self.norm1 = nn.LayerNorm(dim)
+ self.norm2 = nn.LayerNorm(dim)
+
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
+
+ def forward(self, x, latents):
+ """
+ Args:
+ x (torch.Tensor): image features
+ shape (b, n1, D)
+ latent (torch.Tensor): latent features
+ shape (b, n2, D)
+ """
+ x = self.norm1(x)
+ latents = self.norm2(latents)
+
+ b, l, _ = latents.shape
+
+ q = self.to_q(latents)
+ kv_input = torch.cat((x, latents), dim=-2)
+ k, v = self.to_kv(kv_input).chunk(2, dim=-1)
+
+ q = reshape_tensor(q, self.heads)
+ k = reshape_tensor(k, self.heads)
+ v = reshape_tensor(v, self.heads)
+
+ # attention
+ scale = 1 / math.sqrt(math.sqrt(self.dim_head))
+ weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
+ out = weight @ v
+
+ out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
+
+ return self.to_out(out)
+
+
+class Resampler(nn.Module):
+ def __init__(
+ self,
+ dim=1024,
+ depth=8,
+ dim_head=64,
+ heads=16,
+ num_queries=8,
+ embedding_dim=768,
+ output_dim=1024,
+ ff_mult=4,
+ ):
+ super().__init__()
+
+ self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
+
+ self.proj_in = nn.Linear(embedding_dim, dim)
+
+ self.proj_out = nn.Linear(dim, output_dim)
+ self.norm_out = nn.LayerNorm(output_dim)
+
+ self.layers = nn.ModuleList([])
+ for _ in range(depth):
+ self.layers.append(
+ nn.ModuleList(
+ [
+ PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
+ FeedForward(dim=dim, mult=ff_mult),
+ ]
+ )
+ )
+
+ def forward(self, x):
+ latents = self.latents.repeat(x.size(0), 1, 1)
+ x = self.proj_in(x)
+
+ for attn, ff in self.layers:
+ latents = attn(x, latents) + latents
+ latents = ff(latents) + latents
+
+ latents = self.proj_out(latents)
+ return self.norm_out(latents)
+
+
+class AttnProcessor(nn.Module):
+ r"""
+ Default processor for performing attention-related computations.
+ """
+
+ def __init__(
+ self,
+ hidden_size=None,
+ cross_attention_dim=None,
+ ):
+ super().__init__()
+
+ def __call__(
+ self,
+ attn,
+ hidden_states,
+ encoder_hidden_states=None,
+ attention_mask=None,
+ temb=None,
+ ):
+ residual = hidden_states
+
+ if attn.spatial_norm is not None:
+ hidden_states = attn.spatial_norm(hidden_states, temb)
+
+ input_ndim = hidden_states.ndim
+
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+
+ if attn.group_norm is not None:
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = attn.to_q(hidden_states)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ elif attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ query = attn.head_to_batch_dim(query)
+ key = attn.head_to_batch_dim(key)
+ value = attn.head_to_batch_dim(value)
+
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
+ hidden_states = torch.bmm(attention_probs, value)
+ hidden_states = attn.batch_to_head_dim(hidden_states)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ return hidden_states
+
+
+class IPAttnProcessor(nn.Module):
+ r"""
+ Attention processor for IP-Adapater.
+ Args:
+ hidden_size (`int`):
+ The hidden size of the attention layer.
+ cross_attention_dim (`int`):
+ The number of channels in the `encoder_hidden_states`.
+ scale (`float`, defaults to 1.0):
+ the weight scale of image prompt.
+ num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
+ The context length of the image features.
+ """
+
+ def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4):
+ super().__init__()
+
+ self.hidden_size = hidden_size
+ self.cross_attention_dim = cross_attention_dim
+ self.scale = scale
+ self.num_tokens = num_tokens
+
+ self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
+ self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
+
+ def __call__(
+ self,
+ attn,
+ hidden_states,
+ encoder_hidden_states=None,
+ attention_mask=None,
+ temb=None,
+ ):
+ residual = hidden_states
+
+ if attn.spatial_norm is not None:
+ hidden_states = attn.spatial_norm(hidden_states, temb)
+
+ input_ndim = hidden_states.ndim
+
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+
+ if attn.group_norm is not None:
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = attn.to_q(hidden_states)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ else:
+ # get encoder_hidden_states, ip_hidden_states
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens
+ encoder_hidden_states, ip_hidden_states = (
+ encoder_hidden_states[:, :end_pos, :],
+ encoder_hidden_states[:, end_pos:, :],
+ )
+ if attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ query = attn.head_to_batch_dim(query)
+ key = attn.head_to_batch_dim(key)
+ value = attn.head_to_batch_dim(value)
+
+ if xformers_available:
+ hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
+ else:
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
+ hidden_states = torch.bmm(attention_probs, value)
+ hidden_states = attn.batch_to_head_dim(hidden_states)
+
+ # for ip-adapter
+ ip_key = self.to_k_ip(ip_hidden_states)
+ ip_value = self.to_v_ip(ip_hidden_states)
+
+ ip_key = attn.head_to_batch_dim(ip_key)
+ ip_value = attn.head_to_batch_dim(ip_value)
+
+ if xformers_available:
+ ip_hidden_states = self._memory_efficient_attention_xformers(query, ip_key, ip_value, None)
+ else:
+ ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
+ ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
+ ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
+
+ hidden_states = hidden_states + self.scale * ip_hidden_states
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ return hidden_states
+
+ def _memory_efficient_attention_xformers(self, query, key, value, attention_mask):
+ # TODO attention_mask
+ query = query.contiguous()
+ key = key.contiguous()
+ value = value.contiguous()
+ hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
+ return hidden_states
+
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> # !pip install opencv-python transformers accelerate insightface
+ >>> import diffusers
+ >>> from diffusers.utils import load_image
+ >>> from diffusers.models import ControlNetModel
+
+ >>> import cv2
+ >>> import torch
+ >>> import numpy as np
+ >>> from PIL import Image
+
+ >>> from insightface.app import FaceAnalysis
+ >>> from pipeline_stable_diffusion_xl_instantid import StableDiffusionXLInstantIDPipeline, draw_kps
+
+ >>> # download 'antelopev2' under ./models
+ >>> app = FaceAnalysis(name='antelopev2', root='./', providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
+ >>> app.prepare(ctx_id=0, det_size=(640, 640))
+
+ >>> # download models under ./checkpoints
+ >>> face_adapter = f'./checkpoints/ip-adapter.bin'
+ >>> controlnet_path = f'./checkpoints/ControlNetModel'
+
+ >>> # load IdentityNet
+ >>> controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16)
+
+ >>> pipe = StableDiffusionXLInstantIDPipeline.from_pretrained(
+ ... "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, torch_dtype=torch.float16
+ ... )
+ >>> pipe.cuda()
+
+ >>> # load adapter
+ >>> pipe.load_ip_adapter_instantid(face_adapter)
+
+ >>> prompt = "analog film photo of a man. faded film, desaturated, 35mm photo, grainy, vignette, vintage, Kodachrome, Lomography, stained, highly detailed, found footage, masterpiece, best quality"
+ >>> negative_prompt = "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured (lowres, low quality, worst quality:1.2), (text:1.2), watermark, painting, drawing, illustration, glitch,deformed, mutated, cross-eyed, ugly, disfigured"
+
+ >>> # load an image
+ >>> image = load_image("your-example.jpg")
+
+ >>> face_info = app.get(cv2.cvtColor(np.array(face_image), cv2.COLOR_RGB2BGR))[-1]
+ >>> face_emb = face_info['embedding']
+ >>> face_kps = draw_kps(face_image, face_info['kps'])
+
+ >>> pipe.set_ip_adapter_scale(0.8)
+
+ >>> # generate image
+ >>> image = pipe(
+ ... prompt, image_embeds=face_emb, image=face_kps, controlnet_conditioning_scale=0.8
+ ... ).images[0]
+ ```
+"""
+
+
+def draw_kps(image_pil, kps, color_list=[(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (255, 0, 255)]):
+ stickwidth = 4
+ limbSeq = np.array([[0, 2], [1, 2], [3, 2], [4, 2]])
+ kps = np.array(kps)
+
+ w, h = image_pil.size
+ out_img = np.zeros([h, w, 3])
+
+ for i in range(len(limbSeq)):
+ index = limbSeq[i]
+ color = color_list[index[0]]
+
+ x = kps[index][:, 0]
+ y = kps[index][:, 1]
+ length = ((x[0] - x[1]) ** 2 + (y[0] - y[1]) ** 2) ** 0.5
+ angle = math.degrees(math.atan2(y[0] - y[1], x[0] - x[1]))
+ polygon = cv2.ellipse2Poly(
+ (int(np.mean(x)), int(np.mean(y))), (int(length / 2), stickwidth), int(angle), 0, 360, 1
+ )
+ out_img = cv2.fillConvexPoly(out_img.copy(), polygon, color)
+ out_img = (out_img * 0.6).astype(np.uint8)
+
+ for idx_kp, kp in enumerate(kps):
+ color = color_list[idx_kp]
+ x, y = kp
+ out_img = cv2.circle(out_img.copy(), (int(x), int(y)), 10, color, -1)
+
+ out_img_pil = PIL.Image.fromarray(out_img.astype(np.uint8))
+ return out_img_pil
+
+
+class StableDiffusionXLInstantIDImg2ImgPipeline(StableDiffusionXLControlNetImg2ImgPipeline):
+ def cuda(self, dtype=torch.float16, use_xformers=False):
+ self.to("cuda", dtype)
+
+ if hasattr(self, "image_proj_model"):
+ self.image_proj_model.to(self.unet.device).to(self.unet.dtype)
+
+ if use_xformers:
+ if is_xformers_available():
+ import xformers
+ from packaging import version
+
+ xformers_version = version.parse(xformers.__version__)
+ if xformers_version == version.parse("0.0.16"):
+ logger.warning(
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
+ )
+ self.enable_xformers_memory_efficient_attention()
+ else:
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
+
+ def load_ip_adapter_instantid(self, model_ckpt, image_emb_dim=512, num_tokens=16, scale=0.5):
+ self.set_image_proj_model(model_ckpt, image_emb_dim, num_tokens)
+ self.set_ip_adapter(model_ckpt, num_tokens, scale)
+
+ def set_image_proj_model(self, model_ckpt, image_emb_dim=512, num_tokens=16):
+ image_proj_model = Resampler(
+ dim=1280,
+ depth=4,
+ dim_head=64,
+ heads=20,
+ num_queries=num_tokens,
+ embedding_dim=image_emb_dim,
+ output_dim=self.unet.config.cross_attention_dim,
+ ff_mult=4,
+ )
+
+ image_proj_model.eval()
+
+ self.image_proj_model = image_proj_model.to(self.device, dtype=self.dtype)
+ state_dict = torch.load(model_ckpt, map_location="cpu")
+ if "image_proj" in state_dict:
+ state_dict = state_dict["image_proj"]
+ self.image_proj_model.load_state_dict(state_dict)
+
+ self.image_proj_model_in_features = image_emb_dim
+
+ def set_ip_adapter(self, model_ckpt, num_tokens, scale):
+ unet = self.unet
+ attn_procs = {}
+ for name in unet.attn_processors.keys():
+ cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
+ if name.startswith("mid_block"):
+ hidden_size = unet.config.block_out_channels[-1]
+ elif name.startswith("up_blocks"):
+ block_id = int(name[len("up_blocks.")])
+ hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
+ elif name.startswith("down_blocks"):
+ block_id = int(name[len("down_blocks.")])
+ hidden_size = unet.config.block_out_channels[block_id]
+ if cross_attention_dim is None:
+ attn_procs[name] = AttnProcessor().to(unet.device, dtype=unet.dtype)
+ else:
+ attn_procs[name] = IPAttnProcessor(
+ hidden_size=hidden_size,
+ cross_attention_dim=cross_attention_dim,
+ scale=scale,
+ num_tokens=num_tokens,
+ ).to(unet.device, dtype=unet.dtype)
+ unet.set_attn_processor(attn_procs)
+
+ state_dict = torch.load(model_ckpt, map_location="cpu")
+ ip_layers = torch.nn.ModuleList(self.unet.attn_processors.values())
+ if "ip_adapter" in state_dict:
+ state_dict = state_dict["ip_adapter"]
+ ip_layers.load_state_dict(state_dict)
+
+ def set_ip_adapter_scale(self, scale):
+ unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
+ for attn_processor in unet.attn_processors.values():
+ if isinstance(attn_processor, IPAttnProcessor):
+ attn_processor.scale = scale
+
+ def _encode_prompt_image_emb(self, prompt_image_emb, device, dtype, do_classifier_free_guidance):
+ if isinstance(prompt_image_emb, torch.Tensor):
+ prompt_image_emb = prompt_image_emb.clone().detach()
+ else:
+ prompt_image_emb = torch.tensor(prompt_image_emb)
+
+ prompt_image_emb = prompt_image_emb.to(device=device, dtype=dtype)
+ prompt_image_emb = prompt_image_emb.reshape([1, -1, self.image_proj_model_in_features])
+
+ if do_classifier_free_guidance:
+ prompt_image_emb = torch.cat([torch.zeros_like(prompt_image_emb), prompt_image_emb], dim=0)
+ else:
+ prompt_image_emb = torch.cat([prompt_image_emb], dim=0)
+
+ prompt_image_emb = self.image_proj_model(prompt_image_emb)
+ return prompt_image_emb
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ prompt_2: Optional[Union[str, List[str]]] = None,
+ image: PipelineImageInput = None,
+ control_image: PipelineImageInput = None,
+ strength: float = 0.8,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 5.0,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ pooled_prompt_embeds: Optional[torch.Tensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
+ image_embeds: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
+ guess_mode: bool = False,
+ control_guidance_start: Union[float, List[float]] = 0.0,
+ control_guidance_end: Union[float, List[float]] = 1.0,
+ original_size: Tuple[int, int] = None,
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
+ target_size: Tuple[int, int] = None,
+ negative_original_size: Optional[Tuple[int, int]] = None,
+ negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
+ negative_target_size: Optional[Tuple[int, int]] = None,
+ aesthetic_score: float = 6.0,
+ negative_aesthetic_score: float = 2.5,
+ clip_skip: Optional[int] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ **kwargs,
+ ):
+ r"""
+ The call function to the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ used in both text-encoders.
+ image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
+ `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
+ The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
+ specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be
+ accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height
+ and/or width are passed, `image` is resized accordingly. If multiple ControlNets are specified in
+ `init`, images must be passed as a list such that each element of the list can be correctly batched for
+ input to a single ControlNet.
+ height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
+ The height in pixels of the generated image. Anything below 512 pixels won't work well for
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
+ and checkpoints that are not specifically fine-tuned on low resolutions.
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
+ The width in pixels of the generated image. Anything below 512 pixels won't work well for
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
+ and checkpoints that are not specifically fine-tuned on low resolutions.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 5.0):
+ A higher guidance scale value encourages the model to generate images closely linked to the text
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide what to not include in image generation. This is sent to `tokenizer_2`
+ and `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
+ generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor is generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
+ provided, text embeddings are generated from the `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
+ pooled_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
+ not provided, pooled text embeddings are generated from `prompt` input argument.
+ negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs (prompt
+ weighting). If not provided, pooled `negative_prompt_embeds` are generated from `negative_prompt` input
+ argument.
+ image_embeds (`torch.Tensor`, *optional*):
+ Pre-generated image embeddings.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
+ The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added
+ to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set
+ the corresponding scale as a list.
+ guess_mode (`bool`, *optional*, defaults to `False`):
+ The ControlNet encoder tries to recognize the content of the input image even if you remove all
+ prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended.
+ control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
+ The percentage of total steps at which the ControlNet starts applying.
+ control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
+ The percentage of total steps at which the ControlNet stops applying.
+ original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
+ `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
+ explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
+ `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
+ `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
+ `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ For most cases, `target_size` should be set to the desired height and width of the generated image. If
+ not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in
+ section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ To negatively condition the generation process based on a specific image resolution. Part of SDXL's
+ micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
+ negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
+ To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
+ micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
+ negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ To negatively condition the generation process based on a target image resolution. It should be as same
+ as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
+ otherwise a `tuple` is returned containing the output images.
+ """
+
+ callback = kwargs.pop("callback", None)
+ callback_steps = kwargs.pop("callback_steps", None)
+
+ if callback is not None:
+ deprecate(
+ "callback",
+ "1.0.0",
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
+ )
+ if callback_steps is not None:
+ deprecate(
+ "callback_steps",
+ "1.0.0",
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
+ )
+
+ controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
+
+ # align format for control guidance
+ if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
+ control_guidance_start = len(control_guidance_end) * [control_guidance_start]
+ elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
+ control_guidance_end = len(control_guidance_start) * [control_guidance_end]
+ elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
+ mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
+ control_guidance_start, control_guidance_end = (
+ mult * [control_guidance_start],
+ mult * [control_guidance_end],
+ )
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ prompt_2,
+ control_image,
+ strength,
+ num_inference_steps,
+ callback_steps,
+ negative_prompt,
+ negative_prompt_2,
+ prompt_embeds,
+ negative_prompt_embeds,
+ pooled_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ None,
+ None,
+ controlnet_conditioning_scale,
+ control_guidance_start,
+ control_guidance_end,
+ callback_on_step_end_tensor_inputs,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._clip_skip = clip_skip
+ self._cross_attention_kwargs = cross_attention_kwargs
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
+ controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
+
+ global_pool_conditions = (
+ controlnet.config.global_pool_conditions
+ if isinstance(controlnet, ControlNetModel)
+ else controlnet.nets[0].config.global_pool_conditions
+ )
+ guess_mode = guess_mode or global_pool_conditions
+
+ # 3.1 Encode input prompt
+ text_encoder_lora_scale = (
+ self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
+ )
+ (
+ prompt_embeds,
+ negative_prompt_embeds,
+ pooled_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ ) = self.encode_prompt(
+ prompt,
+ prompt_2,
+ device,
+ num_images_per_prompt,
+ self.do_classifier_free_guidance,
+ negative_prompt,
+ negative_prompt_2,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ lora_scale=text_encoder_lora_scale,
+ clip_skip=self.clip_skip,
+ )
+
+ # 3.2 Encode image prompt
+ prompt_image_emb = self._encode_prompt_image_emb(
+ image_embeds, device, self.unet.dtype, self.do_classifier_free_guidance
+ )
+ bs_embed, seq_len, _ = prompt_image_emb.shape
+ prompt_image_emb = prompt_image_emb.repeat(1, num_images_per_prompt, 1)
+ prompt_image_emb = prompt_image_emb.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # 4. Prepare image and controlnet_conditioning_image
+ image = self.image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
+
+ if isinstance(controlnet, ControlNetModel):
+ control_image = self.prepare_control_image(
+ image=control_image,
+ width=width,
+ height=height,
+ batch_size=batch_size * num_images_per_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ dtype=controlnet.dtype,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ guess_mode=guess_mode,
+ )
+ height, width = control_image.shape[-2:]
+ elif isinstance(controlnet, MultiControlNetModel):
+ control_images = []
+
+ for control_image_ in control_image:
+ control_image_ = self.prepare_control_image(
+ image=control_image_,
+ width=width,
+ height=height,
+ batch_size=batch_size * num_images_per_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ dtype=controlnet.dtype,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ guess_mode=guess_mode,
+ )
+
+ control_images.append(control_image_)
+
+ control_image = control_images
+ height, width = control_image[0].shape[-2:]
+ else:
+ assert False
+
+ # 5. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
+ self._num_timesteps = len(timesteps)
+
+ # 6. Prepare latent variables
+ latents = self.prepare_latents(
+ image,
+ latent_timestep,
+ batch_size,
+ num_images_per_prompt,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ True,
+ )
+
+ # # 6.5 Optionally get Guidance Scale Embedding
+ timestep_cond = None
+ if self.unet.config.time_cond_proj_dim is not None:
+ guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
+ timestep_cond = self.get_guidance_scale_embedding(
+ guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
+ ).to(device=device, dtype=latents.dtype)
+
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 7.1 Create tensor stating which controlnets to keep
+ controlnet_keep = []
+ for i in range(len(timesteps)):
+ keeps = [
+ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
+ for s, e in zip(control_guidance_start, control_guidance_end)
+ ]
+ controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)
+
+ # 7.2 Prepare added time ids & embeddings
+ if isinstance(control_image, list):
+ original_size = original_size or control_image[0].shape[-2:]
+ else:
+ original_size = original_size or control_image.shape[-2:]
+ target_size = target_size or (height, width)
+
+ if negative_original_size is None:
+ negative_original_size = original_size
+ if negative_target_size is None:
+ negative_target_size = target_size
+ add_text_embeds = pooled_prompt_embeds
+
+ if self.text_encoder_2 is None:
+ text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
+ else:
+ text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
+
+ add_time_ids, add_neg_time_ids = self._get_add_time_ids(
+ original_size,
+ crops_coords_top_left,
+ target_size,
+ aesthetic_score,
+ negative_aesthetic_score,
+ negative_original_size,
+ negative_crops_coords_top_left,
+ negative_target_size,
+ dtype=prompt_embeds.dtype,
+ text_encoder_projection_dim=text_encoder_projection_dim,
+ )
+ add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1)
+
+ if self.do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
+ add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1)
+ add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0)
+
+ prompt_embeds = prompt_embeds.to(device)
+ add_text_embeds = add_text_embeds.to(device)
+ add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
+ encoder_hidden_states = torch.cat([prompt_embeds, prompt_image_emb], dim=1)
+
+ # 8. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ is_unet_compiled = is_compiled_module(self.unet)
+ is_controlnet_compiled = is_compiled_module(self.controlnet)
+ is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1")
+
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # Relevant thread:
+ # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
+ if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1:
+ torch._inductor.cudagraph_mark_step_begin()
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
+
+ # controlnet(s) inference
+ if guess_mode and self.do_classifier_free_guidance:
+ # Infer ControlNet only for the conditional batch.
+ control_model_input = latents
+ control_model_input = self.scheduler.scale_model_input(control_model_input, t)
+ controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
+ controlnet_added_cond_kwargs = {
+ "text_embeds": add_text_embeds.chunk(2)[1],
+ "time_ids": add_time_ids.chunk(2)[1],
+ }
+ else:
+ control_model_input = latent_model_input
+ controlnet_prompt_embeds = prompt_embeds
+ controlnet_added_cond_kwargs = added_cond_kwargs
+
+ if isinstance(controlnet_keep[i], list):
+ cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
+ else:
+ controlnet_cond_scale = controlnet_conditioning_scale
+ if isinstance(controlnet_cond_scale, list):
+ controlnet_cond_scale = controlnet_cond_scale[0]
+ cond_scale = controlnet_cond_scale * controlnet_keep[i]
+
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
+ control_model_input,
+ t,
+ encoder_hidden_states=prompt_image_emb,
+ controlnet_cond=control_image,
+ conditioning_scale=cond_scale,
+ guess_mode=guess_mode,
+ added_cond_kwargs=controlnet_added_cond_kwargs,
+ return_dict=False,
+ )
+
+ if guess_mode and self.do_classifier_free_guidance:
+ # Inferred ControlNet only for the conditional batch.
+ # To apply the output of ControlNet to both the unconditional and conditional batches,
+ # add 0 to the unconditional batch to keep it unchanged.
+ down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
+ mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
+
+ # predict the noise residual
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=encoder_hidden_states,
+ timestep_cond=timestep_cond,
+ cross_attention_kwargs=self.cross_attention_kwargs,
+ down_block_additional_residuals=down_block_res_samples,
+ mid_block_additional_residual=mid_block_res_sample,
+ added_cond_kwargs=added_cond_kwargs,
+ return_dict=False,
+ )[0]
+
+ # perform guidance
+ if self.do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
+
+ if not output_type == "latent":
+ # make sure the VAE is in float32 mode, as it overflows in float16
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
+ if needs_upcasting:
+ self.upcast_vae()
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
+
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+
+ # cast back to fp16 if needed
+ if needs_upcasting:
+ self.vae.to(dtype=torch.float16)
+ else:
+ image = latents
+
+ if not output_type == "latent":
+ # apply watermark if available
+ if self.watermark is not None:
+ image = self.watermark.apply_watermark(image)
+
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image,)
+
+ return StableDiffusionXLPipelineOutput(images=image)
diff --git a/diffusers/examples/community/pipeline_stable_diffusion_xl_instantid.py b/diffusers/examples/community/pipeline_stable_diffusion_xl_instantid.py
new file mode 100644
index 0000000000000000000000000000000000000000..2eead8861e4d760dfe136f495570ed2c0781ecf5
--- /dev/null
+++ b/diffusers/examples/community/pipeline_stable_diffusion_xl_instantid.py
@@ -0,0 +1,1066 @@
+# Copyright 2024 The InstantX Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import math
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import cv2
+import numpy as np
+import PIL.Image
+import torch
+import torch.nn as nn
+
+from diffusers import StableDiffusionXLControlNetPipeline
+from diffusers.image_processor import PipelineImageInput
+from diffusers.models import ControlNetModel
+from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
+from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput
+from diffusers.utils import (
+ deprecate,
+ logging,
+ replace_example_docstring,
+)
+from diffusers.utils.import_utils import is_xformers_available
+from diffusers.utils.torch_utils import is_compiled_module, is_torch_version
+
+
+try:
+ import xformers
+ import xformers.ops
+
+ xformers_available = True
+except Exception:
+ xformers_available = False
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+logger.warning(
+ "To use instant id pipelines, please make sure you have the `insightface` library installed: `pip install insightface`."
+ "Please refer to: https://huggingface.co/InstantX/InstantID for further instructions regarding inference"
+)
+
+
+def FeedForward(dim, mult=4):
+ inner_dim = int(dim * mult)
+ return nn.Sequential(
+ nn.LayerNorm(dim),
+ nn.Linear(dim, inner_dim, bias=False),
+ nn.GELU(),
+ nn.Linear(inner_dim, dim, bias=False),
+ )
+
+
+def reshape_tensor(x, heads):
+ bs, length, width = x.shape
+ # (bs, length, width) --> (bs, length, n_heads, dim_per_head)
+ x = x.view(bs, length, heads, -1)
+ # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
+ x = x.transpose(1, 2)
+ # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
+ x = x.reshape(bs, heads, length, -1)
+ return x
+
+
+class PerceiverAttention(nn.Module):
+ def __init__(self, *, dim, dim_head=64, heads=8):
+ super().__init__()
+ self.scale = dim_head**-0.5
+ self.dim_head = dim_head
+ self.heads = heads
+ inner_dim = dim_head * heads
+
+ self.norm1 = nn.LayerNorm(dim)
+ self.norm2 = nn.LayerNorm(dim)
+
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
+
+ def forward(self, x, latents):
+ """
+ Args:
+ x (torch.Tensor): image features
+ shape (b, n1, D)
+ latent (torch.Tensor): latent features
+ shape (b, n2, D)
+ """
+ x = self.norm1(x)
+ latents = self.norm2(latents)
+
+ b, l, _ = latents.shape
+
+ q = self.to_q(latents)
+ kv_input = torch.cat((x, latents), dim=-2)
+ k, v = self.to_kv(kv_input).chunk(2, dim=-1)
+
+ q = reshape_tensor(q, self.heads)
+ k = reshape_tensor(k, self.heads)
+ v = reshape_tensor(v, self.heads)
+
+ # attention
+ scale = 1 / math.sqrt(math.sqrt(self.dim_head))
+ weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
+ out = weight @ v
+
+ out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
+
+ return self.to_out(out)
+
+
+class Resampler(nn.Module):
+ def __init__(
+ self,
+ dim=1024,
+ depth=8,
+ dim_head=64,
+ heads=16,
+ num_queries=8,
+ embedding_dim=768,
+ output_dim=1024,
+ ff_mult=4,
+ ):
+ super().__init__()
+
+ self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
+
+ self.proj_in = nn.Linear(embedding_dim, dim)
+
+ self.proj_out = nn.Linear(dim, output_dim)
+ self.norm_out = nn.LayerNorm(output_dim)
+
+ self.layers = nn.ModuleList([])
+ for _ in range(depth):
+ self.layers.append(
+ nn.ModuleList(
+ [
+ PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
+ FeedForward(dim=dim, mult=ff_mult),
+ ]
+ )
+ )
+
+ def forward(self, x):
+ latents = self.latents.repeat(x.size(0), 1, 1)
+ x = self.proj_in(x)
+
+ for attn, ff in self.layers:
+ latents = attn(x, latents) + latents
+ latents = ff(latents) + latents
+
+ latents = self.proj_out(latents)
+ return self.norm_out(latents)
+
+
+class AttnProcessor(nn.Module):
+ r"""
+ Default processor for performing attention-related computations.
+ """
+
+ def __init__(
+ self,
+ hidden_size=None,
+ cross_attention_dim=None,
+ ):
+ super().__init__()
+
+ def __call__(
+ self,
+ attn,
+ hidden_states,
+ encoder_hidden_states=None,
+ attention_mask=None,
+ temb=None,
+ ):
+ residual = hidden_states
+
+ if attn.spatial_norm is not None:
+ hidden_states = attn.spatial_norm(hidden_states, temb)
+
+ input_ndim = hidden_states.ndim
+
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+
+ if attn.group_norm is not None:
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = attn.to_q(hidden_states)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ elif attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ query = attn.head_to_batch_dim(query)
+ key = attn.head_to_batch_dim(key)
+ value = attn.head_to_batch_dim(value)
+
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
+ hidden_states = torch.bmm(attention_probs, value)
+ hidden_states = attn.batch_to_head_dim(hidden_states)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ return hidden_states
+
+
+class IPAttnProcessor(nn.Module):
+ r"""
+ Attention processor for IP-Adapater.
+ Args:
+ hidden_size (`int`):
+ The hidden size of the attention layer.
+ cross_attention_dim (`int`):
+ The number of channels in the `encoder_hidden_states`.
+ scale (`float`, defaults to 1.0):
+ the weight scale of image prompt.
+ num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
+ The context length of the image features.
+ """
+
+ def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4):
+ super().__init__()
+
+ self.hidden_size = hidden_size
+ self.cross_attention_dim = cross_attention_dim
+ self.scale = scale
+ self.num_tokens = num_tokens
+
+ self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
+ self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
+
+ def __call__(
+ self,
+ attn,
+ hidden_states,
+ encoder_hidden_states=None,
+ attention_mask=None,
+ temb=None,
+ ):
+ residual = hidden_states
+
+ if attn.spatial_norm is not None:
+ hidden_states = attn.spatial_norm(hidden_states, temb)
+
+ input_ndim = hidden_states.ndim
+
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+
+ if attn.group_norm is not None:
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = attn.to_q(hidden_states)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ else:
+ # get encoder_hidden_states, ip_hidden_states
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens
+ encoder_hidden_states, ip_hidden_states = (
+ encoder_hidden_states[:, :end_pos, :],
+ encoder_hidden_states[:, end_pos:, :],
+ )
+ if attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ query = attn.head_to_batch_dim(query)
+ key = attn.head_to_batch_dim(key)
+ value = attn.head_to_batch_dim(value)
+
+ if xformers_available:
+ hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
+ else:
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
+ hidden_states = torch.bmm(attention_probs, value)
+ hidden_states = attn.batch_to_head_dim(hidden_states)
+
+ # for ip-adapter
+ ip_key = self.to_k_ip(ip_hidden_states)
+ ip_value = self.to_v_ip(ip_hidden_states)
+
+ ip_key = attn.head_to_batch_dim(ip_key)
+ ip_value = attn.head_to_batch_dim(ip_value)
+
+ if xformers_available:
+ ip_hidden_states = self._memory_efficient_attention_xformers(query, ip_key, ip_value, None)
+ else:
+ ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
+ ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
+ ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
+
+ hidden_states = hidden_states + self.scale * ip_hidden_states
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ return hidden_states
+
+ def _memory_efficient_attention_xformers(self, query, key, value, attention_mask):
+ # TODO attention_mask
+ query = query.contiguous()
+ key = key.contiguous()
+ value = value.contiguous()
+ hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
+ return hidden_states
+
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> # !pip install opencv-python transformers accelerate insightface
+ >>> import diffusers
+ >>> from diffusers.utils import load_image
+ >>> from diffusers.models import ControlNetModel
+
+ >>> import cv2
+ >>> import torch
+ >>> import numpy as np
+ >>> from PIL import Image
+
+ >>> from insightface.app import FaceAnalysis
+ >>> from pipeline_stable_diffusion_xl_instantid import StableDiffusionXLInstantIDPipeline, draw_kps
+
+ >>> # download 'antelopev2' under ./models
+ >>> app = FaceAnalysis(name='antelopev2', root='./', providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
+ >>> app.prepare(ctx_id=0, det_size=(640, 640))
+
+ >>> # download models under ./checkpoints
+ >>> face_adapter = f'./checkpoints/ip-adapter.bin'
+ >>> controlnet_path = f'./checkpoints/ControlNetModel'
+
+ >>> # load IdentityNet
+ >>> controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16)
+
+ >>> pipe = StableDiffusionXLInstantIDPipeline.from_pretrained(
+ ... "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, torch_dtype=torch.float16
+ ... )
+ >>> pipe.cuda()
+
+ >>> # load adapter
+ >>> pipe.load_ip_adapter_instantid(face_adapter)
+
+ >>> prompt = "analog film photo of a man. faded film, desaturated, 35mm photo, grainy, vignette, vintage, Kodachrome, Lomography, stained, highly detailed, found footage, masterpiece, best quality"
+ >>> negative_prompt = "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured (lowres, low quality, worst quality:1.2), (text:1.2), watermark, painting, drawing, illustration, glitch,deformed, mutated, cross-eyed, ugly, disfigured"
+
+ >>> # load an image
+ >>> image = load_image("your-example.jpg")
+
+ >>> face_info = app.get(cv2.cvtColor(np.array(face_image), cv2.COLOR_RGB2BGR))[-1]
+ >>> face_emb = face_info['embedding']
+ >>> face_kps = draw_kps(face_image, face_info['kps'])
+
+ >>> pipe.set_ip_adapter_scale(0.8)
+
+ >>> # generate image
+ >>> image = pipe(
+ ... prompt, image_embeds=face_emb, image=face_kps, controlnet_conditioning_scale=0.8
+ ... ).images[0]
+ ```
+"""
+
+
+def draw_kps(image_pil, kps, color_list=[(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (255, 0, 255)]):
+ stickwidth = 4
+ limbSeq = np.array([[0, 2], [1, 2], [3, 2], [4, 2]])
+ kps = np.array(kps)
+
+ w, h = image_pil.size
+ out_img = np.zeros([h, w, 3])
+
+ for i in range(len(limbSeq)):
+ index = limbSeq[i]
+ color = color_list[index[0]]
+
+ x = kps[index][:, 0]
+ y = kps[index][:, 1]
+ length = ((x[0] - x[1]) ** 2 + (y[0] - y[1]) ** 2) ** 0.5
+ angle = math.degrees(math.atan2(y[0] - y[1], x[0] - x[1]))
+ polygon = cv2.ellipse2Poly(
+ (int(np.mean(x)), int(np.mean(y))), (int(length / 2), stickwidth), int(angle), 0, 360, 1
+ )
+ out_img = cv2.fillConvexPoly(out_img.copy(), polygon, color)
+ out_img = (out_img * 0.6).astype(np.uint8)
+
+ for idx_kp, kp in enumerate(kps):
+ color = color_list[idx_kp]
+ x, y = kp
+ out_img = cv2.circle(out_img.copy(), (int(x), int(y)), 10, color, -1)
+
+ out_img_pil = PIL.Image.fromarray(out_img.astype(np.uint8))
+ return out_img_pil
+
+
+class StableDiffusionXLInstantIDPipeline(StableDiffusionXLControlNetPipeline):
+ def cuda(self, dtype=torch.float16, use_xformers=False):
+ self.to("cuda", dtype)
+
+ if hasattr(self, "image_proj_model"):
+ self.image_proj_model.to(self.unet.device).to(self.unet.dtype)
+
+ if use_xformers:
+ if is_xformers_available():
+ import xformers
+ from packaging import version
+
+ xformers_version = version.parse(xformers.__version__)
+ if xformers_version == version.parse("0.0.16"):
+ logger.warning(
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
+ )
+ self.enable_xformers_memory_efficient_attention()
+ else:
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
+
+ def load_ip_adapter_instantid(self, model_ckpt, image_emb_dim=512, num_tokens=16, scale=0.5):
+ self.set_image_proj_model(model_ckpt, image_emb_dim, num_tokens)
+ self.set_ip_adapter(model_ckpt, num_tokens, scale)
+
+ def set_image_proj_model(self, model_ckpt, image_emb_dim=512, num_tokens=16):
+ image_proj_model = Resampler(
+ dim=1280,
+ depth=4,
+ dim_head=64,
+ heads=20,
+ num_queries=num_tokens,
+ embedding_dim=image_emb_dim,
+ output_dim=self.unet.config.cross_attention_dim,
+ ff_mult=4,
+ )
+
+ image_proj_model.eval()
+
+ self.image_proj_model = image_proj_model.to(self.device, dtype=self.dtype)
+ state_dict = torch.load(model_ckpt, map_location="cpu")
+ if "image_proj" in state_dict:
+ state_dict = state_dict["image_proj"]
+ self.image_proj_model.load_state_dict(state_dict)
+
+ self.image_proj_model_in_features = image_emb_dim
+
+ def set_ip_adapter(self, model_ckpt, num_tokens, scale):
+ unet = self.unet
+ attn_procs = {}
+ for name in unet.attn_processors.keys():
+ cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
+ if name.startswith("mid_block"):
+ hidden_size = unet.config.block_out_channels[-1]
+ elif name.startswith("up_blocks"):
+ block_id = int(name[len("up_blocks.")])
+ hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
+ elif name.startswith("down_blocks"):
+ block_id = int(name[len("down_blocks.")])
+ hidden_size = unet.config.block_out_channels[block_id]
+ if cross_attention_dim is None:
+ attn_procs[name] = AttnProcessor().to(unet.device, dtype=unet.dtype)
+ else:
+ attn_procs[name] = IPAttnProcessor(
+ hidden_size=hidden_size,
+ cross_attention_dim=cross_attention_dim,
+ scale=scale,
+ num_tokens=num_tokens,
+ ).to(unet.device, dtype=unet.dtype)
+ unet.set_attn_processor(attn_procs)
+
+ state_dict = torch.load(model_ckpt, map_location="cpu")
+ ip_layers = torch.nn.ModuleList(self.unet.attn_processors.values())
+ if "ip_adapter" in state_dict:
+ state_dict = state_dict["ip_adapter"]
+ ip_layers.load_state_dict(state_dict)
+
+ def set_ip_adapter_scale(self, scale):
+ unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
+ for attn_processor in unet.attn_processors.values():
+ if isinstance(attn_processor, IPAttnProcessor):
+ attn_processor.scale = scale
+
+ def _encode_prompt_image_emb(self, prompt_image_emb, device, dtype, do_classifier_free_guidance):
+ if isinstance(prompt_image_emb, torch.Tensor):
+ prompt_image_emb = prompt_image_emb.clone().detach()
+ else:
+ prompt_image_emb = torch.tensor(prompt_image_emb)
+
+ prompt_image_emb = prompt_image_emb.to(device=device, dtype=dtype)
+ prompt_image_emb = prompt_image_emb.reshape([1, -1, self.image_proj_model_in_features])
+
+ if do_classifier_free_guidance:
+ prompt_image_emb = torch.cat([torch.zeros_like(prompt_image_emb), prompt_image_emb], dim=0)
+ else:
+ prompt_image_emb = torch.cat([prompt_image_emb], dim=0)
+
+ prompt_image_emb = self.image_proj_model(prompt_image_emb)
+ return prompt_image_emb
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ prompt_2: Optional[Union[str, List[str]]] = None,
+ image: PipelineImageInput = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 5.0,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ pooled_prompt_embeds: Optional[torch.Tensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
+ image_embeds: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
+ guess_mode: bool = False,
+ control_guidance_start: Union[float, List[float]] = 0.0,
+ control_guidance_end: Union[float, List[float]] = 1.0,
+ original_size: Tuple[int, int] = None,
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
+ target_size: Tuple[int, int] = None,
+ negative_original_size: Optional[Tuple[int, int]] = None,
+ negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
+ negative_target_size: Optional[Tuple[int, int]] = None,
+ clip_skip: Optional[int] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ **kwargs,
+ ):
+ r"""
+ The call function to the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ used in both text-encoders.
+ image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
+ `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
+ The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
+ specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be
+ accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height
+ and/or width are passed, `image` is resized accordingly. If multiple ControlNets are specified in
+ `init`, images must be passed as a list such that each element of the list can be correctly batched for
+ input to a single ControlNet.
+ height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
+ The height in pixels of the generated image. Anything below 512 pixels won't work well for
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
+ and checkpoints that are not specifically fine-tuned on low resolutions.
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
+ The width in pixels of the generated image. Anything below 512 pixels won't work well for
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
+ and checkpoints that are not specifically fine-tuned on low resolutions.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 5.0):
+ A higher guidance scale value encourages the model to generate images closely linked to the text
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide what to not include in image generation. This is sent to `tokenizer_2`
+ and `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
+ generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor is generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
+ provided, text embeddings are generated from the `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
+ pooled_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
+ not provided, pooled text embeddings are generated from `prompt` input argument.
+ negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs (prompt
+ weighting). If not provided, pooled `negative_prompt_embeds` are generated from `negative_prompt` input
+ argument.
+ image_embeds (`torch.Tensor`, *optional*):
+ Pre-generated image embeddings.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
+ The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added
+ to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set
+ the corresponding scale as a list.
+ guess_mode (`bool`, *optional*, defaults to `False`):
+ The ControlNet encoder tries to recognize the content of the input image even if you remove all
+ prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended.
+ control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
+ The percentage of total steps at which the ControlNet starts applying.
+ control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
+ The percentage of total steps at which the ControlNet stops applying.
+ original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
+ `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
+ explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
+ `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
+ `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
+ `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ For most cases, `target_size` should be set to the desired height and width of the generated image. If
+ not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in
+ section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ To negatively condition the generation process based on a specific image resolution. Part of SDXL's
+ micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
+ negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
+ To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
+ micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
+ negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ To negatively condition the generation process based on a target image resolution. It should be as same
+ as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
+ otherwise a `tuple` is returned containing the output images.
+ """
+
+ callback = kwargs.pop("callback", None)
+ callback_steps = kwargs.pop("callback_steps", None)
+
+ if callback is not None:
+ deprecate(
+ "callback",
+ "1.0.0",
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
+ )
+ if callback_steps is not None:
+ deprecate(
+ "callback_steps",
+ "1.0.0",
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
+ )
+
+ controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
+
+ # align format for control guidance
+ if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
+ control_guidance_start = len(control_guidance_end) * [control_guidance_start]
+ elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
+ control_guidance_end = len(control_guidance_start) * [control_guidance_end]
+ elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
+ mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
+ control_guidance_start, control_guidance_end = (
+ mult * [control_guidance_start],
+ mult * [control_guidance_end],
+ )
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ prompt_2,
+ image,
+ callback_steps,
+ negative_prompt,
+ negative_prompt_2,
+ prompt_embeds,
+ negative_prompt_embeds,
+ pooled_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ controlnet_conditioning_scale,
+ control_guidance_start,
+ control_guidance_end,
+ callback_on_step_end_tensor_inputs,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._clip_skip = clip_skip
+ self._cross_attention_kwargs = cross_attention_kwargs
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
+ controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
+
+ global_pool_conditions = (
+ controlnet.config.global_pool_conditions
+ if isinstance(controlnet, ControlNetModel)
+ else controlnet.nets[0].config.global_pool_conditions
+ )
+ guess_mode = guess_mode or global_pool_conditions
+
+ # 3.1 Encode input prompt
+ text_encoder_lora_scale = (
+ self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
+ )
+ (
+ prompt_embeds,
+ negative_prompt_embeds,
+ pooled_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ ) = self.encode_prompt(
+ prompt,
+ prompt_2,
+ device,
+ num_images_per_prompt,
+ self.do_classifier_free_guidance,
+ negative_prompt,
+ negative_prompt_2,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ lora_scale=text_encoder_lora_scale,
+ clip_skip=self.clip_skip,
+ )
+
+ # 3.2 Encode image prompt
+ prompt_image_emb = self._encode_prompt_image_emb(
+ image_embeds, device, self.unet.dtype, self.do_classifier_free_guidance
+ )
+ bs_embed, seq_len, _ = prompt_image_emb.shape
+ prompt_image_emb = prompt_image_emb.repeat(1, num_images_per_prompt, 1)
+ prompt_image_emb = prompt_image_emb.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # 4. Prepare image
+ if isinstance(controlnet, ControlNetModel):
+ image = self.prepare_image(
+ image=image,
+ width=width,
+ height=height,
+ batch_size=batch_size * num_images_per_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ dtype=controlnet.dtype,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ guess_mode=guess_mode,
+ )
+ height, width = image.shape[-2:]
+ elif isinstance(controlnet, MultiControlNetModel):
+ images = []
+
+ for image_ in image:
+ image_ = self.prepare_image(
+ image=image_,
+ width=width,
+ height=height,
+ batch_size=batch_size * num_images_per_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ dtype=controlnet.dtype,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ guess_mode=guess_mode,
+ )
+
+ images.append(image_)
+
+ image = images
+ height, width = image[0].shape[-2:]
+ else:
+ assert False
+
+ # 5. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+ self._num_timesteps = len(timesteps)
+
+ # 6. Prepare latent variables
+ num_channels_latents = self.unet.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 6.5 Optionally get Guidance Scale Embedding
+ timestep_cond = None
+ if self.unet.config.time_cond_proj_dim is not None:
+ guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
+ timestep_cond = self.get_guidance_scale_embedding(
+ guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
+ ).to(device=device, dtype=latents.dtype)
+
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 7.1 Create tensor stating which controlnets to keep
+ controlnet_keep = []
+ for i in range(len(timesteps)):
+ keeps = [
+ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
+ for s, e in zip(control_guidance_start, control_guidance_end)
+ ]
+ controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)
+
+ # 7.2 Prepare added time ids & embeddings
+ if isinstance(image, list):
+ original_size = original_size or image[0].shape[-2:]
+ else:
+ original_size = original_size or image.shape[-2:]
+ target_size = target_size or (height, width)
+
+ add_text_embeds = pooled_prompt_embeds
+ if self.text_encoder_2 is None:
+ text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
+ else:
+ text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
+
+ add_time_ids = self._get_add_time_ids(
+ original_size,
+ crops_coords_top_left,
+ target_size,
+ dtype=prompt_embeds.dtype,
+ text_encoder_projection_dim=text_encoder_projection_dim,
+ )
+
+ if negative_original_size is not None and negative_target_size is not None:
+ negative_add_time_ids = self._get_add_time_ids(
+ negative_original_size,
+ negative_crops_coords_top_left,
+ negative_target_size,
+ dtype=prompt_embeds.dtype,
+ text_encoder_projection_dim=text_encoder_projection_dim,
+ )
+ else:
+ negative_add_time_ids = add_time_ids
+
+ if self.do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
+ add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
+
+ prompt_embeds = prompt_embeds.to(device)
+ add_text_embeds = add_text_embeds.to(device)
+ add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
+ encoder_hidden_states = torch.cat([prompt_embeds, prompt_image_emb], dim=1)
+
+ # 8. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ is_unet_compiled = is_compiled_module(self.unet)
+ is_controlnet_compiled = is_compiled_module(self.controlnet)
+ is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1")
+
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # Relevant thread:
+ # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
+ if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1:
+ torch._inductor.cudagraph_mark_step_begin()
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
+
+ # controlnet(s) inference
+ if guess_mode and self.do_classifier_free_guidance:
+ # Infer ControlNet only for the conditional batch.
+ control_model_input = latents
+ control_model_input = self.scheduler.scale_model_input(control_model_input, t)
+ controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
+ controlnet_added_cond_kwargs = {
+ "text_embeds": add_text_embeds.chunk(2)[1],
+ "time_ids": add_time_ids.chunk(2)[1],
+ }
+ else:
+ control_model_input = latent_model_input
+ controlnet_prompt_embeds = prompt_embeds
+ controlnet_added_cond_kwargs = added_cond_kwargs
+
+ if isinstance(controlnet_keep[i], list):
+ cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
+ else:
+ controlnet_cond_scale = controlnet_conditioning_scale
+ if isinstance(controlnet_cond_scale, list):
+ controlnet_cond_scale = controlnet_cond_scale[0]
+ cond_scale = controlnet_cond_scale * controlnet_keep[i]
+
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
+ control_model_input,
+ t,
+ encoder_hidden_states=prompt_image_emb,
+ controlnet_cond=image,
+ conditioning_scale=cond_scale,
+ guess_mode=guess_mode,
+ added_cond_kwargs=controlnet_added_cond_kwargs,
+ return_dict=False,
+ )
+
+ if guess_mode and self.do_classifier_free_guidance:
+ # Inferred ControlNet only for the conditional batch.
+ # To apply the output of ControlNet to both the unconditional and conditional batches,
+ # add 0 to the unconditional batch to keep it unchanged.
+ down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
+ mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
+
+ # predict the noise residual
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=encoder_hidden_states,
+ timestep_cond=timestep_cond,
+ cross_attention_kwargs=self.cross_attention_kwargs,
+ down_block_additional_residuals=down_block_res_samples,
+ mid_block_additional_residual=mid_block_res_sample,
+ added_cond_kwargs=added_cond_kwargs,
+ return_dict=False,
+ )[0]
+
+ # perform guidance
+ if self.do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
+
+ if not output_type == "latent":
+ # make sure the VAE is in float32 mode, as it overflows in float16
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
+ if needs_upcasting:
+ self.upcast_vae()
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
+
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+
+ # cast back to fp16 if needed
+ if needs_upcasting:
+ self.vae.to(dtype=torch.float16)
+ else:
+ image = latents
+
+ if not output_type == "latent":
+ # apply watermark if available
+ if self.watermark is not None:
+ image = self.watermark.apply_watermark(image)
+
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image,)
+
+ return StableDiffusionXLPipelineOutput(images=image)
diff --git a/diffusers/examples/community/pipeline_stable_diffusion_xl_ipex.py b/diffusers/examples/community/pipeline_stable_diffusion_xl_ipex.py
new file mode 100644
index 0000000000000000000000000000000000000000..022dfb1abf8247b7fcceca9ab2c301702b16ccef
--- /dev/null
+++ b/diffusers/examples/community/pipeline_stable_diffusion_xl_ipex.py
@@ -0,0 +1,1430 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import intel_extension_for_pytorch as ipex
+import torch
+from transformers import (
+ CLIPImageProcessor,
+ CLIPTextModel,
+ CLIPTextModelWithProjection,
+ CLIPTokenizer,
+ CLIPVisionModelWithProjection,
+)
+
+from diffusers import StableDiffusionXLPipeline
+from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
+from diffusers.loaders import (
+ StableDiffusionXLLoraLoaderMixin,
+ TextualInversionLoaderMixin,
+)
+from diffusers.models import AutoencoderKL, UNet2DConditionModel
+from diffusers.models.attention_processor import (
+ AttnProcessor2_0,
+ XFormersAttnProcessor,
+)
+from diffusers.models.lora import adjust_lora_scale_text_encoder
+from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput
+from diffusers.schedulers import KarrasDiffusionSchedulers
+from diffusers.utils import (
+ USE_PEFT_BACKEND,
+ deprecate,
+ is_invisible_watermark_available,
+ is_torch_xla_available,
+ logging,
+ replace_example_docstring,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
+from diffusers.utils.torch_utils import randn_tensor
+
+
+if is_invisible_watermark_available():
+ from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import StableDiffusionXLPipelineIpex
+
+ >>> # SDXL-Turbo, a distilled version of SDXL 1.0, trained for real-time synthesis
+ >>> pipe = StableDiffusionXLPipelineIpex.from_pretrained(
+ ... "stabilityai/sdxl-turbo", low_cpu_mem_usage=True, use_safetensors=True
+ ... )
+
+ >>> num_inference_steps = 1
+ >>> guidance_scale = 0.0
+ >>> use_bf16 = True
+ >>> data_type = torch.bfloat16 if use_bf16 else torch.float32
+ >>> prompt = "a photo of an astronaut riding a horse on mars"
+
+ >>> # value of image height/width should be consistent with the pipeline inference
+ >>> # For Float32
+ >>> pipe.prepare_for_ipex(torch.float32, prompt, height=512, width=512)
+ >>> # For BFloat16
+ >>> pipe.prepare_for_ipex(torch.bfloat16, prompt, height=512, width=512)
+
+ >>> # value of image height/width should be consistent with 'prepare_for_ipex()'
+ >>> # For Float32
+ >>> image = pipe(prompt, num_inference_steps=num_inference_steps, height=512, width=512, guidance_scale=guidance_scale).images[0]
+ >>> # For BFloat16
+ >>> with torch.cpu.amp.autocast(enabled=True, dtype=torch.bfloat16):
+ >>> image = pipe(prompt, num_inference_steps=num_inference_steps, height=512, width=512, guidance_scale=guidance_scale).images[0]
+ ```
+"""
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
+def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
+ """
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
+ """
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
+ # rescale the results from guidance (fixes overexposure)
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
+ return noise_cfg
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ **kwargs,
+):
+ """
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used,
+ `timesteps` must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
+ timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
+ must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class StableDiffusionXLPipelineIpex(
+ StableDiffusionXLPipeline,
+):
+ r"""
+ Pipeline for text-to-image generation using Stable Diffusion XL on IPEX.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ In addition the pipeline inherits the following loading methods:
+ - *LoRA*: [`loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`]
+ - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`]
+
+ as well as the following saving methods:
+ - *LoRA*: [`loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`]
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion XL uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ text_encoder_2 ([` CLIPTextModelWithProjection`]):
+ Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
+ specifically the
+ [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
+ variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ tokenizer_2 (`CLIPTokenizer`):
+ Second Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`):
+ Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of
+ `stabilityai/stable-diffusion-xl-base-1-0`.
+ add_watermarker (`bool`, *optional*):
+ Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to
+ watermark output images. If not defined, it will default to True if the package is installed, otherwise no
+ watermarker will be used.
+ """
+
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae"
+ _optional_components = [
+ "tokenizer",
+ "tokenizer_2",
+ "text_encoder",
+ "text_encoder_2",
+ "image_encoder",
+ "feature_extractor",
+ ]
+ _callback_tensor_inputs = [
+ "latents",
+ "prompt_embeds",
+ "negative_prompt_embeds",
+ "add_text_embeds",
+ "add_time_ids",
+ "negative_pooled_prompt_embeds",
+ "negative_add_time_ids",
+ ]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ text_encoder_2: CLIPTextModelWithProjection,
+ tokenizer: CLIPTokenizer,
+ tokenizer_2: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: KarrasDiffusionSchedulers,
+ image_encoder: CLIPVisionModelWithProjection = None,
+ feature_extractor: CLIPImageProcessor = None,
+ force_zeros_for_empty_prompt: bool = True,
+ add_watermarker: Optional[bool] = None,
+ ):
+ # super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ text_encoder_2=text_encoder_2,
+ tokenizer=tokenizer,
+ tokenizer_2=tokenizer_2,
+ unet=unet,
+ scheduler=scheduler,
+ image_encoder=image_encoder,
+ feature_extractor=feature_extractor,
+ )
+ self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+
+ self.default_sample_size = self.unet.config.sample_size
+
+ add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
+
+ if add_watermarker:
+ self.watermark = StableDiffusionXLWatermarker()
+ else:
+ self.watermark = None
+
+ def encode_prompt(
+ self,
+ prompt: str,
+ prompt_2: Optional[str] = None,
+ device: Optional[torch.device] = None,
+ num_images_per_prompt: int = 1,
+ do_classifier_free_guidance: bool = True,
+ negative_prompt: Optional[str] = None,
+ negative_prompt_2: Optional[str] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ pooled_prompt_embeds: Optional[torch.Tensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
+ lora_scale: Optional[float] = None,
+ clip_skip: Optional[int] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ used in both text-encoders
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ pooled_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
+ input argument.
+ lora_scale (`float`, *optional*):
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+ """
+ device = device or self._execution_device
+
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ # dynamically adjust the LoRA scale
+ if self.text_encoder is not None:
+ if not USE_PEFT_BACKEND:
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
+ else:
+ scale_lora_layers(self.text_encoder, lora_scale)
+
+ if self.text_encoder_2 is not None:
+ if not USE_PEFT_BACKEND:
+ adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
+ else:
+ scale_lora_layers(self.text_encoder_2, lora_scale)
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # Define tokenizers and text encoders
+ tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
+ text_encoders = (
+ [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
+ )
+
+ if prompt_embeds is None:
+ prompt_2 = prompt_2 or prompt
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
+
+ # textual inversion: procecss multi-vector tokens if necessary
+ prompt_embeds_list = []
+ prompts = [prompt, prompt_2]
+ for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, tokenizer)
+
+ text_inputs = tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
+
+ # We are only ALWAYS interested in the pooled output of the final text encoder
+ pooled_prompt_embeds = prompt_embeds[0]
+ if clip_skip is None:
+ prompt_embeds = prompt_embeds.hidden_states[-2]
+ else:
+ # "2" because SDXL always indexes from the penultimate layer.
+ prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
+
+ prompt_embeds_list.append(prompt_embeds)
+
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
+
+ # get unconditional embeddings for classifier free guidance
+ zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt
+ if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:
+ negative_prompt_embeds = torch.zeros_like(prompt_embeds)
+ negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
+ elif do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = negative_prompt or ""
+ negative_prompt_2 = negative_prompt_2 or negative_prompt
+
+ # normalize str to list
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+ negative_prompt_2 = (
+ batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
+ )
+
+ uncond_tokens: List[str]
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = [negative_prompt, negative_prompt_2]
+
+ negative_prompt_embeds_list = []
+ for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):
+ if isinstance(self, TextualInversionLoaderMixin):
+ negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer)
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = tokenizer(
+ negative_prompt,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ negative_prompt_embeds = text_encoder(
+ uncond_input.input_ids.to(device),
+ output_hidden_states=True,
+ )
+ # We are only ALWAYS interested in the pooled output of the final text encoder
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
+ negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
+
+ negative_prompt_embeds_list.append(negative_prompt_embeds)
+
+ negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
+
+ if self.text_encoder_2 is not None:
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
+ else:
+ prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ if self.text_encoder_2 is not None:
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
+ else:
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
+ bs_embed * num_images_per_prompt, -1
+ )
+ if do_classifier_free_guidance:
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
+ bs_embed * num_images_per_prompt, -1
+ )
+
+ if self.text_encoder is not None:
+ if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder, lora_scale)
+
+ if self.text_encoder_2 is not None:
+ if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
+
+ return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
+ def encode_image(self, image, device, num_images_per_prompt):
+ dtype = next(self.image_encoder.parameters()).dtype
+
+ if not isinstance(image, torch.Tensor):
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
+
+ image = image.to(device=device, dtype=dtype)
+ image_embeds = self.image_encoder(image).image_embeds
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+
+ uncond_image_embeds = torch.zeros_like(image_embeds)
+ return image_embeds, uncond_image_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(
+ self,
+ prompt,
+ prompt_2,
+ height,
+ width,
+ callback_steps,
+ negative_prompt=None,
+ negative_prompt_2=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ pooled_prompt_embeds=None,
+ negative_pooled_prompt_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ ):
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt_2 is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
+ raise ValueError(
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
+ )
+
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
+ raise ValueError(
+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
+ )
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
+ shape = (
+ batch_size,
+ num_channels_latents,
+ int(height) // self.vae_scale_factor,
+ int(width) // self.vae_scale_factor,
+ )
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=torch.float32)
+ else:
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ def _get_add_time_ids(
+ self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None
+ ):
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
+
+ passed_add_embed_dim = (
+ self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
+ )
+ expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
+
+ if expected_add_embed_dim != passed_add_embed_dim:
+ raise ValueError(
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
+ )
+
+ add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
+ return add_time_ids
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
+ def upcast_vae(self):
+ dtype = self.vae.dtype
+ self.vae.to(dtype=torch.float32)
+ use_torch_2_0_or_xformers = isinstance(
+ self.vae.decoder.mid_block.attentions[0].processor,
+ (
+ AttnProcessor2_0,
+ XFormersAttnProcessor,
+ ),
+ )
+ # if xformers or torch_2_0 is used attention block does not need
+ # to be in float32 which can save lots of memory
+ if use_torch_2_0_or_xformers:
+ self.vae.post_quant_conv.to(dtype)
+ self.vae.decoder.conv_in.to(dtype)
+ self.vae.decoder.mid_block.to(dtype)
+
+ # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
+ def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
+ """
+ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
+
+ Args:
+ timesteps (`torch.Tensor`):
+ generate embedding vectors at these timesteps
+ embedding_dim (`int`, *optional*, defaults to 512):
+ dimension of the embeddings to generate
+ dtype:
+ data type of the generated embeddings
+
+ Returns:
+ `torch.Tensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
+ """
+ assert len(w.shape) == 1
+ w = w * 1000.0
+
+ half_dim = embedding_dim // 2
+ emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
+ emb = w.to(dtype)[:, None] * emb[None, :]
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+ if embedding_dim % 2 == 1: # zero pad
+ emb = torch.nn.functional.pad(emb, (0, 1))
+ assert emb.shape == (w.shape[0], embedding_dim)
+ return emb
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def guidance_rescale(self):
+ return self._guidance_rescale
+
+ @property
+ def clip_skip(self):
+ return self._clip_skip
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
+
+ @property
+ def cross_attention_kwargs(self):
+ return self._cross_attention_kwargs
+
+ @property
+ def denoising_end(self):
+ return self._denoising_end
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ prompt_2: Optional[Union[str, List[str]]] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ timesteps: List[int] = None,
+ denoising_end: Optional[float] = None,
+ guidance_scale: float = 5.0,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ pooled_prompt_embeds: Optional[torch.Tensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
+ ip_adapter_image: Optional[PipelineImageInput] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ guidance_rescale: float = 0.0,
+ original_size: Optional[Tuple[int, int]] = None,
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
+ target_size: Optional[Tuple[int, int]] = None,
+ negative_original_size: Optional[Tuple[int, int]] = None,
+ negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
+ negative_target_size: Optional[Tuple[int, int]] = None,
+ clip_skip: Optional[int] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ **kwargs,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ used in both text-encoders
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
+ Anything below 512 pixels won't work well for
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
+ and checkpoints that are not specifically fine-tuned on low resolutions.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
+ Anything below 512 pixels won't work well for
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
+ and checkpoints that are not specifically fine-tuned on low resolutions.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
+ passed will be used. Must be in descending order.
+ denoising_end (`float`, *optional*):
+ When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
+ completed before it is intentionally prematurely terminated. As a result, the returned sample will
+ still retain a substantial amount of noise as determined by the discrete timesteps selected by the
+ scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a
+ "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
+ Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
+ guidance_scale (`float`, *optional*, defaults to 5.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ pooled_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
+ input argument.
+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
+ of a plain tuple.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
+ Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
+ [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
+ Guidance rescale factor should fix overexposure when using zero terminal SNR.
+ original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
+ `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
+ explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
+ `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
+ `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
+ `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ For most cases, `target_size` should be set to the desired height and width of the generated image. If
+ not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in
+ section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ To negatively condition the generation process based on a specific image resolution. Part of SDXL's
+ micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
+ negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
+ To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
+ micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
+ negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ To negatively condition the generation process based on a target image resolution. It should be as same
+ as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
+ """
+
+ callback = kwargs.pop("callback", None)
+ callback_steps = kwargs.pop("callback_steps", None)
+
+ if callback is not None:
+ deprecate(
+ "callback",
+ "1.0.0",
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
+ )
+ if callback_steps is not None:
+ deprecate(
+ "callback_steps",
+ "1.0.0",
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
+ )
+
+ # 0. Default height and width to unet
+ height = height or self.default_sample_size * self.vae_scale_factor
+ width = width or self.default_sample_size * self.vae_scale_factor
+
+ original_size = original_size or (height, width)
+ target_size = target_size or (height, width)
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ prompt_2,
+ height,
+ width,
+ callback_steps,
+ negative_prompt,
+ negative_prompt_2,
+ prompt_embeds,
+ negative_prompt_embeds,
+ pooled_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ callback_on_step_end_tensor_inputs,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._guidance_rescale = guidance_rescale
+ self._clip_skip = clip_skip
+ self._cross_attention_kwargs = cross_attention_kwargs
+ self._denoising_end = denoising_end
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ # 3. Encode input prompt
+ lora_scale = (
+ self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
+ )
+
+ (
+ prompt_embeds,
+ negative_prompt_embeds,
+ pooled_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ ) = self.encode_prompt(
+ prompt=prompt,
+ prompt_2=prompt_2,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ negative_prompt=negative_prompt,
+ negative_prompt_2=negative_prompt_2,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ lora_scale=lora_scale,
+ clip_skip=self.clip_skip,
+ )
+
+ # 4. Prepare timesteps
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.unet.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 7. Prepare added time ids & embeddings
+ add_text_embeds = pooled_prompt_embeds
+ if self.text_encoder_2 is None:
+ text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
+ else:
+ text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
+
+ add_time_ids = self._get_add_time_ids(
+ original_size,
+ crops_coords_top_left,
+ target_size,
+ dtype=prompt_embeds.dtype,
+ text_encoder_projection_dim=text_encoder_projection_dim,
+ )
+ if negative_original_size is not None and negative_target_size is not None:
+ negative_add_time_ids = self._get_add_time_ids(
+ negative_original_size,
+ negative_crops_coords_top_left,
+ negative_target_size,
+ dtype=prompt_embeds.dtype,
+ text_encoder_projection_dim=text_encoder_projection_dim,
+ )
+ else:
+ negative_add_time_ids = add_time_ids
+
+ if self.do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
+ add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
+
+ prompt_embeds = prompt_embeds.to(device)
+ add_text_embeds = add_text_embeds.to(device)
+ add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
+
+ if ip_adapter_image is not None:
+ image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt)
+ if self.do_classifier_free_guidance:
+ image_embeds = torch.cat([negative_image_embeds, image_embeds])
+ image_embeds = image_embeds.to(device)
+
+ # 8. Denoising loop
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+
+ # 8.1 Apply denoising_end
+ if (
+ self.denoising_end is not None
+ and isinstance(self.denoising_end, float)
+ and self.denoising_end > 0
+ and self.denoising_end < 1
+ ):
+ discrete_timestep_cutoff = int(
+ round(
+ self.scheduler.config.num_train_timesteps
+ - (self.denoising_end * self.scheduler.config.num_train_timesteps)
+ )
+ )
+ num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
+ timesteps = timesteps[:num_inference_steps]
+
+ # 9. Optionally get Guidance Scale Embedding
+ timestep_cond = None
+ if self.unet.config.time_cond_proj_dim is not None:
+ guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
+ timestep_cond = self.get_guidance_scale_embedding(
+ guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
+ ).to(device=device, dtype=latents.dtype)
+
+ self._num_timesteps = len(timesteps)
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
+ if ip_adapter_image is not None:
+ added_cond_kwargs["image_embeds"] = image_embeds
+
+ # noise_pred = self.unet(
+ # latent_model_input,
+ # t,
+ # encoder_hidden_states=prompt_embeds,
+ # timestep_cond=timestep_cond,
+ # cross_attention_kwargs=self.cross_attention_kwargs,
+ # added_cond_kwargs=added_cond_kwargs,
+ # return_dict=False,
+ # )[0]
+
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ added_cond_kwargs=added_cond_kwargs,
+ )["sample"]
+
+ # perform guidance
+ if self.do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+ add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
+ negative_pooled_prompt_embeds = callback_outputs.pop(
+ "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
+ )
+ add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
+ negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ if not output_type == "latent":
+ # make sure the VAE is in float32 mode, as it overflows in float16
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
+
+ if needs_upcasting:
+ self.upcast_vae()
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
+
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+
+ # cast back to fp16 if needed
+ if needs_upcasting:
+ self.vae.to(dtype=torch.float16)
+ else:
+ image = latents
+
+ if not output_type == "latent":
+ # apply watermark if available
+ if self.watermark is not None:
+ image = self.watermark.apply_watermark(image)
+
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image,)
+
+ return StableDiffusionXLPipelineOutput(images=image)
+
+ @torch.no_grad()
+ def prepare_for_ipex(
+ self,
+ dtype=torch.float32,
+ prompt: Union[str, List[str]] = None,
+ prompt_2: Optional[Union[str, List[str]]] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ timesteps: List[int] = None,
+ denoising_end: Optional[float] = None,
+ guidance_scale: float = 5.0,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ pooled_prompt_embeds: Optional[torch.Tensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
+ ip_adapter_image: Optional[PipelineImageInput] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ guidance_rescale: float = 0.0,
+ original_size: Optional[Tuple[int, int]] = None,
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
+ target_size: Optional[Tuple[int, int]] = None,
+ negative_original_size: Optional[Tuple[int, int]] = None,
+ negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
+ negative_target_size: Optional[Tuple[int, int]] = None,
+ clip_skip: Optional[int] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ **kwargs,
+ ):
+ callback = kwargs.pop("callback", None)
+ callback_steps = kwargs.pop("callback_steps", None)
+
+ if callback is not None:
+ deprecate(
+ "callback",
+ "1.0.0",
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
+ )
+ if callback_steps is not None:
+ deprecate(
+ "callback_steps",
+ "1.0.0",
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
+ )
+
+ # 0. Default height and width to unet
+ height = height or self.default_sample_size * self.vae_scale_factor
+ width = width or self.default_sample_size * self.vae_scale_factor
+
+ original_size = original_size or (height, width)
+ target_size = target_size or (height, width)
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ prompt_2,
+ height,
+ width,
+ callback_steps,
+ negative_prompt,
+ negative_prompt_2,
+ prompt_embeds,
+ negative_prompt_embeds,
+ pooled_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ callback_on_step_end_tensor_inputs,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._guidance_rescale = guidance_rescale
+ self._clip_skip = clip_skip
+ self._cross_attention_kwargs = cross_attention_kwargs
+ self._denoising_end = denoising_end
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = "cpu"
+ do_classifier_free_guidance = self.do_classifier_free_guidance
+
+ # 3. Encode input prompt
+ lora_scale = (
+ self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
+ )
+
+ (
+ prompt_embeds,
+ negative_prompt_embeds,
+ pooled_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ ) = self.encode_prompt(
+ prompt=prompt,
+ prompt_2=prompt_2,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ negative_prompt=negative_prompt,
+ negative_prompt_2=negative_prompt_2,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ lora_scale=lora_scale,
+ clip_skip=self.clip_skip,
+ )
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.unet.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 7. Prepare added time ids & embeddings
+ add_text_embeds = pooled_prompt_embeds
+ if self.text_encoder_2 is None:
+ text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
+ else:
+ text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
+
+ add_time_ids = self._get_add_time_ids(
+ original_size,
+ crops_coords_top_left,
+ target_size,
+ dtype=prompt_embeds.dtype,
+ text_encoder_projection_dim=text_encoder_projection_dim,
+ )
+ if negative_original_size is not None and negative_target_size is not None:
+ negative_add_time_ids = self._get_add_time_ids(
+ negative_original_size,
+ negative_crops_coords_top_left,
+ negative_target_size,
+ dtype=prompt_embeds.dtype,
+ text_encoder_projection_dim=text_encoder_projection_dim,
+ )
+ else:
+ negative_add_time_ids = add_time_ids
+
+ if self.do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
+ add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
+
+ prompt_embeds = prompt_embeds.to(device)
+ add_text_embeds = add_text_embeds.to(device)
+ add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
+
+ if ip_adapter_image is not None:
+ image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt)
+ if self.do_classifier_free_guidance:
+ image_embeds = torch.cat([negative_image_embeds, image_embeds])
+ image_embeds = image_embeds.to(device)
+
+ dummy = torch.ones(1, dtype=torch.int32)
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, dummy)
+
+ # predict the noise residual
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
+ if ip_adapter_image is not None:
+ added_cond_kwargs["image_embeds"] = image_embeds
+
+ if not output_type == "latent":
+ # make sure the VAE is in float32 mode, as it overflows in float16
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
+
+ if needs_upcasting:
+ self.upcast_vae()
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
+
+ # cast back to fp16 if needed
+ if needs_upcasting:
+ self.vae.to(dtype=torch.float16)
+
+ self.unet = self.unet.to(memory_format=torch.channels_last)
+ self.vae.decoder = self.vae.decoder.to(memory_format=torch.channels_last)
+ self.text_encoder = self.text_encoder.to(memory_format=torch.channels_last)
+
+ unet_input_example = {
+ "sample": latent_model_input,
+ "timestep": dummy,
+ "encoder_hidden_states": prompt_embeds,
+ "added_cond_kwargs": added_cond_kwargs,
+ }
+
+ vae_decoder_input_example = latents
+
+ # optimize with ipex
+ if dtype == torch.bfloat16:
+ self.unet = ipex.optimize(
+ self.unet.eval(),
+ dtype=torch.bfloat16,
+ inplace=True,
+ )
+ self.vae.decoder = ipex.optimize(self.vae.decoder.eval(), dtype=torch.bfloat16, inplace=True)
+ self.text_encoder = ipex.optimize(self.text_encoder.eval(), dtype=torch.bfloat16, inplace=True)
+ elif dtype == torch.float32:
+ self.unet = ipex.optimize(
+ self.unet.eval(),
+ dtype=torch.float32,
+ inplace=True,
+ level="O1",
+ weights_prepack=True,
+ auto_kernel_selection=False,
+ )
+ self.vae.decoder = ipex.optimize(
+ self.vae.decoder.eval(),
+ dtype=torch.float32,
+ inplace=True,
+ level="O1",
+ weights_prepack=True,
+ auto_kernel_selection=False,
+ )
+ self.text_encoder = ipex.optimize(
+ self.text_encoder.eval(),
+ dtype=torch.float32,
+ inplace=True,
+ level="O1",
+ weights_prepack=True,
+ auto_kernel_selection=False,
+ )
+ else:
+ raise ValueError(" The value of 'dtype' should be 'torch.bfloat16' or 'torch.float32' !")
+
+ # trace unet model to get better performance on IPEX
+ with torch.cpu.amp.autocast(enabled=dtype == torch.bfloat16), torch.no_grad():
+ unet_trace_model = torch.jit.trace(
+ self.unet, example_kwarg_inputs=unet_input_example, check_trace=False, strict=False
+ )
+ unet_trace_model = torch.jit.freeze(unet_trace_model)
+ self.unet.forward = unet_trace_model.forward
+
+ # trace vae.decoder model to get better performance on IPEX
+ with torch.cpu.amp.autocast(enabled=dtype == torch.bfloat16), torch.no_grad():
+ vae_decoder_trace_model = torch.jit.trace(
+ self.vae.decoder, vae_decoder_input_example, check_trace=False, strict=False
+ )
+ vae_decoder_trace_model = torch.jit.freeze(vae_decoder_trace_model)
+ self.vae.decoder.forward = vae_decoder_trace_model.forward
diff --git a/diffusers/examples/community/pipeline_zero1to3.py b/diffusers/examples/community/pipeline_zero1to3.py
new file mode 100644
index 0000000000000000000000000000000000000000..95bb37ce02b701611f2d1c265fd0fece302846af
--- /dev/null
+++ b/diffusers/examples/community/pipeline_zero1to3.py
@@ -0,0 +1,793 @@
+# A diffuser version implementation of Zero1to3 (https://github.com/cvlab-columbia/zero123), ICCV 2023
+# by Xin Kong
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import kornia
+import numpy as np
+import PIL.Image
+import torch
+from packaging import version
+from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
+
+# from ...configuration_utils import FrozenDict
+# from ...models import AutoencoderKL, UNet2DConditionModel
+# from ...schedulers import KarrasDiffusionSchedulers
+# from ...utils import (
+# deprecate,
+# is_accelerate_available,
+# is_accelerate_version,
+# logging,
+# randn_tensor,
+# replace_example_docstring,
+# )
+# from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
+# from . import StableDiffusionPipelineOutput
+# from .safety_checker import StableDiffusionSafetyChecker
+from diffusers import AutoencoderKL, DiffusionPipeline, StableDiffusionMixin, UNet2DConditionModel
+from diffusers.configuration_utils import ConfigMixin, FrozenDict
+from diffusers.models.modeling_utils import ModelMixin
+from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
+from diffusers.schedulers import KarrasDiffusionSchedulers
+from diffusers.utils import (
+ deprecate,
+ logging,
+ replace_example_docstring,
+)
+from diffusers.utils.torch_utils import randn_tensor
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+# todo
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import StableDiffusionPipeline
+
+ >>> pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
+ >>> pipe = pipe.to("cuda")
+
+ >>> prompt = "a photo of an astronaut riding a horse on mars"
+ >>> image = pipe(prompt).images[0]
+ ```
+"""
+
+
+class CCProjection(ModelMixin, ConfigMixin):
+ def __init__(self, in_channel=772, out_channel=768):
+ super().__init__()
+ self.in_channel = in_channel
+ self.out_channel = out_channel
+ self.projection = torch.nn.Linear(in_channel, out_channel)
+
+ def forward(self, x):
+ return self.projection(x)
+
+
+class Zero1to3StableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin):
+ r"""
+ Pipeline for single view conditioned novel view generation using Zero1to3.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ image_encoder ([`CLIPVisionModelWithProjection`]):
+ Frozen CLIP image-encoder. Stable Diffusion Image Variation uses the vision portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPVisionModelWithProjection),
+ specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ feature_extractor ([`CLIPImageProcessor`]):
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+ cc_projection ([`CCProjection`]):
+ Projection layer to project the concated CLIP features and pose embeddings to the original CLIP feature size.
+ """
+
+ _optional_components = ["safety_checker", "feature_extractor"]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ image_encoder: CLIPVisionModelWithProjection,
+ unet: UNet2DConditionModel,
+ scheduler: KarrasDiffusionSchedulers,
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPImageProcessor,
+ cc_projection: CCProjection,
+ requires_safety_checker: bool = True,
+ ):
+ super().__init__()
+
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
+ " file"
+ )
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["steps_offset"] = 1
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
+ )
+ deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["clip_sample"] = False
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if safety_checker is None and requires_safety_checker:
+ logger.warning(
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
+ )
+
+ if safety_checker is not None and feature_extractor is None:
+ raise ValueError(
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
+ )
+
+ is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
+ version.parse(unet.config._diffusers_version).base_version
+ ) < version.parse("0.9.0.dev0")
+ is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
+ deprecation_message = (
+ "The configuration file of the unet has set the default `sample_size` to smaller than"
+ " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
+ " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
+ " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
+ " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
+ " in the config might lead to incorrect results in future versions. If you have downloaded this"
+ " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
+ " the `unet/config.json` file"
+ )
+ deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(unet.config)
+ new_config["sample_size"] = 64
+ unet._internal_dict = FrozenDict(new_config)
+
+ self.register_modules(
+ vae=vae,
+ image_encoder=image_encoder,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ cc_projection=cc_projection,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
+ # self.model_mode = None
+
+ def _encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt=None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
+ Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ """
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = self.tokenizer.batch_decode(
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
+ )
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = text_inputs.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ prompt_embeds = self.text_encoder(
+ text_input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ prompt_embeds = prompt_embeds[0]
+
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = uncond_input.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ negative_prompt_embeds = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ negative_prompt_embeds = negative_prompt_embeds[0]
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+
+ return prompt_embeds
+
+ def CLIP_preprocess(self, x):
+ dtype = x.dtype
+ # following openai's implementation
+ # TODO HF OpenAI CLIP preprocessing issue https://github.com/huggingface/transformers/issues/22505#issuecomment-1650170741
+ # follow openai preprocessing to keep exact same, input tensor [-1, 1], otherwise the preprocessing will be different, https://github.com/huggingface/transformers/pull/22608
+ if isinstance(x, torch.Tensor):
+ if x.min() < -1.0 or x.max() > 1.0:
+ raise ValueError("Expected input tensor to have values in the range [-1, 1]")
+ x = kornia.geometry.resize(
+ x.to(torch.float32), (224, 224), interpolation="bicubic", align_corners=True, antialias=False
+ ).to(dtype=dtype)
+ x = (x + 1.0) / 2.0
+ # renormalize according to clip
+ x = kornia.enhance.normalize(
+ x, torch.Tensor([0.48145466, 0.4578275, 0.40821073]), torch.Tensor([0.26862954, 0.26130258, 0.27577711])
+ )
+ return x
+
+ # from image_variation
+ def _encode_image(self, image, device, num_images_per_prompt, do_classifier_free_guidance):
+ dtype = next(self.image_encoder.parameters()).dtype
+ if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
+ raise ValueError(
+ f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
+ )
+
+ if isinstance(image, torch.Tensor):
+ # Batch single image
+ if image.ndim == 3:
+ assert image.shape[0] == 3, "Image outside a batch should be of shape (3, H, W)"
+ image = image.unsqueeze(0)
+
+ assert image.ndim == 4, "Image must have 4 dimensions"
+
+ # Check image is in [-1, 1]
+ if image.min() < -1 or image.max() > 1:
+ raise ValueError("Image should be in [-1, 1] range")
+ else:
+ # preprocess image
+ if isinstance(image, (PIL.Image.Image, np.ndarray)):
+ image = [image]
+
+ if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):
+ image = [np.array(i.convert("RGB"))[None, :] for i in image]
+ image = np.concatenate(image, axis=0)
+ elif isinstance(image, list) and isinstance(image[0], np.ndarray):
+ image = np.concatenate([i[None, :] for i in image], axis=0)
+
+ image = image.transpose(0, 3, 1, 2)
+ image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
+
+ image = image.to(device=device, dtype=dtype)
+
+ image = self.CLIP_preprocess(image)
+ # if not isinstance(image, torch.Tensor):
+ # # 0-255
+ # print("Warning: image is processed by hf's preprocess, which is different from openai original's.")
+ # image = self.feature_extractor(images=image, return_tensors="pt").pixel_values
+ image_embeddings = self.image_encoder(image).image_embeds.to(dtype=dtype)
+ image_embeddings = image_embeddings.unsqueeze(1)
+
+ # duplicate image embeddings for each generation per prompt, using mps friendly method
+ bs_embed, seq_len, _ = image_embeddings.shape
+ image_embeddings = image_embeddings.repeat(1, num_images_per_prompt, 1)
+ image_embeddings = image_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ if do_classifier_free_guidance:
+ negative_prompt_embeds = torch.zeros_like(image_embeddings)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ image_embeddings = torch.cat([negative_prompt_embeds, image_embeddings])
+
+ return image_embeddings
+
+ def _encode_pose(self, pose, device, num_images_per_prompt, do_classifier_free_guidance):
+ dtype = next(self.cc_projection.parameters()).dtype
+ if isinstance(pose, torch.Tensor):
+ pose_embeddings = pose.unsqueeze(1).to(device=device, dtype=dtype)
+ else:
+ if isinstance(pose[0], list):
+ pose = torch.Tensor(pose)
+ else:
+ pose = torch.Tensor([pose])
+ x, y, z = pose[:, 0].unsqueeze(1), pose[:, 1].unsqueeze(1), pose[:, 2].unsqueeze(1)
+ pose_embeddings = (
+ torch.cat([torch.deg2rad(x), torch.sin(torch.deg2rad(y)), torch.cos(torch.deg2rad(y)), z], dim=-1)
+ .unsqueeze(1)
+ .to(device=device, dtype=dtype)
+ ) # B, 1, 4
+ # duplicate pose embeddings for each generation per prompt, using mps friendly method
+ bs_embed, seq_len, _ = pose_embeddings.shape
+ pose_embeddings = pose_embeddings.repeat(1, num_images_per_prompt, 1)
+ pose_embeddings = pose_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
+ if do_classifier_free_guidance:
+ negative_prompt_embeds = torch.zeros_like(pose_embeddings)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ pose_embeddings = torch.cat([negative_prompt_embeds, pose_embeddings])
+ return pose_embeddings
+
+ def _encode_image_with_pose(self, image, pose, device, num_images_per_prompt, do_classifier_free_guidance):
+ img_prompt_embeds = self._encode_image(image, device, num_images_per_prompt, False)
+ pose_prompt_embeds = self._encode_pose(pose, device, num_images_per_prompt, False)
+ prompt_embeds = torch.cat([img_prompt_embeds, pose_prompt_embeds], dim=-1)
+ prompt_embeds = self.cc_projection(prompt_embeds)
+ # prompt_embeds = img_prompt_embeds
+ # follow 0123, add negative prompt, after projection
+ if do_classifier_free_guidance:
+ negative_prompt = torch.zeros_like(prompt_embeds)
+ prompt_embeds = torch.cat([negative_prompt, prompt_embeds])
+ return prompt_embeds
+
+ def run_safety_checker(self, image, device, dtype):
+ if self.safety_checker is not None:
+ safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
+ image, has_nsfw_concept = self.safety_checker(
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
+ )
+ else:
+ has_nsfw_concept = None
+ return image, has_nsfw_concept
+
+ def decode_latents(self, latents):
+ latents = 1 / self.vae.config.scaling_factor * latents
+ image = self.vae.decode(latents).sample
+ image = (image / 2 + 0.5).clamp(0, 1)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+ return image
+
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(self, image, height, width, callback_steps):
+ if (
+ not isinstance(image, torch.Tensor)
+ and not isinstance(image, PIL.Image.Image)
+ and not isinstance(image, list)
+ ):
+ raise ValueError(
+ "`image` has to be of type `torch.Tensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is"
+ f" {type(image)}"
+ )
+
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
+ shape = (
+ batch_size,
+ num_channels_latents,
+ int(height) // self.vae_scale_factor,
+ int(width) // self.vae_scale_factor,
+ )
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ def prepare_img_latents(self, image, batch_size, dtype, device, generator=None, do_classifier_free_guidance=False):
+ if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
+ raise ValueError(
+ f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
+ )
+
+ if isinstance(image, torch.Tensor):
+ # Batch single image
+ if image.ndim == 3:
+ assert image.shape[0] == 3, "Image outside a batch should be of shape (3, H, W)"
+ image = image.unsqueeze(0)
+
+ assert image.ndim == 4, "Image must have 4 dimensions"
+
+ # Check image is in [-1, 1]
+ if image.min() < -1 or image.max() > 1:
+ raise ValueError("Image should be in [-1, 1] range")
+ else:
+ # preprocess image
+ if isinstance(image, (PIL.Image.Image, np.ndarray)):
+ image = [image]
+
+ if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):
+ image = [np.array(i.convert("RGB"))[None, :] for i in image]
+ image = np.concatenate(image, axis=0)
+ elif isinstance(image, list) and isinstance(image[0], np.ndarray):
+ image = np.concatenate([i[None, :] for i in image], axis=0)
+
+ image = image.transpose(0, 3, 1, 2)
+ image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
+
+ image = image.to(device=device, dtype=dtype)
+
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if isinstance(generator, list):
+ init_latents = [
+ self.vae.encode(image[i : i + 1]).latent_dist.mode(generator[i])
+ for i in range(batch_size) # sample
+ ]
+ init_latents = torch.cat(init_latents, dim=0)
+ else:
+ init_latents = self.vae.encode(image).latent_dist.mode()
+
+ # init_latents = self.vae.config.scaling_factor * init_latents # todo in original zero123's inference gradio_new.py, model.encode_first_stage() is not scaled by scaling_factor
+ if batch_size > init_latents.shape[0]:
+ # init_latents = init_latents.repeat(batch_size // init_latents.shape[0], 1, 1, 1)
+ num_images_per_prompt = batch_size // init_latents.shape[0]
+ # duplicate image latents for each generation per prompt, using mps friendly method
+ bs_embed, emb_c, emb_h, emb_w = init_latents.shape
+ init_latents = init_latents.unsqueeze(1)
+ init_latents = init_latents.repeat(1, num_images_per_prompt, 1, 1, 1)
+ init_latents = init_latents.view(bs_embed * num_images_per_prompt, emb_c, emb_h, emb_w)
+
+ # init_latents = torch.cat([init_latents]*2) if do_classifier_free_guidance else init_latents # follow zero123
+ init_latents = (
+ torch.cat([torch.zeros_like(init_latents), init_latents]) if do_classifier_free_guidance else init_latents
+ )
+
+ init_latents = init_latents.to(device=device, dtype=dtype)
+ return init_latents
+
+ # def load_cc_projection(self, pretrained_weights=None):
+ # self.cc_projection = torch.nn.Linear(772, 768)
+ # torch.nn.init.eye_(list(self.cc_projection.parameters())[0][:768, :768])
+ # torch.nn.init.zeros_(list(self.cc_projection.parameters())[1])
+ # if pretrained_weights is not None:
+ # self.cc_projection.load_state_dict(pretrained_weights)
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ input_imgs: Union[torch.Tensor, PIL.Image.Image] = None,
+ prompt_imgs: Union[torch.Tensor, PIL.Image.Image] = None,
+ poses: Union[List[float], List[List[float]]] = None,
+ torch_dtype=torch.float32,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 3.0,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
+ callback_steps: int = 1,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ controlnet_conditioning_scale: float = 1.0,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ input_imgs (`PIL` or `List[PIL]`, *optional*):
+ The single input image for each 3D object
+ prompt_imgs (`PIL` or `List[PIL]`, *optional*):
+ Same as input_imgs, but will be used later as an image prompt condition, encoded by CLIP feature
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
+ Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under
+ `self.processor` in
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
+
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+ # 0. Default height and width to unet
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+ # 1. Check inputs. Raise error if not correct
+ # input_image = hint_imgs
+ self.check_inputs(input_imgs, height, width, callback_steps)
+
+ # 2. Define call parameters
+ if isinstance(input_imgs, PIL.Image.Image):
+ batch_size = 1
+ elif isinstance(input_imgs, list):
+ batch_size = len(input_imgs)
+ else:
+ batch_size = input_imgs.shape[0]
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input image with pose as prompt
+ prompt_embeds = self._encode_image_with_pose(
+ prompt_imgs, poses, device, num_images_per_prompt, do_classifier_free_guidance
+ )
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # 5. Prepare latent variables
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ 4,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 6. Prepare image latents
+ img_latents = self.prepare_img_latents(
+ input_imgs,
+ batch_size * num_images_per_prompt,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ do_classifier_free_guidance,
+ )
+
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 7. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+ latent_model_input = torch.cat([latent_model_input, img_latents], dim=1)
+
+ # predict the noise residual
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds).sample
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ # latents = self.scheduler.step(noise_pred.to(dtype=torch.float32), t, latents.to(dtype=torch.float32)).prev_sample.to(prompt_embeds.dtype)
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
+
+ # 8. Post-processing
+ has_nsfw_concept = None
+ if output_type == "latent":
+ image = latents
+ elif output_type == "pil":
+ # 8. Post-processing
+ image = self.decode_latents(latents)
+ # 10. Convert to PIL
+ image = self.numpy_to_pil(image)
+ else:
+ # 8. Post-processing
+ image = self.decode_latents(latents)
+
+ # Offload last model to CPU
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.final_offload_hook.offload()
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
diff --git a/diffusers/examples/community/regional_prompting_stable_diffusion.py b/diffusers/examples/community/regional_prompting_stable_diffusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..8a022987ba9d379982461056619401aaf44e0341
--- /dev/null
+++ b/diffusers/examples/community/regional_prompting_stable_diffusion.py
@@ -0,0 +1,618 @@
+import math
+from typing import Dict, Optional
+
+import torch
+import torchvision.transforms.functional as FF
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
+
+from diffusers import StableDiffusionPipeline
+from diffusers.models import AutoencoderKL, UNet2DConditionModel
+from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
+from diffusers.schedulers import KarrasDiffusionSchedulers
+from diffusers.utils import USE_PEFT_BACKEND
+
+
+try:
+ from compel import Compel
+except ImportError:
+ Compel = None
+
+KCOMM = "ADDCOMM"
+KBRK = "BREAK"
+
+
+class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
+ r"""
+ Args for Regional Prompting Pipeline:
+ rp_args:dict
+ Required
+ rp_args["mode"]: cols, rows, prompt, prompt-ex
+ for cols, rows mode
+ rp_args["div"]: ex) 1;1;1(Divide into 3 regions)
+ for prompt, prompt-ex mode
+ rp_args["th"]: ex) 0.5,0.5,0.6 (threshold for prompt mode)
+
+ Optional
+ rp_args["save_mask"]: True/False (save masks in prompt mode)
+
+ Pipeline for text-to-image generation using Stable Diffusion.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
+ feature_extractor ([`CLIPImageProcessor`]):
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+ """
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: KarrasDiffusionSchedulers,
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPImageProcessor,
+ requires_safety_checker: bool = True,
+ ):
+ super().__init__(
+ vae,
+ text_encoder,
+ tokenizer,
+ unet,
+ scheduler,
+ safety_checker,
+ feature_extractor,
+ requires_safety_checker,
+ )
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: str,
+ height: int = 512,
+ width: int = 512,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ negative_prompt: str = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[torch.Generator] = None,
+ latents: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ rp_args: Dict[str, str] = None,
+ ):
+ active = KBRK in prompt[0] if isinstance(prompt, list) else KBRK in prompt
+ if negative_prompt is None:
+ negative_prompt = "" if isinstance(prompt, str) else [""] * len(prompt)
+
+ device = self._execution_device
+ regions = 0
+
+ self.power = int(rp_args["power"]) if "power" in rp_args else 1
+
+ prompts = prompt if isinstance(prompt, list) else [prompt]
+ n_prompts = negative_prompt if isinstance(prompt, str) else [negative_prompt]
+ self.batch = batch = num_images_per_prompt * len(prompts)
+ all_prompts_cn, all_prompts_p = promptsmaker(prompts, num_images_per_prompt)
+ all_n_prompts_cn, _ = promptsmaker(n_prompts, num_images_per_prompt)
+
+ equal = len(all_prompts_cn) == len(all_n_prompts_cn)
+
+ if Compel:
+ compel = Compel(tokenizer=self.tokenizer, text_encoder=self.text_encoder)
+
+ def getcompelembs(prps):
+ embl = []
+ for prp in prps:
+ embl.append(compel.build_conditioning_tensor(prp))
+ return torch.cat(embl)
+
+ conds = getcompelembs(all_prompts_cn)
+ unconds = getcompelembs(all_n_prompts_cn)
+ embs = getcompelembs(prompts)
+ n_embs = getcompelembs(n_prompts)
+ prompt = negative_prompt = None
+ else:
+ conds = self.encode_prompt(prompts, device, 1, True)[0]
+ unconds = (
+ self.encode_prompt(n_prompts, device, 1, True)[0]
+ if equal
+ else self.encode_prompt(all_n_prompts_cn, device, 1, True)[0]
+ )
+ embs = n_embs = None
+
+ if not active:
+ pcallback = None
+ mode = None
+ else:
+ if any(x in rp_args["mode"].upper() for x in ["COL", "ROW"]):
+ mode = "COL" if "COL" in rp_args["mode"].upper() else "ROW"
+ ocells, icells, regions = make_cells(rp_args["div"])
+
+ elif "PRO" in rp_args["mode"].upper():
+ regions = len(all_prompts_p[0])
+ mode = "PROMPT"
+ reset_attnmaps(self)
+ self.ex = "EX" in rp_args["mode"].upper()
+ self.target_tokens = target_tokens = tokendealer(self, all_prompts_p)
+ thresholds = [float(x) for x in rp_args["th"].split(",")]
+
+ orig_hw = (height, width)
+ revers = True
+
+ def pcallback(s_self, step: int, timestep: int, latents: torch.Tensor, selfs=None):
+ if "PRO" in mode: # in Prompt mode, make masks from sum of attension maps
+ self.step = step
+
+ if len(self.attnmaps_sizes) > 3:
+ self.history[step] = self.attnmaps.copy()
+ for hw in self.attnmaps_sizes:
+ allmasks = []
+ basemasks = [None] * batch
+ for tt, th in zip(target_tokens, thresholds):
+ for b in range(batch):
+ key = f"{tt}-{b}"
+ _, mask, _ = makepmask(self, self.attnmaps[key], hw[0], hw[1], th, step)
+ mask = mask.unsqueeze(0).unsqueeze(-1)
+ if self.ex:
+ allmasks[b::batch] = [x - mask for x in allmasks[b::batch]]
+ allmasks[b::batch] = [torch.where(x > 0, 1, 0) for x in allmasks[b::batch]]
+ allmasks.append(mask)
+ basemasks[b] = mask if basemasks[b] is None else basemasks[b] + mask
+ basemasks = [1 - mask for mask in basemasks]
+ basemasks = [torch.where(x > 0, 1, 0) for x in basemasks]
+ allmasks = basemasks + allmasks
+
+ self.attnmasks[hw] = torch.cat(allmasks)
+ self.maskready = True
+ return latents
+
+ def hook_forward(module):
+ # diffusers==0.23.2
+ def forward(
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ temb: Optional[torch.Tensor] = None,
+ scale: float = 1.0,
+ ) -> torch.Tensor:
+ attn = module
+ xshape = hidden_states.shape
+ self.hw = (h, w) = split_dims(xshape[1], *orig_hw)
+
+ if revers:
+ nx, px = hidden_states.chunk(2)
+ else:
+ px, nx = hidden_states.chunk(2)
+
+ if equal:
+ hidden_states = torch.cat(
+ [px for i in range(regions)] + [nx for i in range(regions)],
+ 0,
+ )
+ encoder_hidden_states = torch.cat([conds] + [unconds])
+ else:
+ hidden_states = torch.cat([px for i in range(regions)] + [nx], 0)
+ encoder_hidden_states = torch.cat([conds] + [unconds])
+
+ residual = hidden_states
+
+ args = () if USE_PEFT_BACKEND else (scale,)
+
+ if attn.spatial_norm is not None:
+ hidden_states = attn.spatial_norm(hidden_states, temb)
+
+ input_ndim = hidden_states.ndim
+
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+
+ if attention_mask is not None:
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
+
+ if attn.group_norm is not None:
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ args = () if USE_PEFT_BACKEND else (scale,)
+ query = attn.to_q(hidden_states, *args)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ elif attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ key = attn.to_k(encoder_hidden_states, *args)
+ value = attn.to_v(encoder_hidden_states, *args)
+
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
+ # TODO: add support for attn.scale when we move to Torch 2.1
+ hidden_states = scaled_dot_product_attention(
+ self,
+ query,
+ key,
+ value,
+ attn_mask=attention_mask,
+ dropout_p=0.0,
+ is_causal=False,
+ getattn="PRO" in mode,
+ )
+
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states = hidden_states.to(query.dtype)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states, *args)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ #### Regional Prompting Col/Row mode
+ if any(x in mode for x in ["COL", "ROW"]):
+ reshaped = hidden_states.reshape(hidden_states.size()[0], h, w, hidden_states.size()[2])
+ center = reshaped.shape[0] // 2
+ px = reshaped[0:center] if equal else reshaped[0:-batch]
+ nx = reshaped[center:] if equal else reshaped[-batch:]
+ outs = [px, nx] if equal else [px]
+ for out in outs:
+ c = 0
+ for i, ocell in enumerate(ocells):
+ for icell in icells[i]:
+ if "ROW" in mode:
+ out[
+ 0:batch,
+ int(h * ocell[0]) : int(h * ocell[1]),
+ int(w * icell[0]) : int(w * icell[1]),
+ :,
+ ] = out[
+ c * batch : (c + 1) * batch,
+ int(h * ocell[0]) : int(h * ocell[1]),
+ int(w * icell[0]) : int(w * icell[1]),
+ :,
+ ]
+ else:
+ out[
+ 0:batch,
+ int(h * icell[0]) : int(h * icell[1]),
+ int(w * ocell[0]) : int(w * ocell[1]),
+ :,
+ ] = out[
+ c * batch : (c + 1) * batch,
+ int(h * icell[0]) : int(h * icell[1]),
+ int(w * ocell[0]) : int(w * ocell[1]),
+ :,
+ ]
+ c += 1
+ px, nx = (px[0:batch], nx[0:batch]) if equal else (px[0:batch], nx)
+ hidden_states = torch.cat([nx, px], 0) if revers else torch.cat([px, nx], 0)
+ hidden_states = hidden_states.reshape(xshape)
+
+ #### Regional Prompting Prompt mode
+ elif "PRO" in mode:
+ px, nx = (
+ torch.chunk(hidden_states) if equal else hidden_states[0:-batch],
+ hidden_states[-batch:],
+ )
+
+ if (h, w) in self.attnmasks and self.maskready:
+
+ def mask(input):
+ out = torch.multiply(input, self.attnmasks[(h, w)])
+ for b in range(batch):
+ for r in range(1, regions):
+ out[b] = out[b] + out[r * batch + b]
+ return out
+
+ px, nx = (mask(px), mask(nx)) if equal else (mask(px), nx)
+ px, nx = (px[0:batch], nx[0:batch]) if equal else (px[0:batch], nx)
+ hidden_states = torch.cat([nx, px], 0) if revers else torch.cat([px, nx], 0)
+ return hidden_states
+
+ return forward
+
+ def hook_forwards(root_module: torch.nn.Module):
+ for name, module in root_module.named_modules():
+ if "attn2" in name and module.__class__.__name__ == "Attention":
+ module.forward = hook_forward(module)
+
+ hook_forwards(self.unet)
+
+ output = StableDiffusionPipeline(**self.components)(
+ prompt=prompt,
+ prompt_embeds=embs,
+ negative_prompt=negative_prompt,
+ negative_prompt_embeds=n_embs,
+ height=height,
+ width=width,
+ num_inference_steps=num_inference_steps,
+ guidance_scale=guidance_scale,
+ num_images_per_prompt=num_images_per_prompt,
+ eta=eta,
+ generator=generator,
+ latents=latents,
+ output_type=output_type,
+ return_dict=return_dict,
+ callback_on_step_end=pcallback,
+ )
+
+ if "save_mask" in rp_args:
+ save_mask = rp_args["save_mask"]
+ else:
+ save_mask = False
+
+ if mode == "PROMPT" and save_mask:
+ saveattnmaps(
+ self,
+ output,
+ height,
+ width,
+ thresholds,
+ num_inference_steps // 2,
+ regions,
+ )
+
+ return output
+
+
+### Make prompt list for each regions
+def promptsmaker(prompts, batch):
+ out_p = []
+ plen = len(prompts)
+ for prompt in prompts:
+ add = ""
+ if KCOMM in prompt:
+ add, prompt = prompt.split(KCOMM)
+ add = add + " "
+ prompts = prompt.split(KBRK)
+ out_p.append([add + p for p in prompts])
+ out = [None] * batch * len(out_p[0]) * len(out_p)
+ for p, prs in enumerate(out_p): # inputs prompts
+ for r, pr in enumerate(prs): # prompts for regions
+ start = (p + r * plen) * batch
+ out[start : start + batch] = [pr] * batch # P1R1B1,P1R1B2...,P1R2B1,P1R2B2...,P2R1B1...
+ return out, out_p
+
+
+### make regions from ratios
+### ";" makes outercells, "," makes inner cells
+def make_cells(ratios):
+ if ";" not in ratios and "," in ratios:
+ ratios = ratios.replace(",", ";")
+ ratios = ratios.split(";")
+ ratios = [inratios.split(",") for inratios in ratios]
+
+ icells = []
+ ocells = []
+
+ def startend(cells, array):
+ current_start = 0
+ array = [float(x) for x in array]
+ for value in array:
+ end = current_start + (value / sum(array))
+ cells.append([current_start, end])
+ current_start = end
+
+ startend(ocells, [r[0] for r in ratios])
+
+ for inratios in ratios:
+ if 2 > len(inratios):
+ icells.append([[0, 1]])
+ else:
+ add = []
+ startend(add, inratios[1:])
+ icells.append(add)
+
+ return ocells, icells, sum(len(cell) for cell in icells)
+
+
+def make_emblist(self, prompts):
+ with torch.no_grad():
+ tokens = self.tokenizer(
+ prompts,
+ max_length=self.tokenizer.model_max_length,
+ padding=True,
+ truncation=True,
+ return_tensors="pt",
+ ).input_ids.to(self.device)
+ embs = self.text_encoder(tokens, output_hidden_states=True).last_hidden_state.to(self.device, dtype=self.dtype)
+ return embs
+
+
+def split_dims(xs, height, width):
+ def repeat_div(x, y):
+ while y > 0:
+ x = math.ceil(x / 2)
+ y = y - 1
+ return x
+
+ scale = math.ceil(math.log2(math.sqrt(height * width / xs)))
+ dsh = repeat_div(height, scale)
+ dsw = repeat_div(width, scale)
+ return dsh, dsw
+
+
+##### for prompt mode
+def get_attn_maps(self, attn):
+ height, width = self.hw
+ target_tokens = self.target_tokens
+ if (height, width) not in self.attnmaps_sizes:
+ self.attnmaps_sizes.append((height, width))
+
+ for b in range(self.batch):
+ for t in target_tokens:
+ power = self.power
+ add = attn[b, :, :, t[0] : t[0] + len(t)] ** (power) * (self.attnmaps_sizes.index((height, width)) + 1)
+ add = torch.sum(add, dim=2)
+ key = f"{t}-{b}"
+ if key not in self.attnmaps:
+ self.attnmaps[key] = add
+ else:
+ if self.attnmaps[key].shape[1] != add.shape[1]:
+ add = add.view(8, height, width)
+ add = FF.resize(add, self.attnmaps_sizes[0], antialias=None)
+ add = add.reshape_as(self.attnmaps[key])
+
+ self.attnmaps[key] = self.attnmaps[key] + add
+
+
+def reset_attnmaps(self): # init parameters in every batch
+ self.step = 0
+ self.attnmaps = {} # maked from attention maps
+ self.attnmaps_sizes = [] # height,width set of u-net blocks
+ self.attnmasks = {} # maked from attnmaps for regions
+ self.maskready = False
+ self.history = {}
+
+
+def saveattnmaps(self, output, h, w, th, step, regions):
+ masks = []
+ for i, mask in enumerate(self.history[step].values()):
+ img, _, mask = makepmask(self, mask, h, w, th[i % len(th)], step)
+ if self.ex:
+ masks = [x - mask for x in masks]
+ masks.append(mask)
+ if len(masks) == regions - 1:
+ output.images.extend([FF.to_pil_image(mask) for mask in masks])
+ masks = []
+ else:
+ output.images.append(img)
+
+
+def makepmask(
+ self, mask, h, w, th, step
+): # make masks from attention cache return [for preview, for attention, for Latent]
+ th = th - step * 0.005
+ if 0.05 >= th:
+ th = 0.05
+ mask = torch.mean(mask, dim=0)
+ mask = mask / mask.max().item()
+ mask = torch.where(mask > th, 1, 0)
+ mask = mask.float()
+ mask = mask.view(1, *self.attnmaps_sizes[0])
+ img = FF.to_pil_image(mask)
+ img = img.resize((w, h))
+ mask = FF.resize(mask, (h, w), interpolation=FF.InterpolationMode.NEAREST, antialias=None)
+ lmask = mask
+ mask = mask.reshape(h * w)
+ mask = torch.where(mask > 0.1, 1, 0)
+ return img, mask, lmask
+
+
+def tokendealer(self, all_prompts):
+ for prompts in all_prompts:
+ targets = [p.split(",")[-1] for p in prompts[1:]]
+ tt = []
+
+ for target in targets:
+ ptokens = (
+ self.tokenizer(
+ prompts,
+ max_length=self.tokenizer.model_max_length,
+ padding=True,
+ truncation=True,
+ return_tensors="pt",
+ ).input_ids
+ )[0]
+ ttokens = (
+ self.tokenizer(
+ target,
+ max_length=self.tokenizer.model_max_length,
+ padding=True,
+ truncation=True,
+ return_tensors="pt",
+ ).input_ids
+ )[0]
+
+ tlist = []
+
+ for t in range(ttokens.shape[0] - 2):
+ for p in range(ptokens.shape[0]):
+ if ttokens[t + 1] == ptokens[p]:
+ tlist.append(p)
+ if tlist != []:
+ tt.append(tlist)
+
+ return tt
+
+
+def scaled_dot_product_attention(
+ self,
+ query,
+ key,
+ value,
+ attn_mask=None,
+ dropout_p=0.0,
+ is_causal=False,
+ scale=None,
+ getattn=False,
+) -> torch.Tensor:
+ # Efficient implementation equivalent to the following:
+ L, S = query.size(-2), key.size(-2)
+ scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
+ attn_bias = torch.zeros(L, S, dtype=query.dtype, device=self.device)
+ if is_causal:
+ assert attn_mask is None
+ temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
+ attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
+ attn_bias.to(query.dtype)
+
+ if attn_mask is not None:
+ if attn_mask.dtype == torch.bool:
+ attn_mask.masked_fill_(attn_mask.logical_not(), float("-inf"))
+ else:
+ attn_bias += attn_mask
+ attn_weight = query @ key.transpose(-2, -1) * scale_factor
+ attn_weight += attn_bias
+ attn_weight = torch.softmax(attn_weight, dim=-1)
+ if getattn:
+ get_attn_maps(self, attn_weight)
+ attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
+ return attn_weight @ value
diff --git a/diffusers/examples/community/rerender_a_video.py b/diffusers/examples/community/rerender_a_video.py
new file mode 100644
index 0000000000000000000000000000000000000000..d9c616ab5ebc9aa6c504e809e21c5c83f87b57b2
--- /dev/null
+++ b/diffusers/examples/community/rerender_a_video.py
@@ -0,0 +1,1194 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dataclasses import dataclass
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import PIL.Image
+import torch
+import torch.nn.functional as F
+import torchvision.transforms as T
+from gmflow.gmflow import GMFlow
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
+
+from diffusers.image_processor import VaeImageProcessor
+from diffusers.models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
+from diffusers.models.attention_processor import Attention, AttnProcessor
+from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
+from diffusers.pipelines.controlnet.pipeline_controlnet_img2img import StableDiffusionControlNetImg2ImgPipeline
+from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
+from diffusers.schedulers import KarrasDiffusionSchedulers
+from diffusers.utils import BaseOutput, deprecate, logging
+from diffusers.utils.torch_utils import is_compiled_module, randn_tensor
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+def coords_grid(b, h, w, homogeneous=False, device=None):
+ y, x = torch.meshgrid(torch.arange(h), torch.arange(w)) # [H, W]
+
+ stacks = [x, y]
+
+ if homogeneous:
+ ones = torch.ones_like(x) # [H, W]
+ stacks.append(ones)
+
+ grid = torch.stack(stacks, dim=0).float() # [2, H, W] or [3, H, W]
+
+ grid = grid[None].repeat(b, 1, 1, 1) # [B, 2, H, W] or [B, 3, H, W]
+
+ if device is not None:
+ grid = grid.to(device)
+
+ return grid
+
+
+def bilinear_sample(img, sample_coords, mode="bilinear", padding_mode="zeros", return_mask=False):
+ # img: [B, C, H, W]
+ # sample_coords: [B, 2, H, W] in image scale
+ if sample_coords.size(1) != 2: # [B, H, W, 2]
+ sample_coords = sample_coords.permute(0, 3, 1, 2)
+
+ b, _, h, w = sample_coords.shape
+
+ # Normalize to [-1, 1]
+ x_grid = 2 * sample_coords[:, 0] / (w - 1) - 1
+ y_grid = 2 * sample_coords[:, 1] / (h - 1) - 1
+
+ grid = torch.stack([x_grid, y_grid], dim=-1) # [B, H, W, 2]
+
+ img = F.grid_sample(img, grid, mode=mode, padding_mode=padding_mode, align_corners=True)
+
+ if return_mask:
+ mask = (x_grid >= -1) & (y_grid >= -1) & (x_grid <= 1) & (y_grid <= 1) # [B, H, W]
+
+ return img, mask
+
+ return img
+
+
+def flow_warp(feature, flow, mask=False, mode="bilinear", padding_mode="zeros"):
+ b, c, h, w = feature.size()
+ assert flow.size(1) == 2
+
+ grid = coords_grid(b, h, w).to(flow.device) + flow # [B, 2, H, W]
+ grid = grid.to(feature.dtype)
+ return bilinear_sample(feature, grid, mode=mode, padding_mode=padding_mode, return_mask=mask)
+
+
+def forward_backward_consistency_check(fwd_flow, bwd_flow, alpha=0.01, beta=0.5):
+ # fwd_flow, bwd_flow: [B, 2, H, W]
+ # alpha and beta values are following UnFlow
+ # (https://arxiv.org/abs/1711.07837)
+ assert fwd_flow.dim() == 4 and bwd_flow.dim() == 4
+ assert fwd_flow.size(1) == 2 and bwd_flow.size(1) == 2
+ flow_mag = torch.norm(fwd_flow, dim=1) + torch.norm(bwd_flow, dim=1) # [B, H, W]
+
+ warped_bwd_flow = flow_warp(bwd_flow, fwd_flow) # [B, 2, H, W]
+ warped_fwd_flow = flow_warp(fwd_flow, bwd_flow) # [B, 2, H, W]
+
+ diff_fwd = torch.norm(fwd_flow + warped_bwd_flow, dim=1) # [B, H, W]
+ diff_bwd = torch.norm(bwd_flow + warped_fwd_flow, dim=1)
+
+ threshold = alpha * flow_mag + beta
+
+ fwd_occ = (diff_fwd > threshold).float() # [B, H, W]
+ bwd_occ = (diff_bwd > threshold).float()
+
+ return fwd_occ, bwd_occ
+
+
+@torch.no_grad()
+def get_warped_and_mask(flow_model, image1, image2, image3=None, pixel_consistency=False, device=None):
+ if image3 is None:
+ image3 = image1
+ padder = InputPadder(image1.shape, padding_factor=8)
+ image1, image2 = padder.pad(image1[None].to(device), image2[None].to(device))
+ results_dict = flow_model(
+ image1, image2, attn_splits_list=[2], corr_radius_list=[-1], prop_radius_list=[-1], pred_bidir_flow=True
+ )
+ flow_pr = results_dict["flow_preds"][-1] # [B, 2, H, W]
+ fwd_flow = padder.unpad(flow_pr[0]).unsqueeze(0) # [1, 2, H, W]
+ bwd_flow = padder.unpad(flow_pr[1]).unsqueeze(0) # [1, 2, H, W]
+ fwd_occ, bwd_occ = forward_backward_consistency_check(fwd_flow, bwd_flow) # [1, H, W] float
+ if pixel_consistency:
+ warped_image1 = flow_warp(image1, bwd_flow)
+ bwd_occ = torch.clamp(
+ bwd_occ + (abs(image2 - warped_image1).mean(dim=1) > 255 * 0.25).float(), 0, 1
+ ).unsqueeze(0)
+ warped_results = flow_warp(image3, bwd_flow)
+ return warped_results, bwd_occ, bwd_flow
+
+
+blur = T.GaussianBlur(kernel_size=(9, 9), sigma=(18, 18))
+
+
+@dataclass
+class TextToVideoSDPipelineOutput(BaseOutput):
+ """
+ Output class for text-to-video pipelines.
+
+ Args:
+ frames (`List[np.ndarray]` or `torch.Tensor`)
+ List of denoised frames (essentially images) as NumPy arrays of shape `(height, width, num_channels)` or as
+ a `torch` tensor. The length of the list denotes the video length (the number of frames).
+ """
+
+ frames: Union[List[np.ndarray], torch.Tensor]
+
+
+@torch.no_grad()
+def find_flat_region(mask):
+ device = mask.device
+ kernel_x = torch.Tensor([[-1, 0, 1], [-1, 0, 1], [-1, 0, 1]]).unsqueeze(0).unsqueeze(0).to(device)
+ kernel_y = torch.Tensor([[-1, -1, -1], [0, 0, 0], [1, 1, 1]]).unsqueeze(0).unsqueeze(0).to(device)
+ mask_ = F.pad(mask.unsqueeze(0), (1, 1, 1, 1), mode="replicate")
+
+ grad_x = torch.nn.functional.conv2d(mask_, kernel_x)
+ grad_y = torch.nn.functional.conv2d(mask_, kernel_y)
+ return ((abs(grad_x) + abs(grad_y)) == 0).float()[0]
+
+
+class AttnState:
+ STORE = 0
+ LOAD = 1
+ LOAD_AND_STORE_PREV = 2
+
+ def __init__(self):
+ self.reset()
+
+ @property
+ def state(self):
+ return self.__state
+
+ @property
+ def timestep(self):
+ return self.__timestep
+
+ def set_timestep(self, t):
+ self.__timestep = t
+
+ def reset(self):
+ self.__state = AttnState.STORE
+ self.__timestep = 0
+
+ def to_load(self):
+ self.__state = AttnState.LOAD
+
+ def to_load_and_store_prev(self):
+ self.__state = AttnState.LOAD_AND_STORE_PREV
+
+
+class CrossFrameAttnProcessor(AttnProcessor):
+ """
+ Cross frame attention processor. Each frame attends the first frame and previous frame.
+
+ Args:
+ attn_state: Whether the model is processing the first frame or an intermediate frame
+ """
+
+ def __init__(self, attn_state: AttnState):
+ super().__init__()
+ self.attn_state = attn_state
+ self.first_maps = {}
+ self.prev_maps = {}
+
+ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None):
+ # Is self attention
+ if encoder_hidden_states is None:
+ t = self.attn_state.timestep
+ if self.attn_state.state == AttnState.STORE:
+ self.first_maps[t] = hidden_states.detach()
+ self.prev_maps[t] = hidden_states.detach()
+ res = super().__call__(attn, hidden_states, encoder_hidden_states, attention_mask, temb)
+ else:
+ if self.attn_state.state == AttnState.LOAD_AND_STORE_PREV:
+ tmp = hidden_states.detach()
+ cross_map = torch.cat((self.first_maps[t], self.prev_maps[t]), dim=1)
+ res = super().__call__(attn, hidden_states, cross_map, attention_mask, temb)
+ if self.attn_state.state == AttnState.LOAD_AND_STORE_PREV:
+ self.prev_maps[t] = tmp
+ else:
+ res = super().__call__(attn, hidden_states, encoder_hidden_states, attention_mask, temb)
+
+ return res
+
+
+def prepare_image(image):
+ if isinstance(image, torch.Tensor):
+ # Batch single image
+ if image.ndim == 3:
+ image = image.unsqueeze(0)
+
+ image = image.to(dtype=torch.float32)
+ else:
+ # preprocess image
+ if isinstance(image, (PIL.Image.Image, np.ndarray)):
+ image = [image]
+
+ if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):
+ image = [np.array(i.convert("RGB"))[None, :] for i in image]
+ image = np.concatenate(image, axis=0)
+ elif isinstance(image, list) and isinstance(image[0], np.ndarray):
+ image = np.concatenate([i[None, :] for i in image], axis=0)
+
+ image = image.transpose(0, 3, 1, 2)
+ image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
+
+ return image
+
+
+class RerenderAVideoPipeline(StableDiffusionControlNetImg2ImgPipeline):
+ r"""
+ Pipeline for video-to-video translation using Stable Diffusion with Rerender Algorithm.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ In addition the pipeline inherits the following loading methods:
+ - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ controlnet ([`ControlNetModel`] or `List[ControlNetModel]`):
+ Provides additional conditioning to the unet during the denoising process. If you set multiple ControlNets
+ as a list, the outputs from each ControlNet are added together to create one combined additional
+ conditioning.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ feature_extractor ([`CLIPImageProcessor`]):
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+ """
+
+ _optional_components = ["safety_checker", "feature_extractor"]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],
+ scheduler: KarrasDiffusionSchedulers,
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPImageProcessor,
+ image_encoder=None,
+ requires_safety_checker: bool = True,
+ device=None,
+ ):
+ super().__init__(
+ vae,
+ text_encoder,
+ tokenizer,
+ unet,
+ controlnet,
+ scheduler,
+ safety_checker,
+ feature_extractor,
+ image_encoder,
+ requires_safety_checker,
+ )
+ self.to(device)
+
+ if safety_checker is None and requires_safety_checker:
+ logger.warning(
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
+ )
+
+ if safety_checker is not None and feature_extractor is None:
+ raise ValueError(
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
+ )
+
+ if isinstance(controlnet, (list, tuple)):
+ controlnet = MultiControlNetModel(controlnet)
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ controlnet=controlnet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
+ self.control_image_processor = VaeImageProcessor(
+ vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
+ )
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
+ self.attn_state = AttnState()
+ attn_processor_dict = {}
+ for k in unet.attn_processors.keys():
+ if k.startswith("up"):
+ attn_processor_dict[k] = CrossFrameAttnProcessor(self.attn_state)
+ else:
+ attn_processor_dict[k] = AttnProcessor()
+
+ self.unet.set_attn_processor(attn_processor_dict)
+
+ flow_model = GMFlow(
+ feature_channels=128,
+ num_scales=1,
+ upsample_factor=8,
+ num_head=1,
+ attention_type="swin",
+ ffn_dim_expansion=4,
+ num_transformer_layers=6,
+ ).to(self.device)
+
+ checkpoint = torch.utils.model_zoo.load_url(
+ "https://huggingface.co/Anonymous-sub/Rerender/resolve/main/models/gmflow_sintel-0c07dcb3.pth",
+ map_location=lambda storage, loc: storage,
+ )
+ weights = checkpoint["model"] if "model" in checkpoint else checkpoint
+ flow_model.load_state_dict(weights, strict=False)
+ flow_model.eval()
+ self.flow_model = flow_model
+
+ # Modified from src/diffusers/pipelines/controlnet/pipeline_controlnet.StableDiffusionControlNetImg2ImgPipeline.check_inputs
+ def check_inputs(
+ self,
+ prompt,
+ callback_steps,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ controlnet_conditioning_scale=1.0,
+ control_guidance_start=0.0,
+ control_guidance_end=1.0,
+ ):
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ # `prompt` needs more sophisticated handling when there are multiple
+ # conditionings.
+ if isinstance(self.controlnet, MultiControlNetModel):
+ if isinstance(prompt, list):
+ logger.warning(
+ f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}"
+ " prompts. The conditionings will be fixed across the prompts."
+ )
+
+ is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
+ self.controlnet, torch._dynamo.eval_frame.OptimizedModule
+ )
+
+ # Check `controlnet_conditioning_scale`
+ if (
+ isinstance(self.controlnet, ControlNetModel)
+ or is_compiled
+ and isinstance(self.controlnet._orig_mod, ControlNetModel)
+ ):
+ if not isinstance(controlnet_conditioning_scale, float):
+ raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
+ elif (
+ isinstance(self.controlnet, MultiControlNetModel)
+ or is_compiled
+ and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
+ ):
+ if isinstance(controlnet_conditioning_scale, list):
+ if any(isinstance(i, list) for i in controlnet_conditioning_scale):
+ raise ValueError("A single batch of multiple conditionings are supported at the moment.")
+ elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len(
+ self.controlnet.nets
+ ):
+ raise ValueError(
+ "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have"
+ " the same length as the number of controlnets"
+ )
+ else:
+ assert False
+
+ if len(control_guidance_start) != len(control_guidance_end):
+ raise ValueError(
+ f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list."
+ )
+
+ if isinstance(self.controlnet, MultiControlNetModel):
+ if len(control_guidance_start) != len(self.controlnet.nets):
+ raise ValueError(
+ f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}."
+ )
+
+ for start, end in zip(control_guidance_start, control_guidance_end):
+ if start >= end:
+ raise ValueError(
+ f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}."
+ )
+ if start < 0.0:
+ raise ValueError(f"control guidance start: {start} can't be smaller than 0.")
+ if end > 1.0:
+ raise ValueError(f"control guidance end: {end} can't be larger than 1.0.")
+
+ # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image
+ def prepare_control_image(
+ self,
+ image,
+ width,
+ height,
+ batch_size,
+ num_images_per_prompt,
+ device,
+ dtype,
+ do_classifier_free_guidance=False,
+ guess_mode=False,
+ ):
+ image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
+ image_batch_size = image.shape[0]
+
+ if image_batch_size == 1:
+ repeat_by = batch_size
+ else:
+ # image batch size is the same as prompt batch size
+ repeat_by = num_images_per_prompt
+
+ image = image.repeat_interleave(repeat_by, dim=0)
+
+ image = image.to(device=device, dtype=dtype)
+
+ if do_classifier_free_guidance and not guess_mode:
+ image = torch.cat([image] * 2)
+
+ return image
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
+ def get_timesteps(self, num_inference_steps, strength, device):
+ # get the original timestep using init_timestep
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
+
+ t_start = max(num_inference_steps - init_timestep, 0)
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
+
+ return timesteps, num_inference_steps - t_start
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.prepare_latents
+ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
+ if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
+ raise ValueError(
+ f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
+ )
+
+ image = image.to(device=device, dtype=dtype)
+
+ batch_size = batch_size * num_images_per_prompt
+
+ if image.shape[1] == 4:
+ init_latents = image
+
+ else:
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ elif isinstance(generator, list):
+ init_latents = [
+ self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
+ ]
+ init_latents = torch.cat(init_latents, dim=0)
+ else:
+ init_latents = self.vae.encode(image).latent_dist.sample(generator)
+
+ init_latents = self.vae.config.scaling_factor * init_latents
+
+ if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
+ # expand init_latents for batch_size
+ deprecation_message = (
+ f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial"
+ " images (`image`). Initial images are now duplicating to match the number of text prompts. Note"
+ " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
+ " your script to pass as many initial images as text prompts to suppress this warning."
+ )
+ deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False)
+ additional_image_per_prompt = batch_size // init_latents.shape[0]
+ init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0)
+ elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
+ raise ValueError(
+ f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
+ )
+ else:
+ init_latents = torch.cat([init_latents], dim=0)
+
+ shape = init_latents.shape
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+
+ # get latents
+ init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
+ latents = init_latents
+
+ return latents
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ frames: Union[List[np.ndarray], torch.Tensor] = None,
+ control_frames: Union[List[np.ndarray], torch.Tensor] = None,
+ strength: float = 0.8,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
+ callback_steps: int = 1,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ controlnet_conditioning_scale: Union[float, List[float]] = 0.8,
+ guess_mode: bool = False,
+ control_guidance_start: Union[float, List[float]] = 0.0,
+ control_guidance_end: Union[float, List[float]] = 1.0,
+ warp_start: Union[float, List[float]] = 0.0,
+ warp_end: Union[float, List[float]] = 0.3,
+ mask_start: Union[float, List[float]] = 0.5,
+ mask_end: Union[float, List[float]] = 0.8,
+ smooth_boundary: bool = True,
+ mask_strength: Union[float, List[float]] = 0.5,
+ inner_strength: Union[float, List[float]] = 0.9,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ frames (`List[np.ndarray]` or `torch.Tensor`): The input images to be used as the starting point for the image generation process.
+ control_frames (`List[np.ndarray]` or `torch.Tensor`): The ControlNet input images condition to provide guidance to the `unet` for generation.
+ strength ('float'): SDEdit strength.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
+ The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added
+ to the residual in the original unet. If multiple ControlNets are specified in init, you can set the
+ corresponding scale as a list. Note that by default, we use a smaller conditioning scale for inpainting
+ than for [`~StableDiffusionControlNetPipeline.__call__`].
+ guess_mode (`bool`, *optional*, defaults to `False`):
+ In this mode, the ControlNet encoder will try best to recognize the content of the input image even if
+ you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended.
+ control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
+ The percentage of total steps at which the controlnet starts applying.
+ control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
+ The percentage of total steps at which the controlnet stops applying.
+ warp_start (`float`): Shape-aware fusion start timestep.
+ warp_end (`float`): Shape-aware fusion end timestep.
+ mask_start (`float`): Pixel-aware fusion start timestep.
+ mask_end (`float`):Pixel-aware fusion end timestep.
+ smooth_boundary (`bool`): Smooth fusion boundary. Set `True` to prevent artifacts at boundary.
+ mask_strength (`float`): Pixel-aware fusion strength.
+ inner_strength (`float`): Pixel-aware fusion detail level.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+ controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
+
+ # align format for control guidance
+ if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
+ control_guidance_start = len(control_guidance_end) * [control_guidance_start]
+ elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
+ control_guidance_end = len(control_guidance_start) * [control_guidance_end]
+ elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
+ mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
+ control_guidance_start, control_guidance_end = (
+ mult * [control_guidance_start],
+ mult * [control_guidance_end],
+ )
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ callback_steps,
+ negative_prompt,
+ prompt_embeds,
+ negative_prompt_embeds,
+ controlnet_conditioning_scale,
+ control_guidance_start,
+ control_guidance_end,
+ )
+
+ # 2. Define call parameters
+ # Currently we only support 1 prompt
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ assert False
+ else:
+ assert False
+ num_images_per_prompt = 1
+
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
+ controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
+
+ global_pool_conditions = (
+ controlnet.config.global_pool_conditions
+ if isinstance(controlnet, ControlNetModel)
+ else controlnet.nets[0].config.global_pool_conditions
+ )
+ guess_mode = guess_mode or global_pool_conditions
+
+ # 3. Encode input prompt
+ text_encoder_lora_scale = (
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
+ )
+ prompt_embeds = self._encode_prompt(
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ lora_scale=text_encoder_lora_scale,
+ )
+
+ # 4. Process the first frame
+ height, width = None, None
+ output_frames = []
+ self.attn_state.reset()
+
+ # 4.1 prepare frames
+ image = self.image_processor.preprocess(frames[0]).to(dtype=torch.float32)
+ first_image = image[0] # C, H, W
+
+ # 4.2 Prepare controlnet_conditioning_image
+ # Currently we only support single control
+ if isinstance(controlnet, ControlNetModel):
+ control_image = self.prepare_control_image(
+ image=control_frames[0],
+ width=width,
+ height=height,
+ batch_size=batch_size,
+ num_images_per_prompt=1,
+ device=device,
+ dtype=controlnet.dtype,
+ do_classifier_free_guidance=do_classifier_free_guidance,
+ guess_mode=guess_mode,
+ )
+ else:
+ assert False
+
+ # 4.3 Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps, cur_num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
+ latent_timestep = timesteps[:1].repeat(batch_size)
+
+ # 4.4 Prepare latent variables
+ latents = self.prepare_latents(
+ image,
+ latent_timestep,
+ batch_size,
+ num_images_per_prompt,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ )
+
+ # 4.5 Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 4.6 Create tensor stating which controlnets to keep
+ controlnet_keep = []
+ for i in range(len(timesteps)):
+ keeps = [
+ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
+ for s, e in zip(control_guidance_start, control_guidance_end)
+ ]
+ controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)
+
+ first_x0_list = []
+
+ # 4.7 Denoising loop
+ num_warmup_steps = len(timesteps) - cur_num_inference_steps * self.scheduler.order
+ with self.progress_bar(total=cur_num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ self.attn_state.set_timestep(t.item())
+
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # controlnet(s) inference
+ if guess_mode and do_classifier_free_guidance:
+ # Infer ControlNet only for the conditional batch.
+ control_model_input = latents
+ control_model_input = self.scheduler.scale_model_input(control_model_input, t)
+ controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
+ else:
+ control_model_input = latent_model_input
+ controlnet_prompt_embeds = prompt_embeds
+
+ if isinstance(controlnet_keep[i], list):
+ cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
+ else:
+ controlnet_cond_scale = controlnet_conditioning_scale
+ if isinstance(controlnet_cond_scale, list):
+ controlnet_cond_scale = controlnet_cond_scale[0]
+ cond_scale = controlnet_cond_scale * controlnet_keep[i]
+
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
+ control_model_input,
+ t,
+ encoder_hidden_states=controlnet_prompt_embeds,
+ controlnet_cond=control_image,
+ conditioning_scale=cond_scale,
+ guess_mode=guess_mode,
+ return_dict=False,
+ )
+
+ if guess_mode and do_classifier_free_guidance:
+ # Inferred ControlNet only for the conditional batch.
+ # To apply the output of ControlNet to both the unconditional and conditional batches,
+ # add 0 to the unconditional batch to keep it unchanged.
+ down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
+ mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
+
+ # predict the noise residual
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs=cross_attention_kwargs,
+ down_block_additional_residuals=down_block_res_samples,
+ mid_block_additional_residual=mid_block_res_sample,
+ return_dict=False,
+ )[0]
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ alpha_prod_t = self.scheduler.alphas_cumprod[t]
+ beta_prod_t = 1 - alpha_prod_t
+ pred_x0 = (latents - beta_prod_t ** (0.5) * noise_pred) / alpha_prod_t ** (0.5)
+ first_x0 = pred_x0.detach()
+ first_x0_list.append(first_x0)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, latents)
+
+ if not output_type == "latent":
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+ else:
+ image = latents
+
+ first_result = image
+ prev_result = image
+ do_denormalize = [True] * image.shape[0]
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
+
+ output_frames.append(image[0])
+
+ # 5. Process each frame
+ for idx in range(1, len(frames)):
+ image = frames[idx]
+ prev_image = frames[idx - 1]
+ control_image = control_frames[idx]
+ # 5.1 prepare frames
+ image = self.image_processor.preprocess(image).to(dtype=torch.float32)
+ prev_image = self.image_processor.preprocess(prev_image).to(dtype=torch.float32)
+
+ warped_0, bwd_occ_0, bwd_flow_0 = get_warped_and_mask(
+ self.flow_model, first_image, image[0], first_result, False, self.device
+ )
+ blend_mask_0 = blur(F.max_pool2d(bwd_occ_0, kernel_size=9, stride=1, padding=4))
+ blend_mask_0 = torch.clamp(blend_mask_0 + bwd_occ_0, 0, 1)
+
+ warped_pre, bwd_occ_pre, bwd_flow_pre = get_warped_and_mask(
+ self.flow_model, prev_image[0], image[0], prev_result, False, self.device
+ )
+ blend_mask_pre = blur(F.max_pool2d(bwd_occ_pre, kernel_size=9, stride=1, padding=4))
+ blend_mask_pre = torch.clamp(blend_mask_pre + bwd_occ_pre, 0, 1)
+
+ warp_mask = 1 - F.max_pool2d(blend_mask_0, kernel_size=8)
+ warp_flow = F.interpolate(bwd_flow_0 / 8.0, scale_factor=1.0 / 8, mode="bilinear")
+
+ # 5.2 Prepare controlnet_conditioning_image
+ # Currently we only support single control
+ if isinstance(controlnet, ControlNetModel):
+ control_image = self.prepare_control_image(
+ image=control_image,
+ width=width,
+ height=height,
+ batch_size=batch_size,
+ num_images_per_prompt=1,
+ device=device,
+ dtype=controlnet.dtype,
+ do_classifier_free_guidance=do_classifier_free_guidance,
+ guess_mode=guess_mode,
+ )
+ else:
+ assert False
+
+ # 5.3 Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps, cur_num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
+ latent_timestep = timesteps[:1].repeat(batch_size)
+
+ skip_t = int(num_inference_steps * (1 - strength))
+ warp_start_t = int(warp_start * num_inference_steps)
+ warp_end_t = int(warp_end * num_inference_steps)
+ mask_start_t = int(mask_start * num_inference_steps)
+ mask_end_t = int(mask_end * num_inference_steps)
+
+ # 5.4 Prepare latent variables
+ init_latents = self.prepare_latents(
+ image,
+ latent_timestep,
+ batch_size,
+ num_images_per_prompt,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ )
+
+ # 5.5 Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 5.6 Create tensor stating which controlnets to keep
+ controlnet_keep = []
+ for i in range(len(timesteps)):
+ keeps = [
+ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
+ for s, e in zip(control_guidance_start, control_guidance_end)
+ ]
+ controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)
+
+ # 5.7 Denoising loop
+ num_warmup_steps = len(timesteps) - cur_num_inference_steps * self.scheduler.order
+
+ def denoising_loop(latents, mask=None, xtrg=None, noise_rescale=None):
+ dir_xt = 0
+ latents_dtype = latents.dtype
+ with self.progress_bar(total=cur_num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ self.attn_state.set_timestep(t.item())
+ if i + skip_t >= mask_start_t and i + skip_t <= mask_end_t and xtrg is not None:
+ rescale = torch.maximum(1.0 - mask, (1 - mask**2) ** 0.5 * inner_strength)
+ if noise_rescale is not None:
+ rescale = (1.0 - mask) * (1 - noise_rescale) + rescale * noise_rescale
+ noise = randn_tensor(xtrg.shape, generator=generator, device=device, dtype=xtrg.dtype)
+ latents_ref = self.scheduler.add_noise(xtrg, noise, t)
+ latents = latents_ref * mask + (1.0 - mask) * (latents - dir_xt) + rescale * dir_xt
+ latents = latents.to(latents_dtype)
+
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # controlnet(s) inference
+ if guess_mode and do_classifier_free_guidance:
+ # Infer ControlNet only for the conditional batch.
+ control_model_input = latents
+ control_model_input = self.scheduler.scale_model_input(control_model_input, t)
+ controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
+ else:
+ control_model_input = latent_model_input
+ controlnet_prompt_embeds = prompt_embeds
+
+ if isinstance(controlnet_keep[i], list):
+ cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
+ else:
+ controlnet_cond_scale = controlnet_conditioning_scale
+ if isinstance(controlnet_cond_scale, list):
+ controlnet_cond_scale = controlnet_cond_scale[0]
+ cond_scale = controlnet_cond_scale * controlnet_keep[i]
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
+ control_model_input,
+ t,
+ encoder_hidden_states=controlnet_prompt_embeds,
+ controlnet_cond=control_image,
+ conditioning_scale=cond_scale,
+ guess_mode=guess_mode,
+ return_dict=False,
+ )
+
+ if guess_mode and do_classifier_free_guidance:
+ # Inferred ControlNet only for the conditional batch.
+ # To apply the output of ControlNet to both the unconditional and conditional batches,
+ # add 0 to the unconditional batch to keep it unchanged.
+ down_block_res_samples = [
+ torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples
+ ]
+ mid_block_res_sample = torch.cat(
+ [torch.zeros_like(mid_block_res_sample), mid_block_res_sample]
+ )
+
+ # predict the noise residual
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs=cross_attention_kwargs,
+ down_block_additional_residuals=down_block_res_samples,
+ mid_block_additional_residual=mid_block_res_sample,
+ return_dict=False,
+ )[0]
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # Get pred_x0 from scheduler
+ alpha_prod_t = self.scheduler.alphas_cumprod[t]
+ beta_prod_t = 1 - alpha_prod_t
+ pred_x0 = (latents - beta_prod_t ** (0.5) * noise_pred) / alpha_prod_t ** (0.5)
+
+ if i + skip_t >= warp_start_t and i + skip_t <= warp_end_t:
+ # warp x_0
+ pred_x0 = (
+ flow_warp(first_x0_list[i], warp_flow, mode="nearest") * warp_mask
+ + (1 - warp_mask) * pred_x0
+ )
+
+ # get x_t from x_0
+ latents = self.scheduler.add_noise(pred_x0, noise_pred, t).to(latents_dtype)
+
+ prev_t = t - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps
+ if i == len(timesteps) - 1:
+ alpha_t_prev = 1.0
+ else:
+ alpha_t_prev = self.scheduler.alphas_cumprod[prev_t]
+
+ dir_xt = (1.0 - alpha_t_prev) ** 0.5 * noise_pred
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[
+ 0
+ ]
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or (
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
+ ):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, latents)
+
+ return latents
+
+ if mask_start_t <= mask_end_t:
+ self.attn_state.to_load()
+ else:
+ self.attn_state.to_load_and_store_prev()
+ latents = denoising_loop(init_latents)
+
+ if mask_start_t <= mask_end_t:
+ direct_result = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+
+ blend_results = (1 - blend_mask_pre) * warped_pre + blend_mask_pre * direct_result
+ blend_results = (1 - blend_mask_0) * warped_0 + blend_mask_0 * blend_results
+
+ bwd_occ = 1 - torch.clamp(1 - bwd_occ_pre + 1 - bwd_occ_0, 0, 1)
+ blend_mask = blur(F.max_pool2d(bwd_occ, kernel_size=9, stride=1, padding=4))
+ blend_mask = 1 - torch.clamp(blend_mask + bwd_occ, 0, 1)
+
+ blend_results = blend_results.to(latents.dtype)
+ xtrg = self.vae.encode(blend_results).latent_dist.sample(generator)
+ xtrg = self.vae.config.scaling_factor * xtrg
+ blend_results_rec = self.vae.decode(xtrg / self.vae.config.scaling_factor, return_dict=False)[0]
+ xtrg_rec = self.vae.encode(blend_results_rec).latent_dist.sample(generator)
+ xtrg_rec = self.vae.config.scaling_factor * xtrg_rec
+ xtrg_ = xtrg + (xtrg - xtrg_rec)
+ blend_results_rec_new = self.vae.decode(xtrg_ / self.vae.config.scaling_factor, return_dict=False)[0]
+ tmp = (abs(blend_results_rec_new - blend_results).mean(dim=1, keepdims=True) > 0.25).float()
+
+ mask_x = F.max_pool2d(
+ (F.interpolate(tmp, scale_factor=1 / 8.0, mode="bilinear") > 0).float(),
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ )
+
+ mask = 1 - F.max_pool2d(1 - blend_mask, kernel_size=8) # * (1-mask_x)
+
+ if smooth_boundary:
+ noise_rescale = find_flat_region(mask)
+ else:
+ noise_rescale = torch.ones_like(mask)
+
+ xtrg = (xtrg + (1 - mask_x) * (xtrg - xtrg_rec)) * mask
+ xtrg = xtrg.to(latents.dtype)
+
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps, cur_num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
+
+ self.attn_state.to_load_and_store_prev()
+ latents = denoising_loop(init_latents, mask * mask_strength, xtrg, noise_rescale)
+
+ if not output_type == "latent":
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+ else:
+ image = latents
+
+ prev_result = image
+
+ do_denormalize = [True] * image.shape[0]
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
+
+ output_frames.append(image[0])
+
+ # Offload last model to CPU
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.final_offload_hook.offload()
+
+ if not return_dict:
+ return output_frames
+
+ return TextToVideoSDPipelineOutput(frames=output_frames)
+
+
+class InputPadder:
+ """Pads images such that dimensions are divisible by 8"""
+
+ def __init__(self, dims, mode="sintel", padding_factor=8):
+ self.ht, self.wd = dims[-2:]
+ pad_ht = (((self.ht // padding_factor) + 1) * padding_factor - self.ht) % padding_factor
+ pad_wd = (((self.wd // padding_factor) + 1) * padding_factor - self.wd) % padding_factor
+ if mode == "sintel":
+ self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, pad_ht // 2, pad_ht - pad_ht // 2]
+ else:
+ self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, 0, pad_ht]
+
+ def pad(self, *inputs):
+ return [F.pad(x, self._pad, mode="replicate") for x in inputs]
+
+ def unpad(self, x):
+ ht, wd = x.shape[-2:]
+ c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]]
+ return x[..., c[0] : c[1], c[2] : c[3]]
diff --git a/diffusers/examples/community/run_onnx_controlnet.py b/diffusers/examples/community/run_onnx_controlnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..af2672c17e594ab1b03c63cdc16ffbefef947a1c
--- /dev/null
+++ b/diffusers/examples/community/run_onnx_controlnet.py
@@ -0,0 +1,911 @@
+import argparse
+import inspect
+import os
+import time
+import warnings
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import numpy as np
+import PIL.Image
+import torch
+from PIL import Image
+from transformers import CLIPTokenizer
+
+from diffusers import OnnxRuntimeModel, StableDiffusionImg2ImgPipeline, UniPCMultistepScheduler
+from diffusers.image_processor import VaeImageProcessor
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline
+from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
+from diffusers.schedulers import KarrasDiffusionSchedulers
+from diffusers.utils import (
+ deprecate,
+ logging,
+ replace_example_docstring,
+)
+from diffusers.utils.torch_utils import randn_tensor
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> # !pip install opencv-python transformers accelerate
+ >>> from diffusers import StableDiffusionControlNetImg2ImgPipeline, ControlNetModel, UniPCMultistepScheduler
+ >>> from diffusers.utils import load_image
+ >>> import numpy as np
+ >>> import torch
+
+ >>> import cv2
+ >>> from PIL import Image
+
+ >>> # download an image
+ >>> image = load_image(
+ ... "https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png"
+ ... )
+ >>> np_image = np.array(image)
+
+ >>> # get canny image
+ >>> np_image = cv2.Canny(np_image, 100, 200)
+ >>> np_image = np_image[:, :, None]
+ >>> np_image = np.concatenate([np_image, np_image, np_image], axis=2)
+ >>> canny_image = Image.fromarray(np_image)
+
+ >>> # load control net and stable diffusion v1-5
+ >>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16)
+ >>> pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
+ ... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16
+ ... )
+
+ >>> # speed up diffusion process with faster scheduler and memory optimization
+ >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
+ >>> pipe.enable_model_cpu_offload()
+
+ >>> # generate image
+ >>> generator = torch.manual_seed(0)
+ >>> image = pipe(
+ ... "futuristic-looking woman",
+ ... num_inference_steps=20,
+ ... generator=generator,
+ ... image=image,
+ ... control_image=canny_image,
+ ... ).images[0]
+ ```
+"""
+
+
+def prepare_image(image):
+ if isinstance(image, torch.Tensor):
+ # Batch single image
+ if image.ndim == 3:
+ image = image.unsqueeze(0)
+
+ image = image.to(dtype=torch.float32)
+ else:
+ # preprocess image
+ if isinstance(image, (PIL.Image.Image, np.ndarray)):
+ image = [image]
+
+ if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):
+ image = [np.array(i.convert("RGB"))[None, :] for i in image]
+ image = np.concatenate(image, axis=0)
+ elif isinstance(image, list) and isinstance(image[0], np.ndarray):
+ image = np.concatenate([i[None, :] for i in image], axis=0)
+
+ image = image.transpose(0, 3, 1, 2)
+ image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
+
+ return image
+
+
+class OnnxStableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline):
+ vae_encoder: OnnxRuntimeModel
+ vae_decoder: OnnxRuntimeModel
+ text_encoder: OnnxRuntimeModel
+ tokenizer: CLIPTokenizer
+ unet: OnnxRuntimeModel
+ scheduler: KarrasDiffusionSchedulers
+
+ def __init__(
+ self,
+ vae_encoder: OnnxRuntimeModel,
+ vae_decoder: OnnxRuntimeModel,
+ text_encoder: OnnxRuntimeModel,
+ tokenizer: CLIPTokenizer,
+ unet: OnnxRuntimeModel,
+ scheduler: KarrasDiffusionSchedulers,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae_encoder=vae_encoder,
+ vae_decoder=vae_decoder,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ )
+ self.vae_scale_factor = 2 ** (4 - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
+ self.control_image_processor = VaeImageProcessor(
+ vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
+ )
+
+ def _encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ num_images_per_prompt: Optional[int],
+ do_classifier_free_guidance: bool,
+ negative_prompt: Optional[str],
+ prompt_embeds: Optional[np.ndarray] = None,
+ negative_prompt_embeds: Optional[np.ndarray] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`):
+ prompt to be encoded
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ prompt_embeds (`np.ndarray`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`np.ndarray`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ """
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ # get prompt text embeddings
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="np",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="np").input_ids
+
+ if not np.array_equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer.batch_decode(
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
+ )
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ prompt_embeds = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0]
+
+ prompt_embeds = np.repeat(prompt_embeds, num_images_per_prompt, axis=0)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt] * batch_size
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="np",
+ )
+ negative_prompt_embeds = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0]
+
+ if do_classifier_free_guidance:
+ negative_prompt_embeds = np.repeat(negative_prompt_embeds, num_images_per_prompt, axis=0)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ prompt_embeds = np.concatenate([negative_prompt_embeds, prompt_embeds])
+
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
+ def decode_latents(self, latents):
+ warnings.warn(
+ "The decode_latents method is deprecated and will be removed in a future version. Please"
+ " use VaeImageProcessor instead",
+ FutureWarning,
+ )
+ latents = 1 / self.vae.config.scaling_factor * latents
+ image = self.vae.decode(latents, return_dict=False)[0]
+ image = (image / 2 + 0.5).clamp(0, 1)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+ return image
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(
+ self,
+ num_controlnet,
+ prompt,
+ image,
+ callback_steps,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ controlnet_conditioning_scale=1.0,
+ control_guidance_start=0.0,
+ control_guidance_end=1.0,
+ ):
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ # Check `image`
+ if num_controlnet == 1:
+ self.check_image(image, prompt, prompt_embeds)
+ elif num_controlnet > 1:
+ if not isinstance(image, list):
+ raise TypeError("For multiple controlnets: `image` must be type `list`")
+
+ # When `image` is a nested list:
+ # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]])
+ elif any(isinstance(i, list) for i in image):
+ raise ValueError("A single batch of multiple conditionings are supported at the moment.")
+ elif len(image) != num_controlnet:
+ raise ValueError(
+ f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {num_controlnet} ControlNets."
+ )
+
+ for image_ in image:
+ self.check_image(image_, prompt, prompt_embeds)
+ else:
+ assert False
+
+ # Check `controlnet_conditioning_scale`
+ if num_controlnet == 1:
+ if not isinstance(controlnet_conditioning_scale, float):
+ raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
+ elif num_controlnet > 1:
+ if isinstance(controlnet_conditioning_scale, list):
+ if any(isinstance(i, list) for i in controlnet_conditioning_scale):
+ raise ValueError("A single batch of multiple conditionings are supported at the moment.")
+ elif (
+ isinstance(controlnet_conditioning_scale, list)
+ and len(controlnet_conditioning_scale) != num_controlnet
+ ):
+ raise ValueError(
+ "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have"
+ " the same length as the number of controlnets"
+ )
+ else:
+ assert False
+
+ if len(control_guidance_start) != len(control_guidance_end):
+ raise ValueError(
+ f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list."
+ )
+
+ if num_controlnet > 1:
+ if len(control_guidance_start) != num_controlnet:
+ raise ValueError(
+ f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {num_controlnet} controlnets available. Make sure to provide {num_controlnet}."
+ )
+
+ for start, end in zip(control_guidance_start, control_guidance_end):
+ if start >= end:
+ raise ValueError(
+ f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}."
+ )
+ if start < 0.0:
+ raise ValueError(f"control guidance start: {start} can't be smaller than 0.")
+ if end > 1.0:
+ raise ValueError(f"control guidance end: {end} can't be larger than 1.0.")
+
+ # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image
+ def check_image(self, image, prompt, prompt_embeds):
+ image_is_pil = isinstance(image, PIL.Image.Image)
+ image_is_tensor = isinstance(image, torch.Tensor)
+ image_is_np = isinstance(image, np.ndarray)
+ image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)
+ image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)
+ image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray)
+
+ if (
+ not image_is_pil
+ and not image_is_tensor
+ and not image_is_np
+ and not image_is_pil_list
+ and not image_is_tensor_list
+ and not image_is_np_list
+ ):
+ raise TypeError(
+ f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}"
+ )
+
+ if image_is_pil:
+ image_batch_size = 1
+ else:
+ image_batch_size = len(image)
+
+ if prompt is not None and isinstance(prompt, str):
+ prompt_batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ prompt_batch_size = len(prompt)
+ elif prompt_embeds is not None:
+ prompt_batch_size = prompt_embeds.shape[0]
+
+ if image_batch_size != 1 and image_batch_size != prompt_batch_size:
+ raise ValueError(
+ f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
+ )
+
+ # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image
+ def prepare_control_image(
+ self,
+ image,
+ width,
+ height,
+ batch_size,
+ num_images_per_prompt,
+ device,
+ dtype,
+ do_classifier_free_guidance=False,
+ guess_mode=False,
+ ):
+ image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
+ image_batch_size = image.shape[0]
+
+ if image_batch_size == 1:
+ repeat_by = batch_size
+ else:
+ # image batch size is the same as prompt batch size
+ repeat_by = num_images_per_prompt
+
+ image = image.repeat_interleave(repeat_by, dim=0)
+
+ image = image.to(device=device, dtype=dtype)
+
+ if do_classifier_free_guidance and not guess_mode:
+ image = torch.cat([image] * 2)
+
+ return image
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
+ def get_timesteps(self, num_inference_steps, strength, device):
+ # get the original timestep using init_timestep
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
+
+ t_start = max(num_inference_steps - init_timestep, 0)
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
+
+ return timesteps, num_inference_steps - t_start
+
+ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
+ if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
+ raise ValueError(
+ f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
+ )
+
+ image = image.to(device=device, dtype=dtype)
+
+ batch_size = batch_size * num_images_per_prompt
+
+ if image.shape[1] == 4:
+ init_latents = image
+
+ else:
+ _image = image.cpu().detach().numpy()
+ init_latents = self.vae_encoder(sample=_image)[0]
+ init_latents = torch.from_numpy(init_latents).to(device=device, dtype=dtype)
+ init_latents = 0.18215 * init_latents
+
+ if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
+ # expand init_latents for batch_size
+ deprecation_message = (
+ f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial"
+ " images (`image`). Initial images are now duplicating to match the number of text prompts. Note"
+ " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
+ " your script to pass as many initial images as text prompts to suppress this warning."
+ )
+ deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False)
+ additional_image_per_prompt = batch_size // init_latents.shape[0]
+ init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0)
+ elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
+ raise ValueError(
+ f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
+ )
+ else:
+ init_latents = torch.cat([init_latents], dim=0)
+
+ shape = init_latents.shape
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+
+ # get latents
+ init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
+ latents = init_latents
+
+ return latents
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ num_controlnet: int,
+ fp16: bool = True,
+ prompt: Union[str, List[str]] = None,
+ image: Union[
+ torch.Tensor,
+ PIL.Image.Image,
+ np.ndarray,
+ List[torch.Tensor],
+ List[PIL.Image.Image],
+ List[np.ndarray],
+ ] = None,
+ control_image: Union[
+ torch.Tensor,
+ PIL.Image.Image,
+ np.ndarray,
+ List[torch.Tensor],
+ List[PIL.Image.Image],
+ List[np.ndarray],
+ ] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ strength: float = 0.8,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
+ callback_steps: int = 1,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ controlnet_conditioning_scale: Union[float, List[float]] = 0.8,
+ guess_mode: bool = False,
+ control_guidance_start: Union[float, List[float]] = 0.0,
+ control_guidance_end: Union[float, List[float]] = 1.0,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
+ `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
+ The initial image will be used as the starting point for the image generation process. Can also accept
+ image latents as `image`, if passing latents directly, it will not be encoded again.
+ control_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
+ `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
+ The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If
+ the type is specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can
+ also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If
+ height and/or width are passed, `image` is resized according to them. If multiple ControlNets are
+ specified in init, images must be passed as a list such that each element of the list can be correctly
+ batched for input to a single controlnet.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
+ The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added
+ to the residual in the original unet. If multiple ControlNets are specified in init, you can set the
+ corresponding scale as a list. Note that by default, we use a smaller conditioning scale for inpainting
+ than for [`~StableDiffusionControlNetPipeline.__call__`].
+ guess_mode (`bool`, *optional*, defaults to `False`):
+ In this mode, the ControlNet encoder will try best to recognize the content of the input image even if
+ you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended.
+ control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
+ The percentage of total steps at which the controlnet starts applying.
+ control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
+ The percentage of total steps at which the controlnet stops applying.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+ if fp16:
+ torch_dtype = torch.float16
+ np_dtype = np.float16
+ else:
+ torch_dtype = torch.float32
+ np_dtype = np.float32
+
+ # align format for control guidance
+ if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
+ control_guidance_start = len(control_guidance_end) * [control_guidance_start]
+ elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
+ control_guidance_end = len(control_guidance_start) * [control_guidance_end]
+ elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
+ mult = num_controlnet
+ control_guidance_start, control_guidance_end = (
+ mult * [control_guidance_start],
+ mult * [control_guidance_end],
+ )
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ num_controlnet,
+ prompt,
+ control_image,
+ callback_steps,
+ negative_prompt,
+ prompt_embeds,
+ negative_prompt_embeds,
+ controlnet_conditioning_scale,
+ control_guidance_start,
+ control_guidance_end,
+ )
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ if num_controlnet > 1 and isinstance(controlnet_conditioning_scale, float):
+ controlnet_conditioning_scale = [controlnet_conditioning_scale] * num_controlnet
+
+ # 3. Encode input prompt
+ prompt_embeds = self._encode_prompt(
+ prompt,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ )
+ # 4. Prepare image
+ image = self.image_processor.preprocess(image).to(dtype=torch.float32)
+
+ # 5. Prepare controlnet_conditioning_image
+ if num_controlnet == 1:
+ control_image = self.prepare_control_image(
+ image=control_image,
+ width=width,
+ height=height,
+ batch_size=batch_size * num_images_per_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ dtype=torch_dtype,
+ do_classifier_free_guidance=do_classifier_free_guidance,
+ guess_mode=guess_mode,
+ )
+ elif num_controlnet > 1:
+ control_images = []
+
+ for control_image_ in control_image:
+ control_image_ = self.prepare_control_image(
+ image=control_image_,
+ width=width,
+ height=height,
+ batch_size=batch_size * num_images_per_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ dtype=torch_dtype,
+ do_classifier_free_guidance=do_classifier_free_guidance,
+ guess_mode=guess_mode,
+ )
+
+ control_images.append(control_image_)
+
+ control_image = control_images
+ else:
+ assert False
+
+ # 5. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
+
+ # 6. Prepare latent variables
+ latents = self.prepare_latents(
+ image,
+ latent_timestep,
+ batch_size,
+ num_images_per_prompt,
+ torch_dtype,
+ device,
+ generator,
+ )
+
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 7.1 Create tensor stating which controlnets to keep
+ controlnet_keep = []
+ for i in range(len(timesteps)):
+ keeps = [
+ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
+ for s, e in zip(control_guidance_start, control_guidance_end)
+ ]
+ controlnet_keep.append(keeps[0] if num_controlnet == 1 else keeps)
+
+ # 8. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ if isinstance(controlnet_keep[i], list):
+ cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
+ else:
+ controlnet_cond_scale = controlnet_conditioning_scale
+ if isinstance(controlnet_cond_scale, list):
+ controlnet_cond_scale = controlnet_cond_scale[0]
+ cond_scale = controlnet_cond_scale * controlnet_keep[i]
+
+ # predict the noise residual
+ _latent_model_input = latent_model_input.cpu().detach().numpy()
+ _prompt_embeds = np.array(prompt_embeds, dtype=np_dtype)
+ _t = np.array([t.cpu().detach().numpy()], dtype=np_dtype)
+
+ if num_controlnet == 1:
+ control_images = np.array([control_image], dtype=np_dtype)
+ else:
+ control_images = []
+ for _control_img in control_image:
+ _control_img = _control_img.cpu().detach().numpy()
+ control_images.append(_control_img)
+ control_images = np.array(control_images, dtype=np_dtype)
+
+ control_scales = np.array(cond_scale, dtype=np_dtype)
+ control_scales = np.resize(control_scales, (num_controlnet, 1))
+
+ noise_pred = self.unet(
+ sample=_latent_model_input,
+ timestep=_t,
+ encoder_hidden_states=_prompt_embeds,
+ controlnet_conds=control_images,
+ conditioning_scales=control_scales,
+ )[0]
+ noise_pred = torch.from_numpy(noise_pred).to(device)
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
+
+ if not output_type == "latent":
+ _latents = latents.cpu().detach().numpy() / 0.18215
+ _latents = np.array(_latents, dtype=np_dtype)
+ image = self.vae_decoder(latent_sample=_latents)[0]
+ image = torch.from_numpy(image).to(device, dtype=torch.float32)
+ has_nsfw_concept = None
+ else:
+ image = latents
+ has_nsfw_concept = None
+
+ if has_nsfw_concept is None:
+ do_denormalize = [True] * image.shape[0]
+ else:
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
+
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ "--sd_model",
+ type=str,
+ required=True,
+ help="Path to the `diffusers` checkpoint to convert (either a local directory or on the Hub).",
+ )
+
+ parser.add_argument(
+ "--onnx_model_dir",
+ type=str,
+ required=True,
+ help="Path to the ONNX directory",
+ )
+
+ parser.add_argument("--qr_img_path", type=str, required=True, help="Path to the qr code image")
+
+ args = parser.parse_args()
+
+ qr_image = Image.open(args.qr_img_path)
+ qr_image = qr_image.resize((512, 512))
+
+ # init stable diffusion pipeline
+ pipeline = StableDiffusionImg2ImgPipeline.from_pretrained(args.sd_model)
+ pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config)
+
+ provider = ["CUDAExecutionProvider", "CPUExecutionProvider"]
+ onnx_pipeline = OnnxStableDiffusionControlNetImg2ImgPipeline(
+ vae_encoder=OnnxRuntimeModel.from_pretrained(
+ os.path.join(args.onnx_model_dir, "vae_encoder"), provider=provider
+ ),
+ vae_decoder=OnnxRuntimeModel.from_pretrained(
+ os.path.join(args.onnx_model_dir, "vae_decoder"), provider=provider
+ ),
+ text_encoder=OnnxRuntimeModel.from_pretrained(
+ os.path.join(args.onnx_model_dir, "text_encoder"), provider=provider
+ ),
+ tokenizer=pipeline.tokenizer,
+ unet=OnnxRuntimeModel.from_pretrained(os.path.join(args.onnx_model_dir, "unet"), provider=provider),
+ scheduler=pipeline.scheduler,
+ )
+ onnx_pipeline = onnx_pipeline.to("cuda")
+
+ prompt = "a cute cat fly to the moon"
+ negative_prompt = "paintings, sketches, worst quality, low quality, normal quality, lowres, normal quality, monochrome, grayscale, skin spots, acnes, skin blemishes, age spot, glans, nsfw, nipples, necklace, worst quality, low quality, watermark, username, signature, multiple breasts, lowres, bad anatomy, bad hands, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry, bad feet, single color, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, ugly, blurry, bad anatomy, bad proportions, extra limbs, disfigured, bad anatomy, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, mutated hands, fused fingers, too many fingers, long neck, bad body perspect"
+
+ for i in range(10):
+ start_time = time.time()
+ image = onnx_pipeline(
+ num_controlnet=2,
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ image=qr_image,
+ control_image=[qr_image, qr_image],
+ width=512,
+ height=512,
+ strength=0.75,
+ num_inference_steps=20,
+ num_images_per_prompt=1,
+ controlnet_conditioning_scale=[0.8, 0.8],
+ control_guidance_start=[0.3, 0.3],
+ control_guidance_end=[0.9, 0.9],
+ ).images[0]
+ print(time.time() - start_time)
+ image.save("output_qr_code.png")
diff --git a/diffusers/examples/community/run_tensorrt_controlnet.py b/diffusers/examples/community/run_tensorrt_controlnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..873195fa31bae1d8272cdac7c18f6bd12e8b782f
--- /dev/null
+++ b/diffusers/examples/community/run_tensorrt_controlnet.py
@@ -0,0 +1,1022 @@
+import argparse
+import atexit
+import inspect
+import os
+import time
+import warnings
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import numpy as np
+import PIL.Image
+import pycuda.driver as cuda
+import tensorrt as trt
+import torch
+from PIL import Image
+from pycuda.tools import make_default_context
+from transformers import CLIPTokenizer
+
+from diffusers import OnnxRuntimeModel, StableDiffusionImg2ImgPipeline, UniPCMultistepScheduler
+from diffusers.image_processor import VaeImageProcessor
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline
+from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
+from diffusers.schedulers import KarrasDiffusionSchedulers
+from diffusers.utils import (
+ deprecate,
+ logging,
+ replace_example_docstring,
+)
+from diffusers.utils.torch_utils import randn_tensor
+
+
+# Initialize CUDA
+cuda.init()
+context = make_default_context()
+device = context.get_device()
+atexit.register(context.pop)
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+def load_engine(trt_runtime, engine_path):
+ with open(engine_path, "rb") as f:
+ engine_data = f.read()
+ engine = trt_runtime.deserialize_cuda_engine(engine_data)
+ return engine
+
+
+class TensorRTModel:
+ def __init__(
+ self,
+ trt_engine_path,
+ **kwargs,
+ ):
+ cuda.init()
+ stream = cuda.Stream()
+ TRT_LOGGER = trt.Logger(trt.Logger.VERBOSE)
+ trt.init_libnvinfer_plugins(TRT_LOGGER, "")
+ trt_runtime = trt.Runtime(TRT_LOGGER)
+ engine = load_engine(trt_runtime, trt_engine_path)
+ context = engine.create_execution_context()
+
+ # allocates memory for network inputs/outputs on both CPU and GPU
+ host_inputs = []
+ cuda_inputs = []
+ host_outputs = []
+ cuda_outputs = []
+ bindings = []
+ input_names = []
+ output_names = []
+
+ for binding in engine:
+ datatype = engine.get_binding_dtype(binding)
+ if datatype == trt.DataType.HALF:
+ dtype = np.float16
+ else:
+ dtype = np.float32
+
+ shape = tuple(engine.get_binding_shape(binding))
+ host_mem = cuda.pagelocked_empty(shape, dtype)
+ cuda_mem = cuda.mem_alloc(host_mem.nbytes)
+ bindings.append(int(cuda_mem))
+
+ if engine.binding_is_input(binding):
+ host_inputs.append(host_mem)
+ cuda_inputs.append(cuda_mem)
+ input_names.append(binding)
+ else:
+ host_outputs.append(host_mem)
+ cuda_outputs.append(cuda_mem)
+ output_names.append(binding)
+
+ self.stream = stream
+ self.context = context
+ self.engine = engine
+
+ self.host_inputs = host_inputs
+ self.cuda_inputs = cuda_inputs
+ self.host_outputs = host_outputs
+ self.cuda_outputs = cuda_outputs
+ self.bindings = bindings
+ self.batch_size = engine.max_batch_size
+
+ self.input_names = input_names
+ self.output_names = output_names
+
+ def __call__(self, **kwargs):
+ context = self.context
+ stream = self.stream
+ bindings = self.bindings
+
+ host_inputs = self.host_inputs
+ cuda_inputs = self.cuda_inputs
+ host_outputs = self.host_outputs
+ cuda_outputs = self.cuda_outputs
+
+ for idx, input_name in enumerate(self.input_names):
+ _input = kwargs[input_name]
+ np.copyto(host_inputs[idx], _input)
+ # transfer input data to the GPU
+ cuda.memcpy_htod_async(cuda_inputs[idx], host_inputs[idx], stream)
+
+ context.execute_async_v2(bindings=bindings, stream_handle=stream.handle)
+
+ result = {}
+ for idx, output_name in enumerate(self.output_names):
+ # transfer predictions back from the GPU
+ cuda.memcpy_dtoh_async(host_outputs[idx], cuda_outputs[idx], stream)
+ result[output_name] = host_outputs[idx]
+
+ stream.synchronize()
+
+ return result
+
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> # !pip install opencv-python transformers accelerate
+ >>> from diffusers import StableDiffusionControlNetImg2ImgPipeline, ControlNetModel, UniPCMultistepScheduler
+ >>> from diffusers.utils import load_image
+ >>> import numpy as np
+ >>> import torch
+
+ >>> import cv2
+ >>> from PIL import Image
+
+ >>> # download an image
+ >>> image = load_image(
+ ... "https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png"
+ ... )
+ >>> np_image = np.array(image)
+
+ >>> # get canny image
+ >>> np_image = cv2.Canny(np_image, 100, 200)
+ >>> np_image = np_image[:, :, None]
+ >>> np_image = np.concatenate([np_image, np_image, np_image], axis=2)
+ >>> canny_image = Image.fromarray(np_image)
+
+ >>> # load control net and stable diffusion v1-5
+ >>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16)
+ >>> pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
+ ... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16
+ ... )
+
+ >>> # speed up diffusion process with faster scheduler and memory optimization
+ >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
+ >>> pipe.enable_model_cpu_offload()
+
+ >>> # generate image
+ >>> generator = torch.manual_seed(0)
+ >>> image = pipe(
+ ... "futuristic-looking woman",
+ ... num_inference_steps=20,
+ ... generator=generator,
+ ... image=image,
+ ... control_image=canny_image,
+ ... ).images[0]
+ ```
+"""
+
+
+def prepare_image(image):
+ if isinstance(image, torch.Tensor):
+ # Batch single image
+ if image.ndim == 3:
+ image = image.unsqueeze(0)
+
+ image = image.to(dtype=torch.float32)
+ else:
+ # preprocess image
+ if isinstance(image, (PIL.Image.Image, np.ndarray)):
+ image = [image]
+
+ if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):
+ image = [np.array(i.convert("RGB"))[None, :] for i in image]
+ image = np.concatenate(image, axis=0)
+ elif isinstance(image, list) and isinstance(image[0], np.ndarray):
+ image = np.concatenate([i[None, :] for i in image], axis=0)
+
+ image = image.transpose(0, 3, 1, 2)
+ image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
+
+ return image
+
+
+class TensorRTStableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline):
+ vae_encoder: OnnxRuntimeModel
+ vae_decoder: OnnxRuntimeModel
+ text_encoder: OnnxRuntimeModel
+ tokenizer: CLIPTokenizer
+ unet: TensorRTModel
+ scheduler: KarrasDiffusionSchedulers
+
+ def __init__(
+ self,
+ vae_encoder: OnnxRuntimeModel,
+ vae_decoder: OnnxRuntimeModel,
+ text_encoder: OnnxRuntimeModel,
+ tokenizer: CLIPTokenizer,
+ unet: TensorRTModel,
+ scheduler: KarrasDiffusionSchedulers,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae_encoder=vae_encoder,
+ vae_decoder=vae_decoder,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ )
+ self.vae_scale_factor = 2 ** (4 - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
+ self.control_image_processor = VaeImageProcessor(
+ vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
+ )
+
+ def _encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ num_images_per_prompt: Optional[int],
+ do_classifier_free_guidance: bool,
+ negative_prompt: Optional[str],
+ prompt_embeds: Optional[np.ndarray] = None,
+ negative_prompt_embeds: Optional[np.ndarray] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`):
+ prompt to be encoded
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ prompt_embeds (`np.ndarray`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`np.ndarray`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ """
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ # get prompt text embeddings
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="np",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="np").input_ids
+
+ if not np.array_equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer.batch_decode(
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
+ )
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ prompt_embeds = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0]
+
+ prompt_embeds = np.repeat(prompt_embeds, num_images_per_prompt, axis=0)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt] * batch_size
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="np",
+ )
+ negative_prompt_embeds = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0]
+
+ if do_classifier_free_guidance:
+ negative_prompt_embeds = np.repeat(negative_prompt_embeds, num_images_per_prompt, axis=0)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ prompt_embeds = np.concatenate([negative_prompt_embeds, prompt_embeds])
+
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
+ def decode_latents(self, latents):
+ warnings.warn(
+ "The decode_latents method is deprecated and will be removed in a future version. Please"
+ " use VaeImageProcessor instead",
+ FutureWarning,
+ )
+ latents = 1 / self.vae.config.scaling_factor * latents
+ image = self.vae.decode(latents, return_dict=False)[0]
+ image = (image / 2 + 0.5).clamp(0, 1)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+ return image
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(
+ self,
+ num_controlnet,
+ prompt,
+ image,
+ callback_steps,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ controlnet_conditioning_scale=1.0,
+ control_guidance_start=0.0,
+ control_guidance_end=1.0,
+ ):
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ # Check `image`
+ if num_controlnet == 1:
+ self.check_image(image, prompt, prompt_embeds)
+ elif num_controlnet > 1:
+ if not isinstance(image, list):
+ raise TypeError("For multiple controlnets: `image` must be type `list`")
+
+ # When `image` is a nested list:
+ # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]])
+ elif any(isinstance(i, list) for i in image):
+ raise ValueError("A single batch of multiple conditionings are supported at the moment.")
+ elif len(image) != num_controlnet:
+ raise ValueError(
+ f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {num_controlnet} ControlNets."
+ )
+
+ for image_ in image:
+ self.check_image(image_, prompt, prompt_embeds)
+ else:
+ assert False
+
+ # Check `controlnet_conditioning_scale`
+ if num_controlnet == 1:
+ if not isinstance(controlnet_conditioning_scale, float):
+ raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
+ elif num_controlnet > 1:
+ if isinstance(controlnet_conditioning_scale, list):
+ if any(isinstance(i, list) for i in controlnet_conditioning_scale):
+ raise ValueError("A single batch of multiple conditionings are supported at the moment.")
+ elif (
+ isinstance(controlnet_conditioning_scale, list)
+ and len(controlnet_conditioning_scale) != num_controlnet
+ ):
+ raise ValueError(
+ "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have"
+ " the same length as the number of controlnets"
+ )
+ else:
+ assert False
+
+ if len(control_guidance_start) != len(control_guidance_end):
+ raise ValueError(
+ f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list."
+ )
+
+ if num_controlnet > 1:
+ if len(control_guidance_start) != num_controlnet:
+ raise ValueError(
+ f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {num_controlnet} controlnets available. Make sure to provide {num_controlnet}."
+ )
+
+ for start, end in zip(control_guidance_start, control_guidance_end):
+ if start >= end:
+ raise ValueError(
+ f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}."
+ )
+ if start < 0.0:
+ raise ValueError(f"control guidance start: {start} can't be smaller than 0.")
+ if end > 1.0:
+ raise ValueError(f"control guidance end: {end} can't be larger than 1.0.")
+
+ # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image
+ def check_image(self, image, prompt, prompt_embeds):
+ image_is_pil = isinstance(image, PIL.Image.Image)
+ image_is_tensor = isinstance(image, torch.Tensor)
+ image_is_np = isinstance(image, np.ndarray)
+ image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)
+ image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)
+ image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray)
+
+ if (
+ not image_is_pil
+ and not image_is_tensor
+ and not image_is_np
+ and not image_is_pil_list
+ and not image_is_tensor_list
+ and not image_is_np_list
+ ):
+ raise TypeError(
+ f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}"
+ )
+
+ if image_is_pil:
+ image_batch_size = 1
+ else:
+ image_batch_size = len(image)
+
+ if prompt is not None and isinstance(prompt, str):
+ prompt_batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ prompt_batch_size = len(prompt)
+ elif prompt_embeds is not None:
+ prompt_batch_size = prompt_embeds.shape[0]
+
+ if image_batch_size != 1 and image_batch_size != prompt_batch_size:
+ raise ValueError(
+ f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
+ )
+
+ # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image
+ def prepare_control_image(
+ self,
+ image,
+ width,
+ height,
+ batch_size,
+ num_images_per_prompt,
+ device,
+ dtype,
+ do_classifier_free_guidance=False,
+ guess_mode=False,
+ ):
+ image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
+ image_batch_size = image.shape[0]
+
+ if image_batch_size == 1:
+ repeat_by = batch_size
+ else:
+ # image batch size is the same as prompt batch size
+ repeat_by = num_images_per_prompt
+
+ image = image.repeat_interleave(repeat_by, dim=0)
+
+ image = image.to(device=device, dtype=dtype)
+
+ if do_classifier_free_guidance and not guess_mode:
+ image = torch.cat([image] * 2)
+
+ return image
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
+ def get_timesteps(self, num_inference_steps, strength, device):
+ # get the original timestep using init_timestep
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
+
+ t_start = max(num_inference_steps - init_timestep, 0)
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
+
+ return timesteps, num_inference_steps - t_start
+
+ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
+ if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
+ raise ValueError(
+ f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
+ )
+
+ image = image.to(device=device, dtype=dtype)
+
+ batch_size = batch_size * num_images_per_prompt
+
+ if image.shape[1] == 4:
+ init_latents = image
+
+ else:
+ _image = image.cpu().detach().numpy()
+ init_latents = self.vae_encoder(sample=_image)[0]
+ init_latents = torch.from_numpy(init_latents).to(device=device, dtype=dtype)
+ init_latents = 0.18215 * init_latents
+
+ if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
+ # expand init_latents for batch_size
+ deprecation_message = (
+ f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial"
+ " images (`image`). Initial images are now duplicating to match the number of text prompts. Note"
+ " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
+ " your script to pass as many initial images as text prompts to suppress this warning."
+ )
+ deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False)
+ additional_image_per_prompt = batch_size // init_latents.shape[0]
+ init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0)
+ elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
+ raise ValueError(
+ f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
+ )
+ else:
+ init_latents = torch.cat([init_latents], dim=0)
+
+ shape = init_latents.shape
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+
+ # get latents
+ init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
+ latents = init_latents
+
+ return latents
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ num_controlnet: int,
+ fp16: bool = True,
+ prompt: Union[str, List[str]] = None,
+ image: Union[
+ torch.Tensor,
+ PIL.Image.Image,
+ np.ndarray,
+ List[torch.Tensor],
+ List[PIL.Image.Image],
+ List[np.ndarray],
+ ] = None,
+ control_image: Union[
+ torch.Tensor,
+ PIL.Image.Image,
+ np.ndarray,
+ List[torch.Tensor],
+ List[PIL.Image.Image],
+ List[np.ndarray],
+ ] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ strength: float = 0.8,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
+ callback_steps: int = 1,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ controlnet_conditioning_scale: Union[float, List[float]] = 0.8,
+ guess_mode: bool = False,
+ control_guidance_start: Union[float, List[float]] = 0.0,
+ control_guidance_end: Union[float, List[float]] = 1.0,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
+ `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
+ The initial image will be used as the starting point for the image generation process. Can also accept
+ image latents as `image`, if passing latents directly, it will not be encoded again.
+ control_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
+ `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
+ The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If
+ the type is specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can
+ also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If
+ height and/or width are passed, `image` is resized according to them. If multiple ControlNets are
+ specified in init, images must be passed as a list such that each element of the list can be correctly
+ batched for input to a single controlnet.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
+ The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added
+ to the residual in the original unet. If multiple ControlNets are specified in init, you can set the
+ corresponding scale as a list. Note that by default, we use a smaller conditioning scale for inpainting
+ than for [`~StableDiffusionControlNetPipeline.__call__`].
+ guess_mode (`bool`, *optional*, defaults to `False`):
+ In this mode, the ControlNet encoder will try best to recognize the content of the input image even if
+ you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended.
+ control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
+ The percentage of total steps at which the controlnet starts applying.
+ control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
+ The percentage of total steps at which the controlnet stops applying.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+ if fp16:
+ torch_dtype = torch.float16
+ np_dtype = np.float16
+ else:
+ torch_dtype = torch.float32
+ np_dtype = np.float32
+
+ # align format for control guidance
+ if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
+ control_guidance_start = len(control_guidance_end) * [control_guidance_start]
+ elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
+ control_guidance_end = len(control_guidance_start) * [control_guidance_end]
+ elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
+ mult = num_controlnet
+ control_guidance_start, control_guidance_end = (
+ mult * [control_guidance_start],
+ mult * [control_guidance_end],
+ )
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ num_controlnet,
+ prompt,
+ control_image,
+ callback_steps,
+ negative_prompt,
+ prompt_embeds,
+ negative_prompt_embeds,
+ controlnet_conditioning_scale,
+ control_guidance_start,
+ control_guidance_end,
+ )
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ if num_controlnet > 1 and isinstance(controlnet_conditioning_scale, float):
+ controlnet_conditioning_scale = [controlnet_conditioning_scale] * num_controlnet
+
+ # 3. Encode input prompt
+ prompt_embeds = self._encode_prompt(
+ prompt,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ )
+ # 4. Prepare image
+ image = self.image_processor.preprocess(image).to(dtype=torch.float32)
+
+ # 5. Prepare controlnet_conditioning_image
+ if num_controlnet == 1:
+ control_image = self.prepare_control_image(
+ image=control_image,
+ width=width,
+ height=height,
+ batch_size=batch_size * num_images_per_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ dtype=torch_dtype,
+ do_classifier_free_guidance=do_classifier_free_guidance,
+ guess_mode=guess_mode,
+ )
+ elif num_controlnet > 1:
+ control_images = []
+
+ for control_image_ in control_image:
+ control_image_ = self.prepare_control_image(
+ image=control_image_,
+ width=width,
+ height=height,
+ batch_size=batch_size * num_images_per_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ dtype=torch_dtype,
+ do_classifier_free_guidance=do_classifier_free_guidance,
+ guess_mode=guess_mode,
+ )
+
+ control_images.append(control_image_)
+
+ control_image = control_images
+ else:
+ assert False
+
+ # 5. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
+
+ # 6. Prepare latent variables
+ latents = self.prepare_latents(
+ image,
+ latent_timestep,
+ batch_size,
+ num_images_per_prompt,
+ torch_dtype,
+ device,
+ generator,
+ )
+
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 7.1 Create tensor stating which controlnets to keep
+ controlnet_keep = []
+ for i in range(len(timesteps)):
+ keeps = [
+ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
+ for s, e in zip(control_guidance_start, control_guidance_end)
+ ]
+ controlnet_keep.append(keeps[0] if num_controlnet == 1 else keeps)
+
+ # 8. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ if isinstance(controlnet_keep[i], list):
+ cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
+ else:
+ controlnet_cond_scale = controlnet_conditioning_scale
+ if isinstance(controlnet_cond_scale, list):
+ controlnet_cond_scale = controlnet_cond_scale[0]
+ cond_scale = controlnet_cond_scale * controlnet_keep[i]
+
+ # predict the noise residual
+ _latent_model_input = latent_model_input.cpu().detach().numpy()
+ _prompt_embeds = np.array(prompt_embeds, dtype=np_dtype)
+ _t = np.array([t.cpu().detach().numpy()], dtype=np_dtype)
+
+ if num_controlnet == 1:
+ control_images = np.array([control_image], dtype=np_dtype)
+ else:
+ control_images = []
+ for _control_img in control_image:
+ _control_img = _control_img.cpu().detach().numpy()
+ control_images.append(_control_img)
+ control_images = np.array(control_images, dtype=np_dtype)
+
+ control_scales = np.array(cond_scale, dtype=np_dtype)
+ control_scales = np.resize(control_scales, (num_controlnet, 1))
+
+ noise_pred = self.unet(
+ sample=_latent_model_input,
+ timestep=_t,
+ encoder_hidden_states=_prompt_embeds,
+ controlnet_conds=control_images,
+ conditioning_scales=control_scales,
+ )["noise_pred"]
+ noise_pred = torch.from_numpy(noise_pred).to(device)
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
+
+ if not output_type == "latent":
+ _latents = latents.cpu().detach().numpy() / 0.18215
+ _latents = np.array(_latents, dtype=np_dtype)
+ image = self.vae_decoder(latent_sample=_latents)[0]
+ image = torch.from_numpy(image).to(device, dtype=torch.float32)
+ has_nsfw_concept = None
+ else:
+ image = latents
+ has_nsfw_concept = None
+
+ if has_nsfw_concept is None:
+ do_denormalize = [True] * image.shape[0]
+ else:
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
+
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ "--sd_model",
+ type=str,
+ required=True,
+ help="Path to the `diffusers` checkpoint to convert (either a local directory or on the Hub).",
+ )
+
+ parser.add_argument(
+ "--onnx_model_dir",
+ type=str,
+ required=True,
+ help="Path to the ONNX directory",
+ )
+
+ parser.add_argument(
+ "--unet_engine_path",
+ type=str,
+ required=True,
+ help="Path to the unet + controlnet tensorrt model",
+ )
+
+ parser.add_argument("--qr_img_path", type=str, required=True, help="Path to the qr code image")
+
+ args = parser.parse_args()
+
+ qr_image = Image.open(args.qr_img_path)
+ qr_image = qr_image.resize((512, 512))
+
+ # init stable diffusion pipeline
+ pipeline = StableDiffusionImg2ImgPipeline.from_pretrained(args.sd_model)
+ pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config)
+
+ provider = ["CUDAExecutionProvider", "CPUExecutionProvider"]
+ onnx_pipeline = TensorRTStableDiffusionControlNetImg2ImgPipeline(
+ vae_encoder=OnnxRuntimeModel.from_pretrained(
+ os.path.join(args.onnx_model_dir, "vae_encoder"), provider=provider
+ ),
+ vae_decoder=OnnxRuntimeModel.from_pretrained(
+ os.path.join(args.onnx_model_dir, "vae_decoder"), provider=provider
+ ),
+ text_encoder=OnnxRuntimeModel.from_pretrained(
+ os.path.join(args.onnx_model_dir, "text_encoder"), provider=provider
+ ),
+ tokenizer=pipeline.tokenizer,
+ unet=TensorRTModel(args.unet_engine_path),
+ scheduler=pipeline.scheduler,
+ )
+ onnx_pipeline = onnx_pipeline.to("cuda")
+
+ prompt = "a cute cat fly to the moon"
+ negative_prompt = "paintings, sketches, worst quality, low quality, normal quality, lowres, normal quality, monochrome, grayscale, skin spots, acnes, skin blemishes, age spot, glans, nsfw, nipples, necklace, worst quality, low quality, watermark, username, signature, multiple breasts, lowres, bad anatomy, bad hands, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry, bad feet, single color, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, ugly, blurry, bad anatomy, bad proportions, extra limbs, disfigured, bad anatomy, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, mutated hands, fused fingers, too many fingers, long neck, bad body perspect"
+
+ for i in range(10):
+ start_time = time.time()
+ image = onnx_pipeline(
+ num_controlnet=2,
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ image=qr_image,
+ control_image=[qr_image, qr_image],
+ width=512,
+ height=512,
+ strength=0.75,
+ num_inference_steps=20,
+ num_images_per_prompt=1,
+ controlnet_conditioning_scale=[0.8, 0.8],
+ control_guidance_start=[0.3, 0.3],
+ control_guidance_end=[0.9, 0.9],
+ ).images[0]
+ print(time.time() - start_time)
+ image.save("output_qr_code.png")
diff --git a/diffusers/examples/community/scheduling_ufogen.py b/diffusers/examples/community/scheduling_ufogen.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b1b92ff183ac16c6c2a2b6a28d9da8d16f192da
--- /dev/null
+++ b/diffusers/examples/community/scheduling_ufogen.py
@@ -0,0 +1,521 @@
+# Copyright 2024 UC Berkeley Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
+
+import math
+from dataclasses import dataclass
+from typing import List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.schedulers.scheduling_utils import SchedulerMixin
+from diffusers.utils import BaseOutput
+from diffusers.utils.torch_utils import randn_tensor
+
+
+@dataclass
+# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->UFOGen
+class UFOGenSchedulerOutput(BaseOutput):
+ """
+ Output class for the scheduler's `step` function output.
+
+ Args:
+ prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
+ Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
+ denoising loop.
+ pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
+ The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
+ `pred_original_sample` can be used to preview progress or for guidance.
+ """
+
+ prev_sample: torch.Tensor
+ pred_original_sample: Optional[torch.Tensor] = None
+
+
+# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
+def betas_for_alpha_bar(
+ num_diffusion_timesteps,
+ max_beta=0.999,
+ alpha_transform_type="cosine",
+):
+ """
+ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
+ (1-beta) over time from t = [0,1].
+
+ Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
+ to that part of the diffusion process.
+
+
+ Args:
+ num_diffusion_timesteps (`int`): the number of betas to produce.
+ max_beta (`float`): the maximum beta to use; use values lower than 1 to
+ prevent singularities.
+ alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
+ Choose from `cosine` or `exp`
+
+ Returns:
+ betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
+ """
+ if alpha_transform_type == "cosine":
+
+ def alpha_bar_fn(t):
+ return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
+
+ elif alpha_transform_type == "exp":
+
+ def alpha_bar_fn(t):
+ return math.exp(t * -12.0)
+
+ else:
+ raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
+
+ betas = []
+ for i in range(num_diffusion_timesteps):
+ t1 = i / num_diffusion_timesteps
+ t2 = (i + 1) / num_diffusion_timesteps
+ betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
+ return torch.tensor(betas, dtype=torch.float32)
+
+
+# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
+def rescale_zero_terminal_snr(betas):
+ """
+ Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
+
+
+ Args:
+ betas (`torch.Tensor`):
+ the betas that the scheduler is being initialized with.
+
+ Returns:
+ `torch.Tensor`: rescaled betas with zero terminal SNR
+ """
+ # Convert betas to alphas_bar_sqrt
+ alphas = 1.0 - betas
+ alphas_cumprod = torch.cumprod(alphas, dim=0)
+ alphas_bar_sqrt = alphas_cumprod.sqrt()
+
+ # Store old values.
+ alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
+ alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
+
+ # Shift so the last timestep is zero.
+ alphas_bar_sqrt -= alphas_bar_sqrt_T
+
+ # Scale so the first timestep is back to the old value.
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
+
+ # Convert alphas_bar_sqrt to betas
+ alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
+ alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
+ alphas = torch.cat([alphas_bar[0:1], alphas])
+ betas = 1 - alphas
+
+ return betas
+
+
+class UFOGenScheduler(SchedulerMixin, ConfigMixin):
+ """
+ `UFOGenScheduler` implements multistep and onestep sampling for a UFOGen model, introduced in
+ [UFOGen: You Forward Once Large Scale Text-to-Image Generation via Diffusion GANs](https://arxiv.org/abs/2311.09257)
+ by Yanwu Xu, Yang Zhao, Zhisheng Xiao, and Tingbo Hou. UFOGen is a varianet of the denoising diffusion GAN (DDGAN)
+ model designed for one-step sampling.
+
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
+ methods the library implements for all schedulers such as loading and saving.
+
+ Args:
+ num_train_timesteps (`int`, defaults to 1000):
+ The number of diffusion steps to train the model.
+ beta_start (`float`, defaults to 0.0001):
+ The starting `beta` value of inference.
+ beta_end (`float`, defaults to 0.02):
+ The final `beta` value.
+ beta_schedule (`str`, defaults to `"linear"`):
+ The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
+ `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
+ clip_sample (`bool`, defaults to `True`):
+ Clip the predicted sample for numerical stability.
+ clip_sample_range (`float`, defaults to 1.0):
+ The maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
+ set_alpha_to_one (`bool`, defaults to `True`):
+ Each diffusion step uses the alphas product value at that step and at the previous one. For the final step
+ there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
+ otherwise it uses the alpha value at step 0.
+ prediction_type (`str`, defaults to `epsilon`, *optional*):
+ Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
+ `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
+ Video](https://imagen.research.google/video/paper.pdf) paper).
+ thresholding (`bool`, defaults to `False`):
+ Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
+ as Stable Diffusion.
+ dynamic_thresholding_ratio (`float`, defaults to 0.995):
+ The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
+ sample_max_value (`float`, defaults to 1.0):
+ The threshold value for dynamic thresholding. Valid only when `thresholding=True`.
+ timestep_spacing (`str`, defaults to `"leading"`):
+ The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
+ Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
+ steps_offset (`int`, defaults to 0):
+ An offset added to the inference steps, as required by some model families.
+ rescale_betas_zero_snr (`bool`, defaults to `False`):
+ Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
+ dark samples instead of limiting it to samples with medium brightness. Loosely related to
+ [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
+ denoising_step_size (`int`, defaults to 250):
+ The denoising step size parameter from the UFOGen paper. The number of steps used for training is roughly
+ `math.ceil(num_train_timesteps / denoising_step_size)`.
+ """
+
+ order = 1
+
+ @register_to_config
+ def __init__(
+ self,
+ num_train_timesteps: int = 1000,
+ beta_start: float = 0.0001,
+ beta_end: float = 0.02,
+ beta_schedule: str = "linear",
+ trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
+ clip_sample: bool = True,
+ set_alpha_to_one: bool = True,
+ prediction_type: str = "epsilon",
+ thresholding: bool = False,
+ dynamic_thresholding_ratio: float = 0.995,
+ clip_sample_range: float = 1.0,
+ sample_max_value: float = 1.0,
+ timestep_spacing: str = "leading",
+ steps_offset: int = 0,
+ rescale_betas_zero_snr: bool = False,
+ denoising_step_size: int = 250,
+ ):
+ if trained_betas is not None:
+ self.betas = torch.tensor(trained_betas, dtype=torch.float32)
+ elif beta_schedule == "linear":
+ self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
+ elif beta_schedule == "scaled_linear":
+ # this schedule is very specific to the latent diffusion model.
+ self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
+ elif beta_schedule == "squaredcos_cap_v2":
+ # Glide cosine schedule
+ self.betas = betas_for_alpha_bar(num_train_timesteps)
+ elif beta_schedule == "sigmoid":
+ # GeoDiff sigmoid schedule
+ betas = torch.linspace(-6, 6, num_train_timesteps)
+ self.betas = torch.sigmoid(betas) * (beta_end - beta_start) + beta_start
+ else:
+ raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
+
+ # Rescale for zero SNR
+ if rescale_betas_zero_snr:
+ self.betas = rescale_zero_terminal_snr(self.betas)
+
+ self.alphas = 1.0 - self.betas
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
+
+ # For the final step, there is no previous alphas_cumprod because we are already at 0
+ # `set_alpha_to_one` decides whether we set this parameter simply to one or
+ # whether we use the final alpha of the "non-previous" one.
+ self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
+
+ # standard deviation of the initial noise distribution
+ self.init_noise_sigma = 1.0
+
+ # setable values
+ self.custom_timesteps = False
+ self.num_inference_steps = None
+ self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy())
+
+ def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
+ """
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
+ current timestep.
+
+ Args:
+ sample (`torch.Tensor`):
+ The input sample.
+ timestep (`int`, *optional*):
+ The current timestep in the diffusion chain.
+
+ Returns:
+ `torch.Tensor`:
+ A scaled input sample.
+ """
+ return sample
+
+ def set_timesteps(
+ self,
+ num_inference_steps: Optional[int] = None,
+ device: Union[str, torch.device] = None,
+ timesteps: Optional[List[int]] = None,
+ ):
+ """
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
+
+ Args:
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used,
+ `timesteps` must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
+ timestep spacing strategy of equal spacing between timesteps is used. If `timesteps` is passed,
+ `num_inference_steps` must be `None`.
+
+ """
+ if num_inference_steps is not None and timesteps is not None:
+ raise ValueError("Can only pass one of `num_inference_steps` or `custom_timesteps`.")
+
+ if timesteps is not None:
+ for i in range(1, len(timesteps)):
+ if timesteps[i] >= timesteps[i - 1]:
+ raise ValueError("`custom_timesteps` must be in descending order.")
+
+ if timesteps[0] >= self.config.num_train_timesteps:
+ raise ValueError(
+ f"`timesteps` must start before `self.config.train_timesteps`:"
+ f" {self.config.num_train_timesteps}."
+ )
+
+ timesteps = np.array(timesteps, dtype=np.int64)
+ self.custom_timesteps = True
+ else:
+ if num_inference_steps > self.config.num_train_timesteps:
+ raise ValueError(
+ f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
+ f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
+ f" maximal {self.config.num_train_timesteps} timesteps."
+ )
+
+ self.num_inference_steps = num_inference_steps
+ self.custom_timesteps = False
+
+ # TODO: For now, handle special case when num_inference_steps == 1 separately
+ if num_inference_steps == 1:
+ # Set the timestep schedule to num_train_timesteps - 1 rather than 0
+ # (that is, the one-step timestep schedule is always trailing rather than leading or linspace)
+ timesteps = np.array([self.config.num_train_timesteps - 1], dtype=np.int64)
+ else:
+ # TODO: For now, retain the DDPM timestep spacing logic
+ # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
+ if self.config.timestep_spacing == "linspace":
+ timesteps = (
+ np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps)
+ .round()[::-1]
+ .copy()
+ .astype(np.int64)
+ )
+ elif self.config.timestep_spacing == "leading":
+ step_ratio = self.config.num_train_timesteps // self.num_inference_steps
+ # creates integer timesteps by multiplying by ratio
+ # casting to int to avoid issues when num_inference_step is power of 3
+ timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
+ timesteps += self.config.steps_offset
+ elif self.config.timestep_spacing == "trailing":
+ step_ratio = self.config.num_train_timesteps / self.num_inference_steps
+ # creates integer timesteps by multiplying by ratio
+ # casting to int to avoid issues when num_inference_step is power of 3
+ timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64)
+ timesteps -= 1
+ else:
+ raise ValueError(
+ f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
+ )
+
+ self.timesteps = torch.from_numpy(timesteps).to(device)
+
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
+ def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
+ """
+ "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
+ prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
+ s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
+ pixels from saturation at each step. We find that dynamic thresholding results in significantly better
+ photorealism as well as better image-text alignment, especially when using very large guidance weights."
+
+ https://arxiv.org/abs/2205.11487
+ """
+ dtype = sample.dtype
+ batch_size, channels, *remaining_dims = sample.shape
+
+ if dtype not in (torch.float32, torch.float64):
+ sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
+
+ # Flatten sample for doing quantile calculation along each image
+ sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
+
+ abs_sample = sample.abs() # "a certain percentile absolute pixel value"
+
+ s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
+ s = torch.clamp(
+ s, min=1, max=self.config.sample_max_value
+ ) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
+ s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
+ sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
+
+ sample = sample.reshape(batch_size, channels, *remaining_dims)
+ sample = sample.to(dtype)
+
+ return sample
+
+ def step(
+ self,
+ model_output: torch.Tensor,
+ timestep: int,
+ sample: torch.Tensor,
+ generator: Optional[torch.Generator] = None,
+ return_dict: bool = True,
+ ) -> Union[UFOGenSchedulerOutput, Tuple]:
+ """
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
+ process from the learned model outputs (most often the predicted noise).
+
+ Args:
+ model_output (`torch.Tensor`):
+ The direct output from learned diffusion model.
+ timestep (`float`):
+ The current discrete timestep in the diffusion chain.
+ sample (`torch.Tensor`):
+ A current instance of a sample created by the diffusion process.
+ generator (`torch.Generator`, *optional*):
+ A random number generator.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~schedulers.scheduling_ufogen.UFOGenSchedulerOutput`] or `tuple`.
+
+ Returns:
+ [`~schedulers.scheduling_ddpm.UFOGenSchedulerOutput`] or `tuple`:
+ If return_dict is `True`, [`~schedulers.scheduling_ufogen.UFOGenSchedulerOutput`] is returned, otherwise a
+ tuple is returned where the first element is the sample tensor.
+
+ """
+ # 0. Resolve timesteps
+ t = timestep
+ prev_t = self.previous_timestep(t)
+
+ # 1. compute alphas, betas
+ alpha_prod_t = self.alphas_cumprod[t]
+ alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.final_alpha_cumprod
+ beta_prod_t = 1 - alpha_prod_t
+ # beta_prod_t_prev = 1 - alpha_prod_t_prev
+ # current_alpha_t = alpha_prod_t / alpha_prod_t_prev
+ # current_beta_t = 1 - current_alpha_t
+
+ # 2. compute predicted original sample from predicted noise also called
+ # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
+ if self.config.prediction_type == "epsilon":
+ pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
+ elif self.config.prediction_type == "sample":
+ pred_original_sample = model_output
+ elif self.config.prediction_type == "v_prediction":
+ pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
+ else:
+ raise ValueError(
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` or"
+ " `v_prediction` for UFOGenScheduler."
+ )
+
+ # 3. Clip or threshold "predicted x_0"
+ if self.config.thresholding:
+ pred_original_sample = self._threshold_sample(pred_original_sample)
+ elif self.config.clip_sample:
+ pred_original_sample = pred_original_sample.clamp(
+ -self.config.clip_sample_range, self.config.clip_sample_range
+ )
+
+ # 4. Single-step or multi-step sampling
+ # Noise is not used on the final timestep of the timestep schedule.
+ # This also means that noise is not used for one-step sampling.
+ if t != self.timesteps[-1]:
+ # TODO: is this correct?
+ # Sample prev sample x_{t - 1} ~ q(x_{t - 1} | x_0 = G(x_t, t))
+ device = model_output.device
+ noise = randn_tensor(model_output.shape, generator=generator, device=device, dtype=model_output.dtype)
+ sqrt_alpha_prod_t_prev = alpha_prod_t_prev**0.5
+ sqrt_one_minus_alpha_prod_t_prev = (1 - alpha_prod_t_prev) ** 0.5
+ pred_prev_sample = sqrt_alpha_prod_t_prev * pred_original_sample + sqrt_one_minus_alpha_prod_t_prev * noise
+ else:
+ # Simply return the pred_original_sample. If `prediction_type == "sample"`, this is equivalent to returning
+ # the output of the GAN generator U-Net on the initial noisy latents x_T ~ N(0, I).
+ pred_prev_sample = pred_original_sample
+
+ if not return_dict:
+ return (pred_prev_sample,)
+
+ return UFOGenSchedulerOutput(prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample)
+
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
+ def add_noise(
+ self,
+ original_samples: torch.Tensor,
+ noise: torch.Tensor,
+ timesteps: torch.IntTensor,
+ ) -> torch.Tensor:
+ # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
+ alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
+ timesteps = timesteps.to(original_samples.device)
+
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
+ while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
+
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
+ while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
+
+ noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
+ return noisy_samples
+
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity
+ def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor:
+ # Make sure alphas_cumprod and timestep have same device and dtype as sample
+ alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype)
+ timesteps = timesteps.to(sample.device)
+
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
+ while len(sqrt_alpha_prod.shape) < len(sample.shape):
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
+
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
+ while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
+
+ velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
+ return velocity
+
+ def __len__(self):
+ return self.config.num_train_timesteps
+
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.previous_timestep
+ def previous_timestep(self, timestep):
+ if self.custom_timesteps:
+ index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0]
+ if index == self.timesteps.shape[0] - 1:
+ prev_t = torch.tensor(-1)
+ else:
+ prev_t = self.timesteps[index + 1]
+ else:
+ num_inference_steps = (
+ self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps
+ )
+ prev_t = timestep - self.config.num_train_timesteps // num_inference_steps
+
+ return prev_t
diff --git a/diffusers/examples/community/sd_text2img_k_diffusion.py b/diffusers/examples/community/sd_text2img_k_diffusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..9f83973aba918dc7a166df128d7c27e6dfc8f6d7
--- /dev/null
+++ b/diffusers/examples/community/sd_text2img_k_diffusion.py
@@ -0,0 +1,414 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import importlib
+import warnings
+from typing import Callable, List, Optional, Union
+
+import torch
+from k_diffusion.external import CompVisDenoiser, CompVisVDenoiser
+
+from diffusers import DiffusionPipeline, LMSDiscreteScheduler, StableDiffusionMixin
+from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
+from diffusers.utils import logging
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class ModelWrapper:
+ def __init__(self, model, alphas_cumprod):
+ self.model = model
+ self.alphas_cumprod = alphas_cumprod
+
+ def apply_model(self, *args, **kwargs):
+ if len(args) == 3:
+ encoder_hidden_states = args[-1]
+ args = args[:2]
+ if kwargs.get("cond", None) is not None:
+ encoder_hidden_states = kwargs.pop("cond")
+ return self.model(*args, encoder_hidden_states=encoder_hidden_states, **kwargs).sample
+
+
+class StableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin):
+ r"""
+ Pipeline for text-to-image generation using Stable Diffusion.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ feature_extractor ([`CLIPImageProcessor`]):
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+ """
+
+ _optional_components = ["safety_checker", "feature_extractor"]
+
+ def __init__(
+ self,
+ vae,
+ text_encoder,
+ tokenizer,
+ unet,
+ scheduler,
+ safety_checker,
+ feature_extractor,
+ ):
+ super().__init__()
+
+ if safety_checker is None:
+ logger.warning(
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
+ )
+
+ # get correct sigmas from LMS
+ scheduler = LMSDiscreteScheduler.from_config(scheduler.config)
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+
+ model = ModelWrapper(unet, scheduler.alphas_cumprod)
+ if scheduler.config.prediction_type == "v_prediction":
+ self.k_diffusion_model = CompVisVDenoiser(model)
+ else:
+ self.k_diffusion_model = CompVisDenoiser(model)
+
+ def set_sampler(self, scheduler_type: str):
+ warnings.warn("The `set_sampler` method is deprecated, please use `set_scheduler` instead.")
+ return self.set_scheduler(scheduler_type)
+
+ def set_scheduler(self, scheduler_type: str):
+ library = importlib.import_module("k_diffusion")
+ sampling = getattr(library, "sampling")
+ self.sampler = getattr(sampling, scheduler_type)
+
+ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `list(int)`):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ """
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids
+
+ if not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = text_inputs.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ text_embeddings = self.text_encoder(
+ text_input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ text_embeddings = text_embeddings[0]
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ bs_embed, seq_len, _ = text_embeddings.shape
+ text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
+ text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ max_length = text_input_ids.shape[-1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = uncond_input.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ uncond_embeddings = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ uncond_embeddings = uncond_embeddings[0]
+
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = uncond_embeddings.shape[1]
+ uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
+ uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
+
+ return text_embeddings
+
+ def run_safety_checker(self, image, device, dtype):
+ if self.safety_checker is not None:
+ safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
+ image, has_nsfw_concept = self.safety_checker(
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
+ )
+ else:
+ has_nsfw_concept = None
+ return image, has_nsfw_concept
+
+ def decode_latents(self, latents):
+ latents = 1 / 0.18215 * latents
+ image = self.vae.decode(latents).sample
+ image = (image / 2 + 0.5).clamp(0, 1)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+ return image
+
+ def check_inputs(self, prompt, height, width, callback_steps):
+ if not isinstance(prompt, str) and not isinstance(prompt, list):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
+ shape = (batch_size, num_channels_latents, height // 8, width // 8)
+ if latents is None:
+ if device.type == "mps":
+ # randn does not work reproducibly on mps
+ latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
+ else:
+ latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ if latents.shape != shape:
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ return latents
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]],
+ height: int = 512,
+ width: int = 512,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[torch.Generator] = None,
+ latents: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
+ callback_steps: int = 1,
+ **kwargs,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image generation.
+ height (`int`, *optional*, defaults to 512):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to 512):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator`, *optional*):
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
+ deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(prompt, height, width, callback_steps)
+
+ # 2. Define call parameters
+ batch_size = 1 if isinstance(prompt, str) else len(prompt)
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = True
+ if guidance_scale <= 1.0:
+ raise ValueError("has to use guidance_scale")
+
+ # 3. Encode input prompt
+ text_embeddings = self._encode_prompt(
+ prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
+ )
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=text_embeddings.device)
+ sigmas = self.scheduler.sigmas
+ sigmas = sigmas.to(text_embeddings.dtype)
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.unet.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ text_embeddings.dtype,
+ device,
+ generator,
+ latents,
+ )
+ latents = latents * sigmas[0]
+ self.k_diffusion_model.sigmas = self.k_diffusion_model.sigmas.to(latents.device)
+ self.k_diffusion_model.log_sigmas = self.k_diffusion_model.log_sigmas.to(latents.device)
+
+ def model_fn(x, t):
+ latent_model_input = torch.cat([x] * 2)
+
+ noise_pred = self.k_diffusion_model(latent_model_input, t, cond=text_embeddings)
+
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+ return noise_pred
+
+ latents = self.sampler(model_fn, latents, sigmas)
+
+ # 8. Post-processing
+ image = self.decode_latents(latents)
+
+ # 9. Run safety checker
+ image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype)
+
+ # 10. Convert to PIL
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
diff --git a/diffusers/examples/community/sde_drag.py b/diffusers/examples/community/sde_drag.py
new file mode 100644
index 0000000000000000000000000000000000000000..902eaa99f417431b3ffe4b9c96da9b2953828d93
--- /dev/null
+++ b/diffusers/examples/community/sde_drag.py
@@ -0,0 +1,594 @@
+import math
+import tempfile
+from typing import List, Optional
+
+import numpy as np
+import PIL.Image
+import torch
+from accelerate import Accelerator
+from torchvision import transforms
+from tqdm.auto import tqdm
+from transformers import CLIPTextModel, CLIPTokenizer
+
+from diffusers import AutoencoderKL, DiffusionPipeline, DPMSolverMultistepScheduler, UNet2DConditionModel
+from diffusers.loaders import AttnProcsLayers, StableDiffusionLoraLoaderMixin
+from diffusers.models.attention_processor import (
+ AttnAddedKVProcessor,
+ AttnAddedKVProcessor2_0,
+ LoRAAttnAddedKVProcessor,
+ LoRAAttnProcessor,
+ LoRAAttnProcessor2_0,
+ SlicedAttnAddedKVProcessor,
+)
+from diffusers.optimization import get_scheduler
+
+
+class SdeDragPipeline(DiffusionPipeline):
+ r"""
+ Pipeline for image drag-and-drop editing using stochastic differential equations: https://arxiv.org/abs/2311.01410.
+ Please refer to the [official repository](https://github.com/ML-GSAI/SDE-Drag) for more information.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Please use
+ [`DDIMScheduler`].
+ """
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: DPMSolverMultistepScheduler,
+ ):
+ super().__init__()
+
+ self.register_modules(vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: str,
+ image: PIL.Image.Image,
+ mask_image: PIL.Image.Image,
+ source_points: List[List[int]],
+ target_points: List[List[int]],
+ t0: Optional[float] = 0.6,
+ steps: Optional[int] = 200,
+ step_size: Optional[int] = 2,
+ image_scale: Optional[float] = 0.3,
+ adapt_radius: Optional[int] = 5,
+ min_lora_scale: Optional[float] = 0.5,
+ generator: Optional[torch.Generator] = None,
+ ):
+ r"""
+ Function invoked when calling the pipeline for image editing.
+ Args:
+ prompt (`str`, *required*):
+ The prompt to guide the image editing.
+ image (`PIL.Image.Image`, *required*):
+ Which will be edited, parts of the image will be masked out with `mask_image` and edited
+ according to `prompt`.
+ mask_image (`PIL.Image.Image`, *required*):
+ To mask `image`. White pixels in the mask will be edited, while black pixels will be preserved.
+ source_points (`List[List[int]]`, *required*):
+ Used to mark the starting positions of drag editing in the image, with each pixel represented as a
+ `List[int]` of length 2.
+ target_points (`List[List[int]]`, *required*):
+ Used to mark the target positions of drag editing in the image, with each pixel represented as a
+ `List[int]` of length 2.
+ t0 (`float`, *optional*, defaults to 0.6):
+ The time parameter. Higher t0 improves the fidelity while lowering the faithfulness of the edited images
+ and vice versa.
+ steps (`int`, *optional*, defaults to 200):
+ The number of sampling iterations.
+ step_size (`int`, *optional*, defaults to 2):
+ The drag diatance of each drag step.
+ image_scale (`float`, *optional*, defaults to 0.3):
+ To avoid duplicating the content, use image_scale to perturbs the source.
+ adapt_radius (`int`, *optional*, defaults to 5):
+ The size of the region for copy and paste operations during each step of the drag process.
+ min_lora_scale (`float`, *optional*, defaults to 0.5):
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ min_lora_scale specifies the minimum LoRA scale during the image drag-editing process.
+ generator ('torch.Generator', *optional*, defaults to None):
+ To make generation deterministic(https://pytorch.org/docs/stable/generated/torch.Generator.html).
+ Examples:
+ ```py
+ >>> import PIL
+ >>> import torch
+ >>> from diffusers import DDIMScheduler, DiffusionPipeline
+
+ >>> # Load the pipeline
+ >>> model_path = "runwayml/stable-diffusion-v1-5"
+ >>> scheduler = DDIMScheduler.from_pretrained(model_path, subfolder="scheduler")
+ >>> pipe = DiffusionPipeline.from_pretrained(model_path, scheduler=scheduler, custom_pipeline="sde_drag")
+ >>> pipe.to('cuda')
+
+ >>> # To save GPU memory, torch.float16 can be used, but it may compromise image quality.
+ >>> # If not training LoRA, please avoid using torch.float16
+ >>> # pipe.to(torch.float16)
+
+ >>> # Provide prompt, image, mask image, and the starting and target points for drag editing.
+ >>> prompt = "prompt of the image"
+ >>> image = PIL.Image.open('/path/to/image')
+ >>> mask_image = PIL.Image.open('/path/to/mask_image')
+ >>> source_points = [[123, 456]]
+ >>> target_points = [[234, 567]]
+
+ >>> # train_lora is optional, and in most cases, using train_lora can better preserve consistency with the original image.
+ >>> pipe.train_lora(prompt, image)
+
+ >>> output = pipe(prompt, image, mask_image, source_points, target_points)
+ >>> output_image = PIL.Image.fromarray(output)
+ >>> output_image.save("./output.png")
+ ```
+ """
+
+ self.scheduler.set_timesteps(steps)
+
+ noise_scale = (1 - image_scale**2) ** (0.5)
+
+ text_embeddings = self._get_text_embed(prompt)
+ uncond_embeddings = self._get_text_embed([""])
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
+
+ latent = self._get_img_latent(image)
+
+ mask = mask_image.resize((latent.shape[3], latent.shape[2]))
+ mask = torch.tensor(np.array(mask))
+ mask = mask.unsqueeze(0).expand_as(latent).to(self.device)
+
+ source_points = torch.tensor(source_points).div(torch.tensor([8]), rounding_mode="trunc")
+ target_points = torch.tensor(target_points).div(torch.tensor([8]), rounding_mode="trunc")
+
+ distance = target_points - source_points
+ distance_norm_max = torch.norm(distance.float(), dim=1, keepdim=True).max()
+
+ if distance_norm_max <= step_size:
+ drag_num = 1
+ else:
+ drag_num = distance_norm_max.div(torch.tensor([step_size]), rounding_mode="trunc")
+ if (distance_norm_max / drag_num - step_size).abs() > (
+ distance_norm_max / (drag_num + 1) - step_size
+ ).abs():
+ drag_num += 1
+
+ latents = []
+ for i in tqdm(range(int(drag_num)), desc="SDE Drag"):
+ source_new = source_points + (i / drag_num * distance).to(torch.int)
+ target_new = source_points + ((i + 1) / drag_num * distance).to(torch.int)
+
+ latent, noises, hook_latents, lora_scales, cfg_scales = self._forward(
+ latent, steps, t0, min_lora_scale, text_embeddings, generator
+ )
+ latent = self._copy_and_paste(
+ latent,
+ source_new,
+ target_new,
+ adapt_radius,
+ latent.shape[2] - 1,
+ latent.shape[3] - 1,
+ image_scale,
+ noise_scale,
+ generator,
+ )
+ latent = self._backward(
+ latent, mask, steps, t0, noises, hook_latents, lora_scales, cfg_scales, text_embeddings, generator
+ )
+
+ latents.append(latent)
+
+ result_image = 1 / 0.18215 * latents[-1]
+
+ with torch.no_grad():
+ result_image = self.vae.decode(result_image).sample
+
+ result_image = (result_image / 2 + 0.5).clamp(0, 1)
+ result_image = result_image.cpu().permute(0, 2, 3, 1).numpy()[0]
+ result_image = (result_image * 255).astype(np.uint8)
+
+ return result_image
+
+ def train_lora(self, prompt, image, lora_step=100, lora_rank=16, generator=None):
+ accelerator = Accelerator(gradient_accumulation_steps=1, mixed_precision="fp16")
+
+ self.vae.requires_grad_(False)
+ self.text_encoder.requires_grad_(False)
+ self.unet.requires_grad_(False)
+
+ unet_lora_attn_procs = {}
+ for name, attn_processor in self.unet.attn_processors.items():
+ cross_attention_dim = None if name.endswith("attn1.processor") else self.unet.config.cross_attention_dim
+ if name.startswith("mid_block"):
+ hidden_size = self.unet.config.block_out_channels[-1]
+ elif name.startswith("up_blocks"):
+ block_id = int(name[len("up_blocks.")])
+ hidden_size = list(reversed(self.unet.config.block_out_channels))[block_id]
+ elif name.startswith("down_blocks"):
+ block_id = int(name[len("down_blocks.")])
+ hidden_size = self.unet.config.block_out_channels[block_id]
+ else:
+ raise NotImplementedError("name must start with up_blocks, mid_blocks, or down_blocks")
+
+ if isinstance(attn_processor, (AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0)):
+ lora_attn_processor_class = LoRAAttnAddedKVProcessor
+ else:
+ lora_attn_processor_class = (
+ LoRAAttnProcessor2_0
+ if hasattr(torch.nn.functional, "scaled_dot_product_attention")
+ else LoRAAttnProcessor
+ )
+ unet_lora_attn_procs[name] = lora_attn_processor_class(
+ hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=lora_rank
+ )
+
+ self.unet.set_attn_processor(unet_lora_attn_procs)
+ unet_lora_layers = AttnProcsLayers(self.unet.attn_processors)
+ params_to_optimize = unet_lora_layers.parameters()
+
+ optimizer = torch.optim.AdamW(
+ params_to_optimize,
+ lr=2e-4,
+ betas=(0.9, 0.999),
+ weight_decay=1e-2,
+ eps=1e-08,
+ )
+
+ lr_scheduler = get_scheduler(
+ "constant",
+ optimizer=optimizer,
+ num_warmup_steps=0,
+ num_training_steps=lora_step,
+ num_cycles=1,
+ power=1.0,
+ )
+
+ unet_lora_layers = accelerator.prepare_model(unet_lora_layers)
+ optimizer = accelerator.prepare_optimizer(optimizer)
+ lr_scheduler = accelerator.prepare_scheduler(lr_scheduler)
+
+ with torch.no_grad():
+ text_inputs = self._tokenize_prompt(prompt, tokenizer_max_length=None)
+ text_embedding = self._encode_prompt(
+ text_inputs.input_ids, text_inputs.attention_mask, text_encoder_use_attention_mask=False
+ )
+
+ image_transforms = transforms.Compose(
+ [
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+ )
+
+ image = image_transforms(image).to(self.device, dtype=self.vae.dtype)
+ image = image.unsqueeze(dim=0)
+ latents_dist = self.vae.encode(image).latent_dist
+
+ for _ in tqdm(range(lora_step), desc="Train LoRA"):
+ self.unet.train()
+ model_input = latents_dist.sample() * self.vae.config.scaling_factor
+
+ # Sample noise that we'll add to the latents
+ noise = torch.randn(
+ model_input.size(),
+ dtype=model_input.dtype,
+ layout=model_input.layout,
+ device=model_input.device,
+ generator=generator,
+ )
+ bsz, channels, height, width = model_input.shape
+
+ # Sample a random timestep for each image
+ timesteps = torch.randint(
+ 0, self.scheduler.config.num_train_timesteps, (bsz,), device=model_input.device, generator=generator
+ )
+ timesteps = timesteps.long()
+
+ # Add noise to the model input according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ noisy_model_input = self.scheduler.add_noise(model_input, noise, timesteps)
+
+ # Predict the noise residual
+ model_pred = self.unet(noisy_model_input, timesteps, text_embedding).sample
+
+ # Get the target for loss depending on the prediction type
+ if self.scheduler.config.prediction_type == "epsilon":
+ target = noise
+ elif self.scheduler.config.prediction_type == "v_prediction":
+ target = self.scheduler.get_velocity(model_input, noise, timesteps)
+ else:
+ raise ValueError(f"Unknown prediction type {self.scheduler.config.prediction_type}")
+
+ loss = torch.nn.functional.mse_loss(model_pred.float(), target.float(), reduction="mean")
+ accelerator.backward(loss)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ with tempfile.TemporaryDirectory() as save_lora_dir:
+ StableDiffusionLoraLoaderMixin.save_lora_weights(
+ save_directory=save_lora_dir,
+ unet_lora_layers=unet_lora_layers,
+ text_encoder_lora_layers=None,
+ )
+
+ self.unet.load_attn_procs(save_lora_dir)
+
+ def _tokenize_prompt(self, prompt, tokenizer_max_length=None):
+ if tokenizer_max_length is not None:
+ max_length = tokenizer_max_length
+ else:
+ max_length = self.tokenizer.model_max_length
+
+ text_inputs = self.tokenizer(
+ prompt,
+ truncation=True,
+ padding="max_length",
+ max_length=max_length,
+ return_tensors="pt",
+ )
+
+ return text_inputs
+
+ def _encode_prompt(self, input_ids, attention_mask, text_encoder_use_attention_mask=False):
+ text_input_ids = input_ids.to(self.device)
+
+ if text_encoder_use_attention_mask:
+ attention_mask = attention_mask.to(self.device)
+ else:
+ attention_mask = None
+
+ prompt_embeds = self.text_encoder(
+ text_input_ids,
+ attention_mask=attention_mask,
+ )
+ prompt_embeds = prompt_embeds[0]
+
+ return prompt_embeds
+
+ @torch.no_grad()
+ def _get_text_embed(self, prompt):
+ text_input = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
+ return text_embeddings
+
+ def _copy_and_paste(
+ self, latent, source_new, target_new, adapt_radius, max_height, max_width, image_scale, noise_scale, generator
+ ):
+ def adaption_r(source, target, adapt_radius, max_height, max_width):
+ r_x_lower = min(adapt_radius, source[0], target[0])
+ r_x_upper = min(adapt_radius, max_width - source[0], max_width - target[0])
+ r_y_lower = min(adapt_radius, source[1], target[1])
+ r_y_upper = min(adapt_radius, max_height - source[1], max_height - target[1])
+ return r_x_lower, r_x_upper, r_y_lower, r_y_upper
+
+ for source_, target_ in zip(source_new, target_new):
+ r_x_lower, r_x_upper, r_y_lower, r_y_upper = adaption_r(
+ source_, target_, adapt_radius, max_height, max_width
+ )
+
+ source_feature = latent[
+ :, :, source_[1] - r_y_lower : source_[1] + r_y_upper, source_[0] - r_x_lower : source_[0] + r_x_upper
+ ].clone()
+
+ latent[
+ :, :, source_[1] - r_y_lower : source_[1] + r_y_upper, source_[0] - r_x_lower : source_[0] + r_x_upper
+ ] = image_scale * source_feature + noise_scale * torch.randn(
+ latent.shape[0],
+ 4,
+ r_y_lower + r_y_upper,
+ r_x_lower + r_x_upper,
+ device=self.device,
+ generator=generator,
+ )
+
+ latent[
+ :, :, target_[1] - r_y_lower : target_[1] + r_y_upper, target_[0] - r_x_lower : target_[0] + r_x_upper
+ ] = source_feature * 1.1
+ return latent
+
+ @torch.no_grad()
+ def _get_img_latent(self, image, height=None, weight=None):
+ data = image.convert("RGB")
+ if height is not None:
+ data = data.resize((weight, height))
+ transform = transforms.ToTensor()
+ data = transform(data).unsqueeze(0)
+ data = (data * 2.0) - 1.0
+ data = data.to(self.device, dtype=self.vae.dtype)
+ latent = self.vae.encode(data).latent_dist.sample()
+ latent = 0.18215 * latent
+ return latent
+
+ @torch.no_grad()
+ def _get_eps(self, latent, timestep, guidance_scale, text_embeddings, lora_scale=None):
+ latent_model_input = torch.cat([latent] * 2) if guidance_scale > 1.0 else latent
+ text_embeddings = text_embeddings if guidance_scale > 1.0 else text_embeddings.chunk(2)[1]
+
+ cross_attention_kwargs = None if lora_scale is None else {"scale": lora_scale}
+
+ with torch.no_grad():
+ noise_pred = self.unet(
+ latent_model_input,
+ timestep,
+ encoder_hidden_states=text_embeddings,
+ cross_attention_kwargs=cross_attention_kwargs,
+ ).sample
+
+ if guidance_scale > 1.0:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ elif guidance_scale == 1.0:
+ noise_pred_text = noise_pred
+ noise_pred_uncond = 0.0
+ else:
+ raise NotImplementedError(guidance_scale)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ return noise_pred
+
+ def _forward_sde(
+ self, timestep, sample, guidance_scale, text_embeddings, steps, eta=1.0, lora_scale=None, generator=None
+ ):
+ num_train_timesteps = len(self.scheduler)
+ alphas_cumprod = self.scheduler.alphas_cumprod
+ initial_alpha_cumprod = torch.tensor(1.0)
+
+ prev_timestep = timestep + num_train_timesteps // steps
+
+ alpha_prod_t = alphas_cumprod[timestep] if timestep >= 0 else initial_alpha_cumprod
+ alpha_prod_t_prev = alphas_cumprod[prev_timestep]
+
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
+
+ x_prev = (alpha_prod_t_prev / alpha_prod_t) ** (0.5) * sample + (1 - alpha_prod_t_prev / alpha_prod_t) ** (
+ 0.5
+ ) * torch.randn(
+ sample.size(), dtype=sample.dtype, layout=sample.layout, device=self.device, generator=generator
+ )
+ eps = self._get_eps(x_prev, prev_timestep, guidance_scale, text_embeddings, lora_scale)
+
+ sigma_t_prev = (
+ eta
+ * (1 - alpha_prod_t) ** (0.5)
+ * (1 - alpha_prod_t_prev / (1 - alpha_prod_t_prev) * (1 - alpha_prod_t) / alpha_prod_t) ** (0.5)
+ )
+
+ pred_original_sample = (x_prev - beta_prod_t_prev ** (0.5) * eps) / alpha_prod_t_prev ** (0.5)
+ pred_sample_direction_coeff = (1 - alpha_prod_t - sigma_t_prev**2) ** (0.5)
+
+ noise = (
+ sample - alpha_prod_t ** (0.5) * pred_original_sample - pred_sample_direction_coeff * eps
+ ) / sigma_t_prev
+
+ return x_prev, noise
+
+ def _sample(
+ self,
+ timestep,
+ sample,
+ guidance_scale,
+ text_embeddings,
+ steps,
+ sde=False,
+ noise=None,
+ eta=1.0,
+ lora_scale=None,
+ generator=None,
+ ):
+ num_train_timesteps = len(self.scheduler)
+ alphas_cumprod = self.scheduler.alphas_cumprod
+ final_alpha_cumprod = torch.tensor(1.0)
+
+ eps = self._get_eps(sample, timestep, guidance_scale, text_embeddings, lora_scale)
+
+ prev_timestep = timestep - num_train_timesteps // steps
+
+ alpha_prod_t = alphas_cumprod[timestep]
+ alpha_prod_t_prev = alphas_cumprod[prev_timestep] if prev_timestep >= 0 else final_alpha_cumprod
+
+ beta_prod_t = 1 - alpha_prod_t
+
+ sigma_t = (
+ eta
+ * ((1 - alpha_prod_t_prev) / (1 - alpha_prod_t)) ** (0.5)
+ * (1 - alpha_prod_t / alpha_prod_t_prev) ** (0.5)
+ if sde
+ else 0
+ )
+
+ pred_original_sample = (sample - beta_prod_t ** (0.5) * eps) / alpha_prod_t ** (0.5)
+ pred_sample_direction_coeff = (1 - alpha_prod_t_prev - sigma_t**2) ** (0.5)
+
+ noise = (
+ torch.randn(
+ sample.size(), dtype=sample.dtype, layout=sample.layout, device=self.device, generator=generator
+ )
+ if noise is None
+ else noise
+ )
+ latent = (
+ alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction_coeff * eps + sigma_t * noise
+ )
+
+ return latent
+
+ def _forward(self, latent, steps, t0, lora_scale_min, text_embeddings, generator):
+ def scale_schedule(begin, end, n, length, type="linear"):
+ if type == "constant":
+ return end
+ elif type == "linear":
+ return begin + (end - begin) * n / length
+ elif type == "cos":
+ factor = (1 - math.cos(n * math.pi / length)) / 2
+ return (1 - factor) * begin + factor * end
+ else:
+ raise NotImplementedError(type)
+
+ noises = []
+ latents = []
+ lora_scales = []
+ cfg_scales = []
+ latents.append(latent)
+ t0 = int(t0 * steps)
+ t_begin = steps - t0
+
+ length = len(self.scheduler.timesteps[t_begin - 1 : -1]) - 1
+ index = 1
+ for t in self.scheduler.timesteps[t_begin:].flip(dims=[0]):
+ lora_scale = scale_schedule(1, lora_scale_min, index, length, type="cos")
+ cfg_scale = scale_schedule(1, 3.0, index, length, type="linear")
+ latent, noise = self._forward_sde(
+ t, latent, cfg_scale, text_embeddings, steps, lora_scale=lora_scale, generator=generator
+ )
+
+ noises.append(noise)
+ latents.append(latent)
+ lora_scales.append(lora_scale)
+ cfg_scales.append(cfg_scale)
+ index += 1
+ return latent, noises, latents, lora_scales, cfg_scales
+
+ def _backward(
+ self, latent, mask, steps, t0, noises, hook_latents, lora_scales, cfg_scales, text_embeddings, generator
+ ):
+ t0 = int(t0 * steps)
+ t_begin = steps - t0
+
+ hook_latent = hook_latents.pop()
+ latent = torch.where(mask > 128, latent, hook_latent)
+ for t in self.scheduler.timesteps[t_begin - 1 : -1]:
+ latent = self._sample(
+ t,
+ latent,
+ cfg_scales.pop(),
+ text_embeddings,
+ steps,
+ sde=True,
+ noise=noises.pop(),
+ lora_scale=lora_scales.pop(),
+ generator=generator,
+ )
+ hook_latent = hook_latents.pop()
+ latent = torch.where(mask > 128, latent, hook_latent)
+ return latent
diff --git a/diffusers/examples/community/seed_resize_stable_diffusion.py b/diffusers/examples/community/seed_resize_stable_diffusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..ae2d8a53b2984ad74606cae2a4293e578f56b939
--- /dev/null
+++ b/diffusers/examples/community/seed_resize_stable_diffusion.py
@@ -0,0 +1,342 @@
+"""
+modified based on diffusion library from Huggingface: https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
+"""
+
+import inspect
+from typing import Callable, List, Optional, Union
+
+import torch
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
+
+from diffusers import DiffusionPipeline
+from diffusers.models import AutoencoderKL, UNet2DConditionModel
+from diffusers.pipelines.pipeline_utils import StableDiffusionMixin
+from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
+from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
+from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
+from diffusers.utils import logging
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class SeedResizeStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin):
+ r"""
+ Pipeline for text-to-image generation using Stable Diffusion.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
+ feature_extractor ([`CLIPImageProcessor`]):
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+ """
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPImageProcessor,
+ ):
+ super().__init__()
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]],
+ height: int = 512,
+ width: int = 512,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[torch.Generator] = None,
+ latents: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
+ callback_steps: int = 1,
+ text_embeddings: Optional[torch.Tensor] = None,
+ **kwargs,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image generation.
+ height (`int`, *optional*, defaults to 512):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to 512):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator`, *optional*):
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
+ deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+
+ if isinstance(prompt, str):
+ batch_size = 1
+ elif isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ # get prompt text embeddings
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+
+ if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
+ removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+ text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
+
+ if text_embeddings is None:
+ text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ bs_embed, seq_len, _ = text_embeddings.shape
+ text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
+ text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""]
+ elif type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ max_length = text_input_ids.shape[-1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
+
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = uncond_embeddings.shape[1]
+ uncond_embeddings = uncond_embeddings.repeat(batch_size, num_images_per_prompt, 1)
+ uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
+
+ # get the initial random noise unless the user supplied it
+
+ # Unlike in other pipelines, latents need to be generated in the target device
+ # for 1-to-1 results reproducibility with the CompVis implementation.
+ # However this currently doesn't work in `mps`.
+ latents_shape = (batch_size * num_images_per_prompt, self.unet.config.in_channels, height // 8, width // 8)
+ latents_shape_reference = (batch_size * num_images_per_prompt, self.unet.config.in_channels, 64, 64)
+ latents_dtype = text_embeddings.dtype
+ if latents is None:
+ if self.device.type == "mps":
+ # randn does not exist on mps
+ latents_reference = torch.randn(
+ latents_shape_reference, generator=generator, device="cpu", dtype=latents_dtype
+ ).to(self.device)
+ latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to(
+ self.device
+ )
+ else:
+ latents_reference = torch.randn(
+ latents_shape_reference, generator=generator, device=self.device, dtype=latents_dtype
+ )
+ latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)
+ else:
+ if latents_reference.shape != latents_shape:
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
+ latents_reference = latents_reference.to(self.device)
+ latents = latents.to(self.device)
+
+ # This is the key part of the pipeline where we
+ # try to ensure that the generated images w/ the same seed
+ # but different sizes actually result in similar images
+ dx = (latents_shape[3] - latents_shape_reference[3]) // 2
+ dy = (latents_shape[2] - latents_shape_reference[2]) // 2
+ w = latents_shape_reference[3] if dx >= 0 else latents_shape_reference[3] + 2 * dx
+ h = latents_shape_reference[2] if dy >= 0 else latents_shape_reference[2] + 2 * dy
+ tx = 0 if dx < 0 else dx
+ ty = 0 if dy < 0 else dy
+ dx = max(-dx, 0)
+ dy = max(-dy, 0)
+ # import pdb
+ # pdb.set_trace()
+ latents[:, :, ty : ty + h, tx : tx + w] = latents_reference[:, :, dy : dy + h, dx : dx + w]
+
+ # set timesteps
+ self.scheduler.set_timesteps(num_inference_steps)
+
+ # Some schedulers like PNDM have timesteps as arrays
+ # It's more optimized to move all timesteps to correct device beforehand
+ timesteps_tensor = self.scheduler.timesteps.to(self.device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ for i, t in enumerate(self.progress_bar(timesteps_tensor)):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+ # call the callback, if provided
+ if callback is not None and i % callback_steps == 0:
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
+
+ latents = 1 / 0.18215 * latents
+ image = self.vae.decode(latents).sample
+
+ image = (image / 2 + 0.5).clamp(0, 1)
+
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+
+ if self.safety_checker is not None:
+ safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(
+ self.device
+ )
+ image, has_nsfw_concept = self.safety_checker(
+ images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)
+ )
+ else:
+ has_nsfw_concept = None
+
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
diff --git a/diffusers/examples/community/speech_to_image_diffusion.py b/diffusers/examples/community/speech_to_image_diffusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..9cb5a2a8c7f96858850c47b03549e579274750d6
--- /dev/null
+++ b/diffusers/examples/community/speech_to_image_diffusion.py
@@ -0,0 +1,255 @@
+import inspect
+from typing import Callable, List, Optional, Union
+
+import torch
+from transformers import (
+ CLIPImageProcessor,
+ CLIPTextModel,
+ CLIPTokenizer,
+ WhisperForConditionalGeneration,
+ WhisperProcessor,
+)
+
+from diffusers import (
+ AutoencoderKL,
+ DDIMScheduler,
+ DiffusionPipeline,
+ LMSDiscreteScheduler,
+ PNDMScheduler,
+ UNet2DConditionModel,
+)
+from diffusers.pipelines.pipeline_utils import StableDiffusionMixin
+from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput
+from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
+from diffusers.utils import logging
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class SpeechToImagePipeline(DiffusionPipeline, StableDiffusionMixin):
+ def __init__(
+ self,
+ speech_model: WhisperForConditionalGeneration,
+ speech_processor: WhisperProcessor,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPImageProcessor,
+ ):
+ super().__init__()
+
+ if safety_checker is None:
+ logger.warning(
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
+ )
+
+ self.register_modules(
+ speech_model=speech_model,
+ speech_processor=speech_processor,
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ feature_extractor=feature_extractor,
+ )
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ audio,
+ sampling_rate=16_000,
+ height: int = 512,
+ width: int = 512,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[torch.Generator] = None,
+ latents: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
+ callback_steps: int = 1,
+ **kwargs,
+ ):
+ inputs = self.speech_processor.feature_extractor(
+ audio, return_tensors="pt", sampling_rate=sampling_rate
+ ).input_features.to(self.device)
+ predicted_ids = self.speech_model.generate(inputs, max_length=480_000)
+
+ prompt = self.speech_processor.tokenizer.batch_decode(predicted_ids, skip_special_tokens=True, normalize=True)[
+ 0
+ ]
+
+ if isinstance(prompt, str):
+ batch_size = 1
+ elif isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ # get prompt text embeddings
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+
+ if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
+ removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+ text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
+ text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ bs_embed, seq_len, _ = text_embeddings.shape
+ text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
+ text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ max_length = text_input_ids.shape[-1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
+
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = uncond_embeddings.shape[1]
+ uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
+ uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
+
+ # get the initial random noise unless the user supplied it
+
+ # Unlike in other pipelines, latents need to be generated in the target device
+ # for 1-to-1 results reproducibility with the CompVis implementation.
+ # However this currently doesn't work in `mps`.
+ latents_shape = (batch_size * num_images_per_prompt, self.unet.config.in_channels, height // 8, width // 8)
+ latents_dtype = text_embeddings.dtype
+ if latents is None:
+ if self.device.type == "mps":
+ # randn does not exist on mps
+ latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to(
+ self.device
+ )
+ else:
+ latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)
+ else:
+ if latents.shape != latents_shape:
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
+ latents = latents.to(self.device)
+
+ # set timesteps
+ self.scheduler.set_timesteps(num_inference_steps)
+
+ # Some schedulers like PNDM have timesteps as arrays
+ # It's more optimized to move all timesteps to correct device beforehand
+ timesteps_tensor = self.scheduler.timesteps.to(self.device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ for i, t in enumerate(self.progress_bar(timesteps_tensor)):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+ # call the callback, if provided
+ if callback is not None and i % callback_steps == 0:
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
+
+ latents = 1 / 0.18215 * latents
+ image = self.vae.decode(latents).sample
+
+ image = (image / 2 + 0.5).clamp(0, 1)
+
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return image
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=None)
diff --git a/diffusers/examples/community/stable_diffusion_comparison.py b/diffusers/examples/community/stable_diffusion_comparison.py
new file mode 100644
index 0000000000000000000000000000000000000000..2b510a64f854ab37c8ab47bc49f37fe4d48c013f
--- /dev/null
+++ b/diffusers/examples/community/stable_diffusion_comparison.py
@@ -0,0 +1,381 @@
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import torch
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
+
+from diffusers import (
+ AutoencoderKL,
+ DDIMScheduler,
+ DiffusionPipeline,
+ LMSDiscreteScheduler,
+ PNDMScheduler,
+ StableDiffusionPipeline,
+ UNet2DConditionModel,
+)
+from diffusers.pipelines.pipeline_utils import StableDiffusionMixin
+from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
+from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
+
+
+pipe1_model_id = "CompVis/stable-diffusion-v1-1"
+pipe2_model_id = "CompVis/stable-diffusion-v1-2"
+pipe3_model_id = "CompVis/stable-diffusion-v1-3"
+pipe4_model_id = "CompVis/stable-diffusion-v1-4"
+
+
+class StableDiffusionComparisonPipeline(DiffusionPipeline, StableDiffusionMixin):
+ r"""
+ Pipeline for parallel comparison of Stable Diffusion v1-v4
+ This pipeline inherits from DiffusionPipeline and depends on the use of an Auth Token for
+ downloading pre-trained checkpoints from Hugging Face Hub.
+ If using Hugging Face Hub, pass the Model ID for Stable Diffusion v1.4 as the previous 3 checkpoints will be loaded
+ automatically.
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionMegaSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ feature_extractor ([`CLIPImageProcessor`]):
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+ """
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPImageProcessor,
+ requires_safety_checker: bool = True,
+ ):
+ super()._init_()
+
+ self.pipe1 = StableDiffusionPipeline.from_pretrained(pipe1_model_id)
+ self.pipe2 = StableDiffusionPipeline.from_pretrained(pipe2_model_id)
+ self.pipe3 = StableDiffusionPipeline.from_pretrained(pipe3_model_id)
+ self.pipe4 = StableDiffusionPipeline(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ requires_safety_checker=requires_safety_checker,
+ )
+
+ self.register_modules(pipeline1=self.pipe1, pipeline2=self.pipe2, pipeline3=self.pipe3, pipeline4=self.pipe4)
+
+ @property
+ def layers(self) -> Dict[str, Any]:
+ return {k: getattr(self, k) for k in self.config.keys() if not k.startswith("_")}
+
+ @torch.no_grad()
+ def text2img_sd1_1(
+ self,
+ prompt: Union[str, List[str]],
+ height: int = 512,
+ width: int = 512,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[torch.Generator] = None,
+ latents: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
+ callback_steps: int = 1,
+ **kwargs,
+ ):
+ return self.pipe1(
+ prompt=prompt,
+ height=height,
+ width=width,
+ num_inference_steps=num_inference_steps,
+ guidance_scale=guidance_scale,
+ negative_prompt=negative_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ eta=eta,
+ generator=generator,
+ latents=latents,
+ output_type=output_type,
+ return_dict=return_dict,
+ callback=callback,
+ callback_steps=callback_steps,
+ **kwargs,
+ )
+
+ @torch.no_grad()
+ def text2img_sd1_2(
+ self,
+ prompt: Union[str, List[str]],
+ height: int = 512,
+ width: int = 512,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[torch.Generator] = None,
+ latents: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
+ callback_steps: int = 1,
+ **kwargs,
+ ):
+ return self.pipe2(
+ prompt=prompt,
+ height=height,
+ width=width,
+ num_inference_steps=num_inference_steps,
+ guidance_scale=guidance_scale,
+ negative_prompt=negative_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ eta=eta,
+ generator=generator,
+ latents=latents,
+ output_type=output_type,
+ return_dict=return_dict,
+ callback=callback,
+ callback_steps=callback_steps,
+ **kwargs,
+ )
+
+ @torch.no_grad()
+ def text2img_sd1_3(
+ self,
+ prompt: Union[str, List[str]],
+ height: int = 512,
+ width: int = 512,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[torch.Generator] = None,
+ latents: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
+ callback_steps: int = 1,
+ **kwargs,
+ ):
+ return self.pipe3(
+ prompt=prompt,
+ height=height,
+ width=width,
+ num_inference_steps=num_inference_steps,
+ guidance_scale=guidance_scale,
+ negative_prompt=negative_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ eta=eta,
+ generator=generator,
+ latents=latents,
+ output_type=output_type,
+ return_dict=return_dict,
+ callback=callback,
+ callback_steps=callback_steps,
+ **kwargs,
+ )
+
+ @torch.no_grad()
+ def text2img_sd1_4(
+ self,
+ prompt: Union[str, List[str]],
+ height: int = 512,
+ width: int = 512,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[torch.Generator] = None,
+ latents: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
+ callback_steps: int = 1,
+ **kwargs,
+ ):
+ return self.pipe4(
+ prompt=prompt,
+ height=height,
+ width=width,
+ num_inference_steps=num_inference_steps,
+ guidance_scale=guidance_scale,
+ negative_prompt=negative_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ eta=eta,
+ generator=generator,
+ latents=latents,
+ output_type=output_type,
+ return_dict=return_dict,
+ callback=callback,
+ callback_steps=callback_steps,
+ **kwargs,
+ )
+
+ @torch.no_grad()
+ def _call_(
+ self,
+ prompt: Union[str, List[str]],
+ height: int = 512,
+ width: int = 512,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[torch.Generator] = None,
+ latents: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
+ callback_steps: int = 1,
+ **kwargs,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation. This function will generate 4 results as part
+ of running all the 4 pipelines for SD1.1-1.4 together in a serial-processing, parallel-invocation fashion.
+ Args:
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image generation.
+ height (`int`, optional, defaults to 512):
+ The height in pixels of the generated image.
+ width (`int`, optional, defaults to 512):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, optional, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, optional, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ eta (`float`, optional, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator`, optional):
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
+ deterministic.
+ latents (`torch.Tensor`, optional):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ output_type (`str`, optional, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, optional, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ self.to(device)
+
+ # Checks if the height and width are divisible by 8 or not
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` must be divisible by 8 but are {height} and {width}.")
+
+ # Get first result from Stable Diffusion Checkpoint v1.1
+ res1 = self.text2img_sd1_1(
+ prompt=prompt,
+ height=height,
+ width=width,
+ num_inference_steps=num_inference_steps,
+ guidance_scale=guidance_scale,
+ negative_prompt=negative_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ eta=eta,
+ generator=generator,
+ latents=latents,
+ output_type=output_type,
+ return_dict=return_dict,
+ callback=callback,
+ callback_steps=callback_steps,
+ **kwargs,
+ )
+
+ # Get first result from Stable Diffusion Checkpoint v1.2
+ res2 = self.text2img_sd1_2(
+ prompt=prompt,
+ height=height,
+ width=width,
+ num_inference_steps=num_inference_steps,
+ guidance_scale=guidance_scale,
+ negative_prompt=negative_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ eta=eta,
+ generator=generator,
+ latents=latents,
+ output_type=output_type,
+ return_dict=return_dict,
+ callback=callback,
+ callback_steps=callback_steps,
+ **kwargs,
+ )
+
+ # Get first result from Stable Diffusion Checkpoint v1.3
+ res3 = self.text2img_sd1_3(
+ prompt=prompt,
+ height=height,
+ width=width,
+ num_inference_steps=num_inference_steps,
+ guidance_scale=guidance_scale,
+ negative_prompt=negative_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ eta=eta,
+ generator=generator,
+ latents=latents,
+ output_type=output_type,
+ return_dict=return_dict,
+ callback=callback,
+ callback_steps=callback_steps,
+ **kwargs,
+ )
+
+ # Get first result from Stable Diffusion Checkpoint v1.4
+ res4 = self.text2img_sd1_4(
+ prompt=prompt,
+ height=height,
+ width=width,
+ num_inference_steps=num_inference_steps,
+ guidance_scale=guidance_scale,
+ negative_prompt=negative_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ eta=eta,
+ generator=generator,
+ latents=latents,
+ output_type=output_type,
+ return_dict=return_dict,
+ callback=callback,
+ callback_steps=callback_steps,
+ **kwargs,
+ )
+
+ # Get all result images into a single list and pass it via StableDiffusionPipelineOutput for final result
+ return StableDiffusionPipelineOutput([res1[0], res2[0], res3[0], res4[0]])
diff --git a/diffusers/examples/community/stable_diffusion_controlnet_img2img.py b/diffusers/examples/community/stable_diffusion_controlnet_img2img.py
new file mode 100644
index 0000000000000000000000000000000000000000..c7c88d6fdcc72a9c39a1f784fa15698ee2d9d1ed
--- /dev/null
+++ b/diffusers/examples/community/stable_diffusion_controlnet_img2img.py
@@ -0,0 +1,907 @@
+# Inspired by: https://github.com/haofanwang/ControlNet-for-Diffusers/
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import PIL.Image
+import torch
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
+
+from diffusers import AutoencoderKL, ControlNetModel, UNet2DConditionModel, logging
+from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
+from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
+from diffusers.schedulers import KarrasDiffusionSchedulers
+from diffusers.utils import (
+ PIL_INTERPOLATION,
+ replace_example_docstring,
+)
+from diffusers.utils.torch_utils import randn_tensor
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import numpy as np
+ >>> import torch
+ >>> from PIL import Image
+ >>> from diffusers import ControlNetModel, UniPCMultistepScheduler
+ >>> from diffusers.utils import load_image
+
+ >>> input_image = load_image("https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png")
+
+ >>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16)
+
+ >>> pipe_controlnet = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
+ "runwayml/stable-diffusion-v1-5",
+ controlnet=controlnet,
+ safety_checker=None,
+ torch_dtype=torch.float16
+ )
+
+ >>> pipe_controlnet.scheduler = UniPCMultistepScheduler.from_config(pipe_controlnet.scheduler.config)
+ >>> pipe_controlnet.enable_xformers_memory_efficient_attention()
+ >>> pipe_controlnet.enable_model_cpu_offload()
+
+ # using image with edges for our canny controlnet
+ >>> control_image = load_image(
+ "https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/vermeer_canny_edged.png")
+
+
+ >>> result_img = pipe_controlnet(controlnet_conditioning_image=control_image,
+ image=input_image,
+ prompt="an android robot, cyberpank, digitl art masterpiece",
+ num_inference_steps=20).images[0]
+
+ >>> result_img.show()
+ ```
+"""
+
+
+def prepare_image(image):
+ if isinstance(image, torch.Tensor):
+ # Batch single image
+ if image.ndim == 3:
+ image = image.unsqueeze(0)
+
+ image = image.to(dtype=torch.float32)
+ else:
+ # preprocess image
+ if isinstance(image, (PIL.Image.Image, np.ndarray)):
+ image = [image]
+
+ if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):
+ image = [np.array(i.convert("RGB"))[None, :] for i in image]
+ image = np.concatenate(image, axis=0)
+ elif isinstance(image, list) and isinstance(image[0], np.ndarray):
+ image = np.concatenate([i[None, :] for i in image], axis=0)
+
+ image = image.transpose(0, 3, 1, 2)
+ image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
+
+ return image
+
+
+def prepare_controlnet_conditioning_image(
+ controlnet_conditioning_image,
+ width,
+ height,
+ batch_size,
+ num_images_per_prompt,
+ device,
+ dtype,
+ do_classifier_free_guidance,
+):
+ if not isinstance(controlnet_conditioning_image, torch.Tensor):
+ if isinstance(controlnet_conditioning_image, PIL.Image.Image):
+ controlnet_conditioning_image = [controlnet_conditioning_image]
+
+ if isinstance(controlnet_conditioning_image[0], PIL.Image.Image):
+ controlnet_conditioning_image = [
+ np.array(i.resize((width, height), resample=PIL_INTERPOLATION["lanczos"]))[None, :]
+ for i in controlnet_conditioning_image
+ ]
+ controlnet_conditioning_image = np.concatenate(controlnet_conditioning_image, axis=0)
+ controlnet_conditioning_image = np.array(controlnet_conditioning_image).astype(np.float32) / 255.0
+ controlnet_conditioning_image = controlnet_conditioning_image.transpose(0, 3, 1, 2)
+ controlnet_conditioning_image = torch.from_numpy(controlnet_conditioning_image)
+ elif isinstance(controlnet_conditioning_image[0], torch.Tensor):
+ controlnet_conditioning_image = torch.cat(controlnet_conditioning_image, dim=0)
+
+ image_batch_size = controlnet_conditioning_image.shape[0]
+
+ if image_batch_size == 1:
+ repeat_by = batch_size
+ else:
+ # image batch size is the same as prompt batch size
+ repeat_by = num_images_per_prompt
+
+ controlnet_conditioning_image = controlnet_conditioning_image.repeat_interleave(repeat_by, dim=0)
+
+ controlnet_conditioning_image = controlnet_conditioning_image.to(device=device, dtype=dtype)
+
+ if do_classifier_free_guidance:
+ controlnet_conditioning_image = torch.cat([controlnet_conditioning_image] * 2)
+
+ return controlnet_conditioning_image
+
+
+class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, StableDiffusionMixin):
+ """
+ Inspired by: https://github.com/haofanwang/ControlNet-for-Diffusers/
+ """
+
+ _optional_components = ["safety_checker", "feature_extractor"]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],
+ scheduler: KarrasDiffusionSchedulers,
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPImageProcessor,
+ requires_safety_checker: bool = True,
+ ):
+ super().__init__()
+
+ if safety_checker is None and requires_safety_checker:
+ logger.warning(
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
+ )
+
+ if safety_checker is not None and feature_extractor is None:
+ raise ValueError(
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
+ )
+
+ if isinstance(controlnet, (list, tuple)):
+ controlnet = MultiControlNetModel(controlnet)
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ controlnet=controlnet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
+
+ def _encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt=None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ """
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = self.tokenizer.batch_decode(
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
+ )
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = text_inputs.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ prompt_embeds = self.text_encoder(
+ text_input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ prompt_embeds = prompt_embeds[0]
+
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = uncond_input.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ negative_prompt_embeds = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ negative_prompt_embeds = negative_prompt_embeds[0]
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+
+ return prompt_embeds
+
+ def run_safety_checker(self, image, device, dtype):
+ if self.safety_checker is not None:
+ safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
+ image, has_nsfw_concept = self.safety_checker(
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
+ )
+ else:
+ has_nsfw_concept = None
+ return image, has_nsfw_concept
+
+ def decode_latents(self, latents):
+ latents = 1 / self.vae.config.scaling_factor * latents
+ image = self.vae.decode(latents).sample
+ image = (image / 2 + 0.5).clamp(0, 1)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+ return image
+
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_controlnet_conditioning_image(self, image, prompt, prompt_embeds):
+ image_is_pil = isinstance(image, PIL.Image.Image)
+ image_is_tensor = isinstance(image, torch.Tensor)
+ image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)
+ image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)
+
+ if not image_is_pil and not image_is_tensor and not image_is_pil_list and not image_is_tensor_list:
+ raise TypeError(
+ "image must be passed and be one of PIL image, torch tensor, list of PIL images, or list of torch tensors"
+ )
+
+ if image_is_pil:
+ image_batch_size = 1
+ elif image_is_tensor:
+ image_batch_size = image.shape[0]
+ elif image_is_pil_list:
+ image_batch_size = len(image)
+ elif image_is_tensor_list:
+ image_batch_size = len(image)
+ else:
+ raise ValueError("controlnet condition image is not valid")
+
+ if prompt is not None and isinstance(prompt, str):
+ prompt_batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ prompt_batch_size = len(prompt)
+ elif prompt_embeds is not None:
+ prompt_batch_size = prompt_embeds.shape[0]
+ else:
+ raise ValueError("prompt or prompt_embeds are not valid")
+
+ if image_batch_size != 1 and image_batch_size != prompt_batch_size:
+ raise ValueError(
+ f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
+ )
+
+ def check_inputs(
+ self,
+ prompt,
+ image,
+ controlnet_conditioning_image,
+ height,
+ width,
+ callback_steps,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ strength=None,
+ controlnet_guidance_start=None,
+ controlnet_guidance_end=None,
+ controlnet_conditioning_scale=None,
+ ):
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ # check controlnet condition image
+
+ if isinstance(self.controlnet, ControlNetModel):
+ self.check_controlnet_conditioning_image(controlnet_conditioning_image, prompt, prompt_embeds)
+ elif isinstance(self.controlnet, MultiControlNetModel):
+ if not isinstance(controlnet_conditioning_image, list):
+ raise TypeError("For multiple controlnets: `image` must be type `list`")
+
+ if len(controlnet_conditioning_image) != len(self.controlnet.nets):
+ raise ValueError(
+ "For multiple controlnets: `image` must have the same length as the number of controlnets."
+ )
+
+ for image_ in controlnet_conditioning_image:
+ self.check_controlnet_conditioning_image(image_, prompt, prompt_embeds)
+ else:
+ assert False
+
+ # Check `controlnet_conditioning_scale`
+
+ if isinstance(self.controlnet, ControlNetModel):
+ if not isinstance(controlnet_conditioning_scale, float):
+ raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
+ elif isinstance(self.controlnet, MultiControlNetModel):
+ if isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len(
+ self.controlnet.nets
+ ):
+ raise ValueError(
+ "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have"
+ " the same length as the number of controlnets"
+ )
+ else:
+ assert False
+
+ if isinstance(image, torch.Tensor):
+ if image.ndim != 3 and image.ndim != 4:
+ raise ValueError("`image` must have 3 or 4 dimensions")
+
+ if image.ndim == 3:
+ image_batch_size = 1
+ image_channels, image_height, image_width = image.shape
+ elif image.ndim == 4:
+ image_batch_size, image_channels, image_height, image_width = image.shape
+ else:
+ assert False
+
+ if image_channels != 3:
+ raise ValueError("`image` must have 3 channels")
+
+ if image.min() < -1 or image.max() > 1:
+ raise ValueError("`image` should be in range [-1, 1]")
+
+ if self.vae.config.latent_channels != self.unet.config.in_channels:
+ raise ValueError(
+ f"The config of `pipeline.unet` expects {self.unet.config.in_channels} but received"
+ f" latent channels: {self.vae.config.latent_channels},"
+ f" Please verify the config of `pipeline.unet` and the `pipeline.vae`"
+ )
+
+ if strength < 0 or strength > 1:
+ raise ValueError(f"The value of `strength` should in [0.0, 1.0] but is {strength}")
+
+ if controlnet_guidance_start < 0 or controlnet_guidance_start > 1:
+ raise ValueError(
+ f"The value of `controlnet_guidance_start` should in [0.0, 1.0] but is {controlnet_guidance_start}"
+ )
+
+ if controlnet_guidance_end < 0 or controlnet_guidance_end > 1:
+ raise ValueError(
+ f"The value of `controlnet_guidance_end` should in [0.0, 1.0] but is {controlnet_guidance_end}"
+ )
+
+ if controlnet_guidance_start > controlnet_guidance_end:
+ raise ValueError(
+ "The value of `controlnet_guidance_start` should be less than `controlnet_guidance_end`, but got"
+ f" `controlnet_guidance_start` {controlnet_guidance_start} >= `controlnet_guidance_end` {controlnet_guidance_end}"
+ )
+
+ def get_timesteps(self, num_inference_steps, strength, device):
+ # get the original timestep using init_timestep
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
+
+ t_start = max(num_inference_steps - init_timestep, 0)
+ timesteps = self.scheduler.timesteps[t_start:]
+
+ return timesteps, num_inference_steps - t_start
+
+ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
+ if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
+ raise ValueError(
+ f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
+ )
+
+ image = image.to(device=device, dtype=dtype)
+
+ batch_size = batch_size * num_images_per_prompt
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if isinstance(generator, list):
+ init_latents = [
+ self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
+ ]
+ init_latents = torch.cat(init_latents, dim=0)
+ else:
+ init_latents = self.vae.encode(image).latent_dist.sample(generator)
+
+ init_latents = self.vae.config.scaling_factor * init_latents
+
+ if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
+ raise ValueError(
+ f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
+ )
+ else:
+ init_latents = torch.cat([init_latents], dim=0)
+
+ shape = init_latents.shape
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+
+ # get latents
+ init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
+ latents = init_latents
+
+ return latents
+
+ def _default_height_width(self, height, width, image):
+ if isinstance(image, list):
+ image = image[0]
+
+ if height is None:
+ if isinstance(image, PIL.Image.Image):
+ height = image.height
+ elif isinstance(image, torch.Tensor):
+ height = image.shape[3]
+
+ height = (height // 8) * 8 # round down to nearest multiple of 8
+
+ if width is None:
+ if isinstance(image, PIL.Image.Image):
+ width = image.width
+ elif isinstance(image, torch.Tensor):
+ width = image.shape[2]
+
+ width = (width // 8) * 8 # round down to nearest multiple of 8
+
+ return height, width
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ image: Union[torch.Tensor, PIL.Image.Image] = None,
+ controlnet_conditioning_image: Union[
+ torch.Tensor, PIL.Image.Image, List[torch.Tensor], List[PIL.Image.Image]
+ ] = None,
+ strength: float = 0.8,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
+ callback_steps: int = 1,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
+ controlnet_guidance_start: float = 0.0,
+ controlnet_guidance_end: float = 1.0,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ image (`torch.Tensor` or `PIL.Image.Image`):
+ `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will
+ be masked out with `mask_image` and repainted according to `prompt`.
+ controlnet_conditioning_image (`torch.Tensor`, `PIL.Image.Image`, `List[torch.Tensor]` or `List[PIL.Image.Image]`):
+ The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If
+ the type is specified as `torch.Tensor`, it is passed to ControlNet as is. PIL.Image.Image` can
+ also be accepted as an image. The control image is automatically resized to fit the output image.
+ strength (`float`, *optional*):
+ Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image`
+ will be used as a starting point, adding more noise to it the larger the `strength`. The number of
+ denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will
+ be maximum and the denoising process will run for the full number of iterations specified in
+ `num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0):
+ The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added
+ to the residual in the original unet.
+ controlnet_guidance_start ('float', *optional*, defaults to 0.0):
+ The percentage of total steps the controlnet starts applying. Must be between 0 and 1.
+ controlnet_guidance_end ('float', *optional*, defaults to 1.0):
+ The percentage of total steps the controlnet ends applying. Must be between 0 and 1. Must be greater
+ than `controlnet_guidance_start`.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+ # 0. Default height and width to unet
+ height, width = self._default_height_width(height, width, controlnet_conditioning_image)
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ image,
+ controlnet_conditioning_image,
+ height,
+ width,
+ callback_steps,
+ negative_prompt,
+ prompt_embeds,
+ negative_prompt_embeds,
+ strength,
+ controlnet_guidance_start,
+ controlnet_guidance_end,
+ controlnet_conditioning_scale,
+ )
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ if isinstance(self.controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
+ controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(self.controlnet.nets)
+
+ # 3. Encode input prompt
+ prompt_embeds = self._encode_prompt(
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ )
+
+ # 4. Prepare image, and controlnet_conditioning_image
+ image = prepare_image(image)
+
+ # condition image(s)
+ if isinstance(self.controlnet, ControlNetModel):
+ controlnet_conditioning_image = prepare_controlnet_conditioning_image(
+ controlnet_conditioning_image=controlnet_conditioning_image,
+ width=width,
+ height=height,
+ batch_size=batch_size * num_images_per_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ dtype=self.controlnet.dtype,
+ do_classifier_free_guidance=do_classifier_free_guidance,
+ )
+ elif isinstance(self.controlnet, MultiControlNetModel):
+ controlnet_conditioning_images = []
+
+ for image_ in controlnet_conditioning_image:
+ image_ = prepare_controlnet_conditioning_image(
+ controlnet_conditioning_image=image_,
+ width=width,
+ height=height,
+ batch_size=batch_size * num_images_per_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ dtype=self.controlnet.dtype,
+ do_classifier_free_guidance=do_classifier_free_guidance,
+ )
+
+ controlnet_conditioning_images.append(image_)
+
+ controlnet_conditioning_image = controlnet_conditioning_images
+ else:
+ assert False
+
+ # 5. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
+
+ # 6. Prepare latent variables
+ if latents is None:
+ latents = self.prepare_latents(
+ image,
+ latent_timestep,
+ batch_size,
+ num_images_per_prompt,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ )
+
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 8. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # compute the percentage of total steps we are at
+ current_sampling_percent = i / len(timesteps)
+
+ if (
+ current_sampling_percent < controlnet_guidance_start
+ or current_sampling_percent > controlnet_guidance_end
+ ):
+ # do not apply the controlnet
+ down_block_res_samples = None
+ mid_block_res_sample = None
+ else:
+ # apply the controlnet
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ controlnet_cond=controlnet_conditioning_image,
+ conditioning_scale=controlnet_conditioning_scale,
+ return_dict=False,
+ )
+
+ # predict the noise residual
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs=cross_attention_kwargs,
+ down_block_additional_residuals=down_block_res_samples,
+ mid_block_additional_residual=mid_block_res_sample,
+ ).sample
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
+
+ # If we do sequential model offloading, let's offload unet and controlnet
+ # manually for max memory savings
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.unet.to("cpu")
+ self.controlnet.to("cpu")
+ torch.cuda.empty_cache()
+
+ if output_type == "latent":
+ image = latents
+ has_nsfw_concept = None
+ elif output_type == "pil":
+ # 8. Post-processing
+ image = self.decode_latents(latents)
+
+ # 9. Run safety checker
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
+
+ # 10. Convert to PIL
+ image = self.numpy_to_pil(image)
+ else:
+ # 8. Post-processing
+ image = self.decode_latents(latents)
+
+ # 9. Run safety checker
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
+
+ # Offload last model to CPU
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.final_offload_hook.offload()
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
diff --git a/diffusers/examples/community/stable_diffusion_controlnet_inpaint.py b/diffusers/examples/community/stable_diffusion_controlnet_inpaint.py
new file mode 100644
index 0000000000000000000000000000000000000000..b473ffe79933d8ed9fb161359e34d78db90441e7
--- /dev/null
+++ b/diffusers/examples/community/stable_diffusion_controlnet_inpaint.py
@@ -0,0 +1,1060 @@
+# Inspired by: https://github.com/haofanwang/ControlNet-for-Diffusers/
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import PIL.Image
+import torch
+import torch.nn.functional as F
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
+
+from diffusers import AutoencoderKL, ControlNetModel, UNet2DConditionModel, logging
+from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
+from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
+from diffusers.schedulers import KarrasDiffusionSchedulers
+from diffusers.utils import (
+ PIL_INTERPOLATION,
+ replace_example_docstring,
+)
+from diffusers.utils.torch_utils import randn_tensor
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import numpy as np
+ >>> import torch
+ >>> from PIL import Image
+ >>> from stable_diffusion_controlnet_inpaint import StableDiffusionControlNetInpaintPipeline
+
+ >>> from transformers import AutoImageProcessor, UperNetForSemanticSegmentation
+ >>> from diffusers import ControlNetModel, UniPCMultistepScheduler
+ >>> from diffusers.utils import load_image
+
+ >>> def ade_palette():
+ return [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50],
+ [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255],
+ [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7],
+ [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],
+ [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3],
+ [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255],
+ [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220],
+ [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224],
+ [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255],
+ [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7],
+ [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153],
+ [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],
+ [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0],
+ [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],
+ [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255],
+ [11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255],
+ [0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0],
+ [255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0],
+ [0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255],
+ [173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255],
+ [255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20],
+ [255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255],
+ [255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255],
+ [0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255],
+ [0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0],
+ [143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0],
+ [8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255],
+ [255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112],
+ [92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160],
+ [163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163],
+ [255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0],
+ [255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0],
+ [10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255],
+ [255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204],
+ [41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255],
+ [71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255],
+ [184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194],
+ [102, 255, 0], [92, 0, 255]]
+
+ >>> image_processor = AutoImageProcessor.from_pretrained("openmmlab/upernet-convnext-small")
+ >>> image_segmentor = UperNetForSemanticSegmentation.from_pretrained("openmmlab/upernet-convnext-small")
+
+ >>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-seg", torch_dtype=torch.float16)
+
+ >>> pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
+ "runwayml/stable-diffusion-inpainting", controlnet=controlnet, safety_checker=None, torch_dtype=torch.float16
+ )
+
+ >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
+ >>> pipe.enable_xformers_memory_efficient_attention()
+ >>> pipe.enable_model_cpu_offload()
+
+ >>> def image_to_seg(image):
+ pixel_values = image_processor(image, return_tensors="pt").pixel_values
+ with torch.no_grad():
+ outputs = image_segmentor(pixel_values)
+ seg = image_processor.post_process_semantic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
+ color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) # height, width, 3
+ palette = np.array(ade_palette())
+ for label, color in enumerate(palette):
+ color_seg[seg == label, :] = color
+ color_seg = color_seg.astype(np.uint8)
+ seg_image = Image.fromarray(color_seg)
+ return seg_image
+
+ >>> image = load_image(
+ "https://github.com/CompVis/latent-diffusion/raw/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
+ )
+
+ >>> mask_image = load_image(
+ "https://github.com/CompVis/latent-diffusion/raw/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
+ )
+
+ >>> controlnet_conditioning_image = image_to_seg(image)
+
+ >>> image = pipe(
+ "Face of a yellow cat, high resolution, sitting on a park bench",
+ image,
+ mask_image,
+ controlnet_conditioning_image,
+ num_inference_steps=20,
+ ).images[0]
+
+ >>> image.save("out.png")
+ ```
+"""
+
+
+def prepare_image(image):
+ if isinstance(image, torch.Tensor):
+ # Batch single image
+ if image.ndim == 3:
+ image = image.unsqueeze(0)
+
+ image = image.to(dtype=torch.float32)
+ else:
+ # preprocess image
+ if isinstance(image, (PIL.Image.Image, np.ndarray)):
+ image = [image]
+
+ if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):
+ image = [np.array(i.convert("RGB"))[None, :] for i in image]
+ image = np.concatenate(image, axis=0)
+ elif isinstance(image, list) and isinstance(image[0], np.ndarray):
+ image = np.concatenate([i[None, :] for i in image], axis=0)
+
+ image = image.transpose(0, 3, 1, 2)
+ image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
+
+ return image
+
+
+def prepare_mask_image(mask_image):
+ if isinstance(mask_image, torch.Tensor):
+ if mask_image.ndim == 2:
+ # Batch and add channel dim for single mask
+ mask_image = mask_image.unsqueeze(0).unsqueeze(0)
+ elif mask_image.ndim == 3 and mask_image.shape[0] == 1:
+ # Single mask, the 0'th dimension is considered to be
+ # the existing batch size of 1
+ mask_image = mask_image.unsqueeze(0)
+ elif mask_image.ndim == 3 and mask_image.shape[0] != 1:
+ # Batch of mask, the 0'th dimension is considered to be
+ # the batching dimension
+ mask_image = mask_image.unsqueeze(1)
+
+ # Binarize mask
+ mask_image[mask_image < 0.5] = 0
+ mask_image[mask_image >= 0.5] = 1
+ else:
+ # preprocess mask
+ if isinstance(mask_image, (PIL.Image.Image, np.ndarray)):
+ mask_image = [mask_image]
+
+ if isinstance(mask_image, list) and isinstance(mask_image[0], PIL.Image.Image):
+ mask_image = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask_image], axis=0)
+ mask_image = mask_image.astype(np.float32) / 255.0
+ elif isinstance(mask_image, list) and isinstance(mask_image[0], np.ndarray):
+ mask_image = np.concatenate([m[None, None, :] for m in mask_image], axis=0)
+
+ mask_image[mask_image < 0.5] = 0
+ mask_image[mask_image >= 0.5] = 1
+ mask_image = torch.from_numpy(mask_image)
+
+ return mask_image
+
+
+def prepare_controlnet_conditioning_image(
+ controlnet_conditioning_image,
+ width,
+ height,
+ batch_size,
+ num_images_per_prompt,
+ device,
+ dtype,
+ do_classifier_free_guidance,
+):
+ if not isinstance(controlnet_conditioning_image, torch.Tensor):
+ if isinstance(controlnet_conditioning_image, PIL.Image.Image):
+ controlnet_conditioning_image = [controlnet_conditioning_image]
+
+ if isinstance(controlnet_conditioning_image[0], PIL.Image.Image):
+ controlnet_conditioning_image = [
+ np.array(i.resize((width, height), resample=PIL_INTERPOLATION["lanczos"]))[None, :]
+ for i in controlnet_conditioning_image
+ ]
+ controlnet_conditioning_image = np.concatenate(controlnet_conditioning_image, axis=0)
+ controlnet_conditioning_image = np.array(controlnet_conditioning_image).astype(np.float32) / 255.0
+ controlnet_conditioning_image = controlnet_conditioning_image.transpose(0, 3, 1, 2)
+ controlnet_conditioning_image = torch.from_numpy(controlnet_conditioning_image)
+ elif isinstance(controlnet_conditioning_image[0], torch.Tensor):
+ controlnet_conditioning_image = torch.cat(controlnet_conditioning_image, dim=0)
+
+ image_batch_size = controlnet_conditioning_image.shape[0]
+
+ if image_batch_size == 1:
+ repeat_by = batch_size
+ else:
+ # image batch size is the same as prompt batch size
+ repeat_by = num_images_per_prompt
+
+ controlnet_conditioning_image = controlnet_conditioning_image.repeat_interleave(repeat_by, dim=0)
+
+ controlnet_conditioning_image = controlnet_conditioning_image.to(device=device, dtype=dtype)
+
+ if do_classifier_free_guidance:
+ controlnet_conditioning_image = torch.cat([controlnet_conditioning_image] * 2)
+
+ return controlnet_conditioning_image
+
+
+class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline, StableDiffusionMixin):
+ """
+ Inspired by: https://github.com/haofanwang/ControlNet-for-Diffusers/
+ """
+
+ _optional_components = ["safety_checker", "feature_extractor"]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],
+ scheduler: KarrasDiffusionSchedulers,
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPImageProcessor,
+ requires_safety_checker: bool = True,
+ ):
+ super().__init__()
+
+ if safety_checker is None and requires_safety_checker:
+ logger.warning(
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
+ )
+
+ if safety_checker is not None and feature_extractor is None:
+ raise ValueError(
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
+ )
+
+ if isinstance(controlnet, (list, tuple)):
+ controlnet = MultiControlNetModel(controlnet)
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ controlnet=controlnet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
+
+ def _encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt=None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead.
+ Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ """
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = self.tokenizer.batch_decode(
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
+ )
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = text_inputs.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ prompt_embeds = self.text_encoder(
+ text_input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ prompt_embeds = prompt_embeds[0]
+
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = uncond_input.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ negative_prompt_embeds = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ negative_prompt_embeds = negative_prompt_embeds[0]
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+
+ return prompt_embeds
+
+ def run_safety_checker(self, image, device, dtype):
+ if self.safety_checker is not None:
+ safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
+ image, has_nsfw_concept = self.safety_checker(
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
+ )
+ else:
+ has_nsfw_concept = None
+ return image, has_nsfw_concept
+
+ def decode_latents(self, latents):
+ latents = 1 / self.vae.config.scaling_factor * latents
+ image = self.vae.decode(latents).sample
+ image = (image / 2 + 0.5).clamp(0, 1)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+ return image
+
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_controlnet_conditioning_image(self, image, prompt, prompt_embeds):
+ image_is_pil = isinstance(image, PIL.Image.Image)
+ image_is_tensor = isinstance(image, torch.Tensor)
+ image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)
+ image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)
+
+ if not image_is_pil and not image_is_tensor and not image_is_pil_list and not image_is_tensor_list:
+ raise TypeError(
+ "image must be passed and be one of PIL image, torch tensor, list of PIL images, or list of torch tensors"
+ )
+
+ if image_is_pil:
+ image_batch_size = 1
+ elif image_is_tensor:
+ image_batch_size = image.shape[0]
+ elif image_is_pil_list:
+ image_batch_size = len(image)
+ elif image_is_tensor_list:
+ image_batch_size = len(image)
+ else:
+ raise ValueError("controlnet condition image is not valid")
+
+ if prompt is not None and isinstance(prompt, str):
+ prompt_batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ prompt_batch_size = len(prompt)
+ elif prompt_embeds is not None:
+ prompt_batch_size = prompt_embeds.shape[0]
+ else:
+ raise ValueError("prompt or prompt_embeds are not valid")
+
+ if image_batch_size != 1 and image_batch_size != prompt_batch_size:
+ raise ValueError(
+ f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
+ )
+
+ def check_inputs(
+ self,
+ prompt,
+ image,
+ mask_image,
+ controlnet_conditioning_image,
+ height,
+ width,
+ callback_steps,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ controlnet_conditioning_scale=None,
+ ):
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ # check controlnet condition image
+ if isinstance(self.controlnet, ControlNetModel):
+ self.check_controlnet_conditioning_image(controlnet_conditioning_image, prompt, prompt_embeds)
+ elif isinstance(self.controlnet, MultiControlNetModel):
+ if not isinstance(controlnet_conditioning_image, list):
+ raise TypeError("For multiple controlnets: `image` must be type `list`")
+ if len(controlnet_conditioning_image) != len(self.controlnet.nets):
+ raise ValueError(
+ "For multiple controlnets: `image` must have the same length as the number of controlnets."
+ )
+ for image_ in controlnet_conditioning_image:
+ self.check_controlnet_conditioning_image(image_, prompt, prompt_embeds)
+ else:
+ assert False
+
+ # Check `controlnet_conditioning_scale`
+ if isinstance(self.controlnet, ControlNetModel):
+ if not isinstance(controlnet_conditioning_scale, float):
+ raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
+ elif isinstance(self.controlnet, MultiControlNetModel):
+ if isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len(
+ self.controlnet.nets
+ ):
+ raise ValueError(
+ "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have"
+ " the same length as the number of controlnets"
+ )
+ else:
+ assert False
+
+ if isinstance(image, torch.Tensor) and not isinstance(mask_image, torch.Tensor):
+ raise TypeError("if `image` is a tensor, `mask_image` must also be a tensor")
+
+ if isinstance(image, PIL.Image.Image) and not isinstance(mask_image, PIL.Image.Image):
+ raise TypeError("if `image` is a PIL image, `mask_image` must also be a PIL image")
+
+ if isinstance(image, torch.Tensor):
+ if image.ndim != 3 and image.ndim != 4:
+ raise ValueError("`image` must have 3 or 4 dimensions")
+
+ if mask_image.ndim != 2 and mask_image.ndim != 3 and mask_image.ndim != 4:
+ raise ValueError("`mask_image` must have 2, 3, or 4 dimensions")
+
+ if image.ndim == 3:
+ image_batch_size = 1
+ image_channels, image_height, image_width = image.shape
+ elif image.ndim == 4:
+ image_batch_size, image_channels, image_height, image_width = image.shape
+ else:
+ assert False
+
+ if mask_image.ndim == 2:
+ mask_image_batch_size = 1
+ mask_image_channels = 1
+ mask_image_height, mask_image_width = mask_image.shape
+ elif mask_image.ndim == 3:
+ mask_image_channels = 1
+ mask_image_batch_size, mask_image_height, mask_image_width = mask_image.shape
+ elif mask_image.ndim == 4:
+ mask_image_batch_size, mask_image_channels, mask_image_height, mask_image_width = mask_image.shape
+
+ if image_channels != 3:
+ raise ValueError("`image` must have 3 channels")
+
+ if mask_image_channels != 1:
+ raise ValueError("`mask_image` must have 1 channel")
+
+ if image_batch_size != mask_image_batch_size:
+ raise ValueError("`image` and `mask_image` mush have the same batch sizes")
+
+ if image_height != mask_image_height or image_width != mask_image_width:
+ raise ValueError("`image` and `mask_image` must have the same height and width dimensions")
+
+ if image.min() < -1 or image.max() > 1:
+ raise ValueError("`image` should be in range [-1, 1]")
+
+ if mask_image.min() < 0 or mask_image.max() > 1:
+ raise ValueError("`mask_image` should be in range [0, 1]")
+ else:
+ mask_image_channels = 1
+ image_channels = 3
+
+ single_image_latent_channels = self.vae.config.latent_channels
+
+ total_latent_channels = single_image_latent_channels * 2 + mask_image_channels
+
+ if total_latent_channels != self.unet.config.in_channels:
+ raise ValueError(
+ f"The config of `pipeline.unet` expects {self.unet.config.in_channels} but received"
+ f" non inpainting latent channels: {single_image_latent_channels},"
+ f" mask channels: {mask_image_channels}, and masked image channels: {single_image_latent_channels}."
+ f" Please verify the config of `pipeline.unet` and the `mask_image` and `image` inputs."
+ )
+
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
+ shape = (
+ batch_size,
+ num_channels_latents,
+ int(height) // self.vae_scale_factor,
+ int(width) // self.vae_scale_factor,
+ )
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+
+ return latents
+
+ def prepare_mask_latents(self, mask_image, batch_size, height, width, dtype, device, do_classifier_free_guidance):
+ # resize the mask to latents shape as we concatenate the mask to the latents
+ # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
+ # and half precision
+ mask_image = F.interpolate(mask_image, size=(height // self.vae_scale_factor, width // self.vae_scale_factor))
+ mask_image = mask_image.to(device=device, dtype=dtype)
+
+ # duplicate mask for each generation per prompt, using mps friendly method
+ if mask_image.shape[0] < batch_size:
+ if not batch_size % mask_image.shape[0] == 0:
+ raise ValueError(
+ "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
+ f" a total batch size of {batch_size}, but {mask_image.shape[0]} masks were passed. Make sure the number"
+ " of masks that you pass is divisible by the total requested batch size."
+ )
+ mask_image = mask_image.repeat(batch_size // mask_image.shape[0], 1, 1, 1)
+
+ mask_image = torch.cat([mask_image] * 2) if do_classifier_free_guidance else mask_image
+
+ mask_image_latents = mask_image
+
+ return mask_image_latents
+
+ def prepare_masked_image_latents(
+ self, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
+ ):
+ masked_image = masked_image.to(device=device, dtype=dtype)
+
+ # encode the mask image into latents space so we can concatenate it to the latents
+ if isinstance(generator, list):
+ masked_image_latents = [
+ self.vae.encode(masked_image[i : i + 1]).latent_dist.sample(generator=generator[i])
+ for i in range(batch_size)
+ ]
+ masked_image_latents = torch.cat(masked_image_latents, dim=0)
+ else:
+ masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(generator=generator)
+ masked_image_latents = self.vae.config.scaling_factor * masked_image_latents
+
+ # duplicate masked_image_latents for each generation per prompt, using mps friendly method
+ if masked_image_latents.shape[0] < batch_size:
+ if not batch_size % masked_image_latents.shape[0] == 0:
+ raise ValueError(
+ "The passed images and the required batch size don't match. Images are supposed to be duplicated"
+ f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
+ " Make sure the number of images that you pass is divisible by the total requested batch size."
+ )
+ masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1)
+
+ masked_image_latents = (
+ torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents
+ )
+
+ # aligning device to prevent device errors when concating it with the latent model input
+ masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
+ return masked_image_latents
+
+ def _default_height_width(self, height, width, image):
+ if isinstance(image, list):
+ image = image[0]
+
+ if height is None:
+ if isinstance(image, PIL.Image.Image):
+ height = image.height
+ elif isinstance(image, torch.Tensor):
+ height = image.shape[3]
+
+ height = (height // 8) * 8 # round down to nearest multiple of 8
+
+ if width is None:
+ if isinstance(image, PIL.Image.Image):
+ width = image.width
+ elif isinstance(image, torch.Tensor):
+ width = image.shape[2]
+
+ width = (width // 8) * 8 # round down to nearest multiple of 8
+
+ return height, width
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ image: Union[torch.Tensor, PIL.Image.Image] = None,
+ mask_image: Union[torch.Tensor, PIL.Image.Image] = None,
+ controlnet_conditioning_image: Union[
+ torch.Tensor, PIL.Image.Image, List[torch.Tensor], List[PIL.Image.Image]
+ ] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
+ callback_steps: int = 1,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ image (`torch.Tensor` or `PIL.Image.Image`):
+ `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will
+ be masked out with `mask_image` and repainted according to `prompt`.
+ mask_image (`torch.Tensor` or `PIL.Image.Image`):
+ `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
+ repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted
+ to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)
+ instead of 3, so the expected shape would be `(B, H, W, 1)`.
+ controlnet_conditioning_image (`torch.Tensor`, `PIL.Image.Image`, `List[torch.Tensor]` or `List[PIL.Image.Image]`):
+ The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If
+ the type is specified as `torch.Tensor`, it is passed to ControlNet as is. PIL.Image.Image` can
+ also be accepted as an image. The control image is automatically resized to fit the output image.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead.
+ Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0):
+ The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added
+ to the residual in the original unet.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+ # 0. Default height and width to unet
+ height, width = self._default_height_width(height, width, controlnet_conditioning_image)
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ image,
+ mask_image,
+ controlnet_conditioning_image,
+ height,
+ width,
+ callback_steps,
+ negative_prompt,
+ prompt_embeds,
+ negative_prompt_embeds,
+ controlnet_conditioning_scale,
+ )
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ if isinstance(self.controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
+ controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(self.controlnet.nets)
+
+ # 3. Encode input prompt
+ prompt_embeds = self._encode_prompt(
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ )
+
+ # 4. Prepare mask, image, and controlnet_conditioning_image
+ image = prepare_image(image)
+
+ mask_image = prepare_mask_image(mask_image)
+
+ # condition image(s)
+ if isinstance(self.controlnet, ControlNetModel):
+ controlnet_conditioning_image = prepare_controlnet_conditioning_image(
+ controlnet_conditioning_image=controlnet_conditioning_image,
+ width=width,
+ height=height,
+ batch_size=batch_size * num_images_per_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ dtype=self.controlnet.dtype,
+ do_classifier_free_guidance=do_classifier_free_guidance,
+ )
+ elif isinstance(self.controlnet, MultiControlNetModel):
+ controlnet_conditioning_images = []
+
+ for image_ in controlnet_conditioning_image:
+ image_ = prepare_controlnet_conditioning_image(
+ controlnet_conditioning_image=image_,
+ width=width,
+ height=height,
+ batch_size=batch_size * num_images_per_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ dtype=self.controlnet.dtype,
+ do_classifier_free_guidance=do_classifier_free_guidance,
+ )
+ controlnet_conditioning_images.append(image_)
+
+ controlnet_conditioning_image = controlnet_conditioning_images
+ else:
+ assert False
+
+ masked_image = image * (mask_image < 0.5)
+
+ # 5. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # 6. Prepare latent variables
+ num_channels_latents = self.vae.config.latent_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ mask_image_latents = self.prepare_mask_latents(
+ mask_image,
+ batch_size * num_images_per_prompt,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ do_classifier_free_guidance,
+ )
+
+ masked_image_latents = self.prepare_masked_image_latents(
+ masked_image,
+ batch_size * num_images_per_prompt,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ do_classifier_free_guidance,
+ )
+
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 8. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ non_inpainting_latent_model_input = (
+ torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ )
+
+ non_inpainting_latent_model_input = self.scheduler.scale_model_input(
+ non_inpainting_latent_model_input, t
+ )
+
+ inpainting_latent_model_input = torch.cat(
+ [non_inpainting_latent_model_input, mask_image_latents, masked_image_latents], dim=1
+ )
+
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
+ non_inpainting_latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ controlnet_cond=controlnet_conditioning_image,
+ conditioning_scale=controlnet_conditioning_scale,
+ return_dict=False,
+ )
+
+ # predict the noise residual
+ noise_pred = self.unet(
+ inpainting_latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs=cross_attention_kwargs,
+ down_block_additional_residuals=down_block_res_samples,
+ mid_block_additional_residual=mid_block_res_sample,
+ ).sample
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
+
+ # If we do sequential model offloading, let's offload unet and controlnet
+ # manually for max memory savings
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.unet.to("cpu")
+ self.controlnet.to("cpu")
+ torch.cuda.empty_cache()
+
+ if output_type == "latent":
+ image = latents
+ has_nsfw_concept = None
+ elif output_type == "pil":
+ # 8. Post-processing
+ image = self.decode_latents(latents)
+
+ # 9. Run safety checker
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
+
+ # 10. Convert to PIL
+ image = self.numpy_to_pil(image)
+ else:
+ # 8. Post-processing
+ image = self.decode_latents(latents)
+
+ # 9. Run safety checker
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
+
+ # Offload last model to CPU
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.final_offload_hook.offload()
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
diff --git a/diffusers/examples/community/stable_diffusion_controlnet_inpaint_img2img.py b/diffusers/examples/community/stable_diffusion_controlnet_inpaint_img2img.py
new file mode 100644
index 0000000000000000000000000000000000000000..8928f34239e3a32d7675fdf8590a17f0f0421451
--- /dev/null
+++ b/diffusers/examples/community/stable_diffusion_controlnet_inpaint_img2img.py
@@ -0,0 +1,1037 @@
+# Inspired by: https://github.com/haofanwang/ControlNet-for-Diffusers/
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import numpy as np
+import PIL.Image
+import torch
+import torch.nn.functional as F
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
+
+from diffusers import AutoencoderKL, ControlNetModel, UNet2DConditionModel, logging
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
+from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
+from diffusers.schedulers import KarrasDiffusionSchedulers
+from diffusers.utils import (
+ PIL_INTERPOLATION,
+ replace_example_docstring,
+)
+from diffusers.utils.torch_utils import randn_tensor
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import numpy as np
+ >>> import torch
+ >>> from PIL import Image
+ >>> from stable_diffusion_controlnet_inpaint_img2img import StableDiffusionControlNetInpaintImg2ImgPipeline
+
+ >>> from transformers import AutoImageProcessor, UperNetForSemanticSegmentation
+ >>> from diffusers import ControlNetModel, UniPCMultistepScheduler
+ >>> from diffusers.utils import load_image
+
+ >>> def ade_palette():
+ return [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50],
+ [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255],
+ [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7],
+ [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],
+ [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3],
+ [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255],
+ [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220],
+ [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224],
+ [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255],
+ [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7],
+ [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153],
+ [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],
+ [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0],
+ [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],
+ [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255],
+ [11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255],
+ [0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0],
+ [255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0],
+ [0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255],
+ [173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255],
+ [255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20],
+ [255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255],
+ [255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255],
+ [0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255],
+ [0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0],
+ [143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0],
+ [8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255],
+ [255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112],
+ [92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160],
+ [163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163],
+ [255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0],
+ [255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0],
+ [10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255],
+ [255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204],
+ [41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255],
+ [71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255],
+ [184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194],
+ [102, 255, 0], [92, 0, 255]]
+
+ >>> image_processor = AutoImageProcessor.from_pretrained("openmmlab/upernet-convnext-small")
+ >>> image_segmentor = UperNetForSemanticSegmentation.from_pretrained("openmmlab/upernet-convnext-small")
+
+ >>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-seg", torch_dtype=torch.float16)
+
+ >>> pipe = StableDiffusionControlNetInpaintImg2ImgPipeline.from_pretrained(
+ "runwayml/stable-diffusion-inpainting", controlnet=controlnet, safety_checker=None, torch_dtype=torch.float16
+ )
+
+ >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
+ >>> pipe.enable_xformers_memory_efficient_attention()
+ >>> pipe.enable_model_cpu_offload()
+
+ >>> def image_to_seg(image):
+ pixel_values = image_processor(image, return_tensors="pt").pixel_values
+ with torch.no_grad():
+ outputs = image_segmentor(pixel_values)
+ seg = image_processor.post_process_semantic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
+ color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) # height, width, 3
+ palette = np.array(ade_palette())
+ for label, color in enumerate(palette):
+ color_seg[seg == label, :] = color
+ color_seg = color_seg.astype(np.uint8)
+ seg_image = Image.fromarray(color_seg)
+ return seg_image
+
+ >>> image = load_image(
+ "https://github.com/CompVis/latent-diffusion/raw/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
+ )
+
+ >>> mask_image = load_image(
+ "https://github.com/CompVis/latent-diffusion/raw/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
+ )
+
+ >>> controlnet_conditioning_image = image_to_seg(image)
+
+ >>> image = pipe(
+ "Face of a yellow cat, high resolution, sitting on a park bench",
+ image,
+ mask_image,
+ controlnet_conditioning_image,
+ num_inference_steps=20,
+ ).images[0]
+
+ >>> image.save("out.png")
+ ```
+"""
+
+
+def prepare_image(image):
+ if isinstance(image, torch.Tensor):
+ # Batch single image
+ if image.ndim == 3:
+ image = image.unsqueeze(0)
+
+ image = image.to(dtype=torch.float32)
+ else:
+ # preprocess image
+ if isinstance(image, (PIL.Image.Image, np.ndarray)):
+ image = [image]
+
+ if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):
+ image = [np.array(i.convert("RGB"))[None, :] for i in image]
+ image = np.concatenate(image, axis=0)
+ elif isinstance(image, list) and isinstance(image[0], np.ndarray):
+ image = np.concatenate([i[None, :] for i in image], axis=0)
+
+ image = image.transpose(0, 3, 1, 2)
+ image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
+
+ return image
+
+
+def prepare_mask_image(mask_image):
+ if isinstance(mask_image, torch.Tensor):
+ if mask_image.ndim == 2:
+ # Batch and add channel dim for single mask
+ mask_image = mask_image.unsqueeze(0).unsqueeze(0)
+ elif mask_image.ndim == 3 and mask_image.shape[0] == 1:
+ # Single mask, the 0'th dimension is considered to be
+ # the existing batch size of 1
+ mask_image = mask_image.unsqueeze(0)
+ elif mask_image.ndim == 3 and mask_image.shape[0] != 1:
+ # Batch of mask, the 0'th dimension is considered to be
+ # the batching dimension
+ mask_image = mask_image.unsqueeze(1)
+
+ # Binarize mask
+ mask_image[mask_image < 0.5] = 0
+ mask_image[mask_image >= 0.5] = 1
+ else:
+ # preprocess mask
+ if isinstance(mask_image, (PIL.Image.Image, np.ndarray)):
+ mask_image = [mask_image]
+
+ if isinstance(mask_image, list) and isinstance(mask_image[0], PIL.Image.Image):
+ mask_image = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask_image], axis=0)
+ mask_image = mask_image.astype(np.float32) / 255.0
+ elif isinstance(mask_image, list) and isinstance(mask_image[0], np.ndarray):
+ mask_image = np.concatenate([m[None, None, :] for m in mask_image], axis=0)
+
+ mask_image[mask_image < 0.5] = 0
+ mask_image[mask_image >= 0.5] = 1
+ mask_image = torch.from_numpy(mask_image)
+
+ return mask_image
+
+
+def prepare_controlnet_conditioning_image(
+ controlnet_conditioning_image, width, height, batch_size, num_images_per_prompt, device, dtype
+):
+ if not isinstance(controlnet_conditioning_image, torch.Tensor):
+ if isinstance(controlnet_conditioning_image, PIL.Image.Image):
+ controlnet_conditioning_image = [controlnet_conditioning_image]
+
+ if isinstance(controlnet_conditioning_image[0], PIL.Image.Image):
+ controlnet_conditioning_image = [
+ np.array(i.resize((width, height), resample=PIL_INTERPOLATION["lanczos"]))[None, :]
+ for i in controlnet_conditioning_image
+ ]
+ controlnet_conditioning_image = np.concatenate(controlnet_conditioning_image, axis=0)
+ controlnet_conditioning_image = np.array(controlnet_conditioning_image).astype(np.float32) / 255.0
+ controlnet_conditioning_image = controlnet_conditioning_image.transpose(0, 3, 1, 2)
+ controlnet_conditioning_image = torch.from_numpy(controlnet_conditioning_image)
+ elif isinstance(controlnet_conditioning_image[0], torch.Tensor):
+ controlnet_conditioning_image = torch.cat(controlnet_conditioning_image, dim=0)
+
+ image_batch_size = controlnet_conditioning_image.shape[0]
+
+ if image_batch_size == 1:
+ repeat_by = batch_size
+ else:
+ # image batch size is the same as prompt batch size
+ repeat_by = num_images_per_prompt
+
+ controlnet_conditioning_image = controlnet_conditioning_image.repeat_interleave(repeat_by, dim=0)
+
+ controlnet_conditioning_image = controlnet_conditioning_image.to(device=device, dtype=dtype)
+
+ return controlnet_conditioning_image
+
+
+class StableDiffusionControlNetInpaintImg2ImgPipeline(DiffusionPipeline, StableDiffusionMixin):
+ """
+ Inspired by: https://github.com/haofanwang/ControlNet-for-Diffusers/
+ """
+
+ _optional_components = ["safety_checker", "feature_extractor"]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ controlnet: ControlNetModel,
+ scheduler: KarrasDiffusionSchedulers,
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPImageProcessor,
+ requires_safety_checker: bool = True,
+ ):
+ super().__init__()
+
+ if safety_checker is None and requires_safety_checker:
+ logger.warning(
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
+ )
+
+ if safety_checker is not None and feature_extractor is None:
+ raise ValueError(
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
+ )
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ controlnet=controlnet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
+
+ def _encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt=None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead.
+ Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ """
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = self.tokenizer.batch_decode(
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
+ )
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = text_inputs.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ prompt_embeds = self.text_encoder(
+ text_input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ prompt_embeds = prompt_embeds[0]
+
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = uncond_input.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ negative_prompt_embeds = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ negative_prompt_embeds = negative_prompt_embeds[0]
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+
+ return prompt_embeds
+
+ def run_safety_checker(self, image, device, dtype):
+ if self.safety_checker is not None:
+ safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
+ image, has_nsfw_concept = self.safety_checker(
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
+ )
+ else:
+ has_nsfw_concept = None
+ return image, has_nsfw_concept
+
+ def decode_latents(self, latents):
+ latents = 1 / self.vae.config.scaling_factor * latents
+ image = self.vae.decode(latents).sample
+ image = (image / 2 + 0.5).clamp(0, 1)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+ return image
+
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(
+ self,
+ prompt,
+ image,
+ mask_image,
+ controlnet_conditioning_image,
+ height,
+ width,
+ callback_steps,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ strength=None,
+ ):
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ controlnet_cond_image_is_pil = isinstance(controlnet_conditioning_image, PIL.Image.Image)
+ controlnet_cond_image_is_tensor = isinstance(controlnet_conditioning_image, torch.Tensor)
+ controlnet_cond_image_is_pil_list = isinstance(controlnet_conditioning_image, list) and isinstance(
+ controlnet_conditioning_image[0], PIL.Image.Image
+ )
+ controlnet_cond_image_is_tensor_list = isinstance(controlnet_conditioning_image, list) and isinstance(
+ controlnet_conditioning_image[0], torch.Tensor
+ )
+
+ if (
+ not controlnet_cond_image_is_pil
+ and not controlnet_cond_image_is_tensor
+ and not controlnet_cond_image_is_pil_list
+ and not controlnet_cond_image_is_tensor_list
+ ):
+ raise TypeError(
+ "image must be passed and be one of PIL image, torch tensor, list of PIL images, or list of torch tensors"
+ )
+
+ if controlnet_cond_image_is_pil:
+ controlnet_cond_image_batch_size = 1
+ elif controlnet_cond_image_is_tensor:
+ controlnet_cond_image_batch_size = controlnet_conditioning_image.shape[0]
+ elif controlnet_cond_image_is_pil_list:
+ controlnet_cond_image_batch_size = len(controlnet_conditioning_image)
+ elif controlnet_cond_image_is_tensor_list:
+ controlnet_cond_image_batch_size = len(controlnet_conditioning_image)
+
+ if prompt is not None and isinstance(prompt, str):
+ prompt_batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ prompt_batch_size = len(prompt)
+ elif prompt_embeds is not None:
+ prompt_batch_size = prompt_embeds.shape[0]
+
+ if controlnet_cond_image_batch_size != 1 and controlnet_cond_image_batch_size != prompt_batch_size:
+ raise ValueError(
+ f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {controlnet_cond_image_batch_size}, prompt batch size: {prompt_batch_size}"
+ )
+
+ if isinstance(image, torch.Tensor) and not isinstance(mask_image, torch.Tensor):
+ raise TypeError("if `image` is a tensor, `mask_image` must also be a tensor")
+
+ if isinstance(image, PIL.Image.Image) and not isinstance(mask_image, PIL.Image.Image):
+ raise TypeError("if `image` is a PIL image, `mask_image` must also be a PIL image")
+
+ if isinstance(image, torch.Tensor):
+ if image.ndim != 3 and image.ndim != 4:
+ raise ValueError("`image` must have 3 or 4 dimensions")
+
+ if mask_image.ndim != 2 and mask_image.ndim != 3 and mask_image.ndim != 4:
+ raise ValueError("`mask_image` must have 2, 3, or 4 dimensions")
+
+ if image.ndim == 3:
+ image_batch_size = 1
+ image_channels, image_height, image_width = image.shape
+ elif image.ndim == 4:
+ image_batch_size, image_channels, image_height, image_width = image.shape
+
+ if mask_image.ndim == 2:
+ mask_image_batch_size = 1
+ mask_image_channels = 1
+ mask_image_height, mask_image_width = mask_image.shape
+ elif mask_image.ndim == 3:
+ mask_image_channels = 1
+ mask_image_batch_size, mask_image_height, mask_image_width = mask_image.shape
+ elif mask_image.ndim == 4:
+ mask_image_batch_size, mask_image_channels, mask_image_height, mask_image_width = mask_image.shape
+
+ if image_channels != 3:
+ raise ValueError("`image` must have 3 channels")
+
+ if mask_image_channels != 1:
+ raise ValueError("`mask_image` must have 1 channel")
+
+ if image_batch_size != mask_image_batch_size:
+ raise ValueError("`image` and `mask_image` mush have the same batch sizes")
+
+ if image_height != mask_image_height or image_width != mask_image_width:
+ raise ValueError("`image` and `mask_image` must have the same height and width dimensions")
+
+ if image.min() < -1 or image.max() > 1:
+ raise ValueError("`image` should be in range [-1, 1]")
+
+ if mask_image.min() < 0 or mask_image.max() > 1:
+ raise ValueError("`mask_image` should be in range [0, 1]")
+ else:
+ mask_image_channels = 1
+ image_channels = 3
+
+ single_image_latent_channels = self.vae.config.latent_channels
+
+ total_latent_channels = single_image_latent_channels * 2 + mask_image_channels
+
+ if total_latent_channels != self.unet.config.in_channels:
+ raise ValueError(
+ f"The config of `pipeline.unet` expects {self.unet.config.in_channels} but received"
+ f" non inpainting latent channels: {single_image_latent_channels},"
+ f" mask channels: {mask_image_channels}, and masked image channels: {single_image_latent_channels}."
+ f" Please verify the config of `pipeline.unet` and the `mask_image` and `image` inputs."
+ )
+
+ if strength < 0 or strength > 1:
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
+
+ def get_timesteps(self, num_inference_steps, strength, device):
+ # get the original timestep using init_timestep
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
+
+ t_start = max(num_inference_steps - init_timestep, 0)
+ timesteps = self.scheduler.timesteps[t_start:]
+
+ return timesteps, num_inference_steps - t_start
+
+ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
+ if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
+ raise ValueError(
+ f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
+ )
+
+ image = image.to(device=device, dtype=dtype)
+
+ batch_size = batch_size * num_images_per_prompt
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if isinstance(generator, list):
+ init_latents = [
+ self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
+ ]
+ init_latents = torch.cat(init_latents, dim=0)
+ else:
+ init_latents = self.vae.encode(image).latent_dist.sample(generator)
+
+ init_latents = self.vae.config.scaling_factor * init_latents
+
+ if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
+ raise ValueError(
+ f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
+ )
+ else:
+ init_latents = torch.cat([init_latents], dim=0)
+
+ shape = init_latents.shape
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+
+ # get latents
+ init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
+ latents = init_latents
+
+ return latents
+
+ def prepare_mask_latents(self, mask_image, batch_size, height, width, dtype, device, do_classifier_free_guidance):
+ # resize the mask to latents shape as we concatenate the mask to the latents
+ # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
+ # and half precision
+ mask_image = F.interpolate(mask_image, size=(height // self.vae_scale_factor, width // self.vae_scale_factor))
+ mask_image = mask_image.to(device=device, dtype=dtype)
+
+ # duplicate mask for each generation per prompt, using mps friendly method
+ if mask_image.shape[0] < batch_size:
+ if not batch_size % mask_image.shape[0] == 0:
+ raise ValueError(
+ "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
+ f" a total batch size of {batch_size}, but {mask_image.shape[0]} masks were passed. Make sure the number"
+ " of masks that you pass is divisible by the total requested batch size."
+ )
+ mask_image = mask_image.repeat(batch_size // mask_image.shape[0], 1, 1, 1)
+
+ mask_image = torch.cat([mask_image] * 2) if do_classifier_free_guidance else mask_image
+
+ mask_image_latents = mask_image
+
+ return mask_image_latents
+
+ def prepare_masked_image_latents(
+ self, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
+ ):
+ masked_image = masked_image.to(device=device, dtype=dtype)
+
+ # encode the mask image into latents space so we can concatenate it to the latents
+ if isinstance(generator, list):
+ masked_image_latents = [
+ self.vae.encode(masked_image[i : i + 1]).latent_dist.sample(generator=generator[i])
+ for i in range(batch_size)
+ ]
+ masked_image_latents = torch.cat(masked_image_latents, dim=0)
+ else:
+ masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(generator=generator)
+ masked_image_latents = self.vae.config.scaling_factor * masked_image_latents
+
+ # duplicate masked_image_latents for each generation per prompt, using mps friendly method
+ if masked_image_latents.shape[0] < batch_size:
+ if not batch_size % masked_image_latents.shape[0] == 0:
+ raise ValueError(
+ "The passed images and the required batch size don't match. Images are supposed to be duplicated"
+ f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
+ " Make sure the number of images that you pass is divisible by the total requested batch size."
+ )
+ masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1)
+
+ masked_image_latents = (
+ torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents
+ )
+
+ # aligning device to prevent device errors when concating it with the latent model input
+ masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
+ return masked_image_latents
+
+ def _default_height_width(self, height, width, image):
+ if isinstance(image, list):
+ image = image[0]
+
+ if height is None:
+ if isinstance(image, PIL.Image.Image):
+ height = image.height
+ elif isinstance(image, torch.Tensor):
+ height = image.shape[3]
+
+ height = (height // 8) * 8 # round down to nearest multiple of 8
+
+ if width is None:
+ if isinstance(image, PIL.Image.Image):
+ width = image.width
+ elif isinstance(image, torch.Tensor):
+ width = image.shape[2]
+
+ width = (width // 8) * 8 # round down to nearest multiple of 8
+
+ return height, width
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ image: Union[torch.Tensor, PIL.Image.Image] = None,
+ mask_image: Union[torch.Tensor, PIL.Image.Image] = None,
+ controlnet_conditioning_image: Union[
+ torch.Tensor, PIL.Image.Image, List[torch.Tensor], List[PIL.Image.Image]
+ ] = None,
+ strength: float = 0.8,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
+ callback_steps: int = 1,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ controlnet_conditioning_scale: float = 1.0,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ image (`torch.Tensor` or `PIL.Image.Image`):
+ `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will
+ be masked out with `mask_image` and repainted according to `prompt`.
+ mask_image (`torch.Tensor` or `PIL.Image.Image`):
+ `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
+ repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted
+ to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)
+ instead of 3, so the expected shape would be `(B, H, W, 1)`.
+ controlnet_conditioning_image (`torch.Tensor`, `PIL.Image.Image`, `List[torch.Tensor]` or `List[PIL.Image.Image]`):
+ The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If
+ the type is specified as `torch.Tensor`, it is passed to ControlNet as is. PIL.Image.Image` can
+ also be accepted as an image. The control image is automatically resized to fit the output image.
+ strength (`float`, *optional*):
+ Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image`
+ will be used as a starting point, adding more noise to it the larger the `strength`. The number of
+ denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will
+ be maximum and the denoising process will run for the full number of iterations specified in
+ `num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead.
+ Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0):
+ The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added
+ to the residual in the original unet.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+ # 0. Default height and width to unet
+ height, width = self._default_height_width(height, width, controlnet_conditioning_image)
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ image,
+ mask_image,
+ controlnet_conditioning_image,
+ height,
+ width,
+ callback_steps,
+ negative_prompt,
+ prompt_embeds,
+ negative_prompt_embeds,
+ strength,
+ )
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ prompt_embeds = self._encode_prompt(
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ )
+
+ # 4. Prepare mask, image, and controlnet_conditioning_image
+ image = prepare_image(image)
+
+ mask_image = prepare_mask_image(mask_image)
+
+ controlnet_conditioning_image = prepare_controlnet_conditioning_image(
+ controlnet_conditioning_image,
+ width,
+ height,
+ batch_size * num_images_per_prompt,
+ num_images_per_prompt,
+ device,
+ self.controlnet.dtype,
+ )
+
+ masked_image = image * (mask_image < 0.5)
+
+ # 5. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
+
+ # 6. Prepare latent variables
+ if latents is None:
+ latents = self.prepare_latents(
+ image,
+ latent_timestep,
+ batch_size,
+ num_images_per_prompt,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ )
+
+ mask_image_latents = self.prepare_mask_latents(
+ mask_image,
+ batch_size * num_images_per_prompt,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ do_classifier_free_guidance,
+ )
+
+ masked_image_latents = self.prepare_masked_image_latents(
+ masked_image,
+ batch_size * num_images_per_prompt,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ do_classifier_free_guidance,
+ )
+
+ if do_classifier_free_guidance:
+ controlnet_conditioning_image = torch.cat([controlnet_conditioning_image] * 2)
+
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 8. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ non_inpainting_latent_model_input = (
+ torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ )
+
+ non_inpainting_latent_model_input = self.scheduler.scale_model_input(
+ non_inpainting_latent_model_input, t
+ )
+
+ inpainting_latent_model_input = torch.cat(
+ [non_inpainting_latent_model_input, mask_image_latents, masked_image_latents], dim=1
+ )
+
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
+ non_inpainting_latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ controlnet_cond=controlnet_conditioning_image,
+ return_dict=False,
+ )
+
+ down_block_res_samples = [
+ down_block_res_sample * controlnet_conditioning_scale
+ for down_block_res_sample in down_block_res_samples
+ ]
+ mid_block_res_sample *= controlnet_conditioning_scale
+
+ # predict the noise residual
+ noise_pred = self.unet(
+ inpainting_latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs=cross_attention_kwargs,
+ down_block_additional_residuals=down_block_res_samples,
+ mid_block_additional_residual=mid_block_res_sample,
+ ).sample
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
+
+ # If we do sequential model offloading, let's offload unet and controlnet
+ # manually for max memory savings
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.unet.to("cpu")
+ self.controlnet.to("cpu")
+ torch.cuda.empty_cache()
+
+ if output_type == "latent":
+ image = latents
+ has_nsfw_concept = None
+ elif output_type == "pil":
+ # 8. Post-processing
+ image = self.decode_latents(latents)
+
+ # 9. Run safety checker
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
+
+ # 10. Convert to PIL
+ image = self.numpy_to_pil(image)
+ else:
+ # 8. Post-processing
+ image = self.decode_latents(latents)
+
+ # 9. Run safety checker
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
+
+ # Offload last model to CPU
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.final_offload_hook.offload()
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
diff --git a/diffusers/examples/community/stable_diffusion_controlnet_reference.py b/diffusers/examples/community/stable_diffusion_controlnet_reference.py
new file mode 100644
index 0000000000000000000000000000000000000000..577c7712e7714598f611a7fe07de04238dc4fe57
--- /dev/null
+++ b/diffusers/examples/community/stable_diffusion_controlnet_reference.py
@@ -0,0 +1,838 @@
+# Inspired by: https://github.com/Mikubill/sd-webui-controlnet/discussions/1236 and https://github.com/Mikubill/sd-webui-controlnet/discussions/1280
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import PIL.Image
+import torch
+
+from diffusers import StableDiffusionControlNetPipeline
+from diffusers.models import ControlNetModel
+from diffusers.models.attention import BasicTransformerBlock
+from diffusers.models.unets.unet_2d_blocks import CrossAttnDownBlock2D, CrossAttnUpBlock2D, DownBlock2D, UpBlock2D
+from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
+from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
+from diffusers.utils import logging
+from diffusers.utils.torch_utils import is_compiled_module, randn_tensor
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import cv2
+ >>> import torch
+ >>> import numpy as np
+ >>> from PIL import Image
+ >>> from diffusers import UniPCMultistepScheduler
+ >>> from diffusers.utils import load_image
+
+ >>> input_image = load_image("https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png")
+
+ >>> # get canny image
+ >>> image = cv2.Canny(np.array(input_image), 100, 200)
+ >>> image = image[:, :, None]
+ >>> image = np.concatenate([image, image, image], axis=2)
+ >>> canny_image = Image.fromarray(image)
+
+ >>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16)
+ >>> pipe = StableDiffusionControlNetReferencePipeline.from_pretrained(
+ "runwayml/stable-diffusion-v1-5",
+ controlnet=controlnet,
+ safety_checker=None,
+ torch_dtype=torch.float16
+ ).to('cuda:0')
+
+ >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe_controlnet.scheduler.config)
+
+ >>> result_img = pipe(ref_image=input_image,
+ prompt="1girl",
+ image=canny_image,
+ num_inference_steps=20,
+ reference_attn=True,
+ reference_adain=True).images[0]
+
+ >>> result_img.show()
+ ```
+"""
+
+
+def torch_dfs(model: torch.nn.Module):
+ result = [model]
+ for child in model.children():
+ result += torch_dfs(child)
+ return result
+
+
+class StableDiffusionControlNetReferencePipeline(StableDiffusionControlNetPipeline):
+ def prepare_ref_latents(self, refimage, batch_size, dtype, device, generator, do_classifier_free_guidance):
+ refimage = refimage.to(device=device, dtype=dtype)
+
+ # encode the mask image into latents space so we can concatenate it to the latents
+ if isinstance(generator, list):
+ ref_image_latents = [
+ self.vae.encode(refimage[i : i + 1]).latent_dist.sample(generator=generator[i])
+ for i in range(batch_size)
+ ]
+ ref_image_latents = torch.cat(ref_image_latents, dim=0)
+ else:
+ ref_image_latents = self.vae.encode(refimage).latent_dist.sample(generator=generator)
+ ref_image_latents = self.vae.config.scaling_factor * ref_image_latents
+
+ # duplicate mask and ref_image_latents for each generation per prompt, using mps friendly method
+ if ref_image_latents.shape[0] < batch_size:
+ if not batch_size % ref_image_latents.shape[0] == 0:
+ raise ValueError(
+ "The passed images and the required batch size don't match. Images are supposed to be duplicated"
+ f" to a total batch size of {batch_size}, but {ref_image_latents.shape[0]} images were passed."
+ " Make sure the number of images that you pass is divisible by the total requested batch size."
+ )
+ ref_image_latents = ref_image_latents.repeat(batch_size // ref_image_latents.shape[0], 1, 1, 1)
+
+ ref_image_latents = torch.cat([ref_image_latents] * 2) if do_classifier_free_guidance else ref_image_latents
+
+ # aligning device to prevent device errors when concating it with the latent model input
+ ref_image_latents = ref_image_latents.to(device=device, dtype=dtype)
+ return ref_image_latents
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ image: Union[
+ torch.Tensor,
+ PIL.Image.Image,
+ np.ndarray,
+ List[torch.Tensor],
+ List[PIL.Image.Image],
+ List[np.ndarray],
+ ] = None,
+ ref_image: Union[torch.Tensor, PIL.Image.Image] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
+ callback_steps: int = 1,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
+ guess_mode: bool = False,
+ attention_auto_machine_weight: float = 1.0,
+ gn_auto_machine_weight: float = 1.0,
+ style_fidelity: float = 0.5,
+ reference_attn: bool = True,
+ reference_adain: bool = True,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
+ `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
+ The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If
+ the type is specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can
+ also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If
+ height and/or width are passed, `image` is resized according to them. If multiple ControlNets are
+ specified in init, images must be passed as a list such that each element of the list can be correctly
+ batched for input to a single controlnet.
+ ref_image (`torch.Tensor`, `PIL.Image.Image`):
+ The Reference Control input condition. Reference Control uses this input condition to generate guidance to Unet. If
+ the type is specified as `torch.Tensor`, it is passed to Reference Control as is. `PIL.Image.Image` can
+ also be accepted as an image.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
+ The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added
+ to the residual in the original unet. If multiple ControlNets are specified in init, you can set the
+ corresponding scale as a list.
+ guess_mode (`bool`, *optional*, defaults to `False`):
+ In this mode, the ControlNet encoder will try best to recognize the content of the input image even if
+ you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended.
+ attention_auto_machine_weight (`float`):
+ Weight of using reference query for self attention's context.
+ If attention_auto_machine_weight=1.0, use reference query for all self attention's context.
+ gn_auto_machine_weight (`float`):
+ Weight of using reference adain. If gn_auto_machine_weight=2.0, use all reference adain plugins.
+ style_fidelity (`float`):
+ style fidelity of ref_uncond_xt. If style_fidelity=1.0, control more important,
+ elif style_fidelity=0.0, prompt more important, else balanced.
+ reference_attn (`bool`):
+ Whether to use reference query for self attention's context.
+ reference_adain (`bool`):
+ Whether to use reference adain.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+ assert reference_attn or reference_adain, "`reference_attn` or `reference_adain` must be True."
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ image,
+ callback_steps,
+ negative_prompt,
+ prompt_embeds,
+ negative_prompt_embeds,
+ controlnet_conditioning_scale,
+ )
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
+
+ if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
+ controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
+
+ global_pool_conditions = (
+ controlnet.config.global_pool_conditions
+ if isinstance(controlnet, ControlNetModel)
+ else controlnet.nets[0].config.global_pool_conditions
+ )
+ guess_mode = guess_mode or global_pool_conditions
+
+ # 3. Encode input prompt
+ text_encoder_lora_scale = (
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
+ )
+ prompt_embeds = self._encode_prompt(
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ lora_scale=text_encoder_lora_scale,
+ )
+
+ # 4. Prepare image
+ if isinstance(controlnet, ControlNetModel):
+ image = self.prepare_image(
+ image=image,
+ width=width,
+ height=height,
+ batch_size=batch_size * num_images_per_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ dtype=controlnet.dtype,
+ do_classifier_free_guidance=do_classifier_free_guidance,
+ guess_mode=guess_mode,
+ )
+ height, width = image.shape[-2:]
+ elif isinstance(controlnet, MultiControlNetModel):
+ images = []
+
+ for image_ in image:
+ image_ = self.prepare_image(
+ image=image_,
+ width=width,
+ height=height,
+ batch_size=batch_size * num_images_per_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ dtype=controlnet.dtype,
+ do_classifier_free_guidance=do_classifier_free_guidance,
+ guess_mode=guess_mode,
+ )
+
+ images.append(image_)
+
+ image = images
+ height, width = image[0].shape[-2:]
+ else:
+ assert False
+
+ # 5. Preprocess reference image
+ ref_image = self.prepare_image(
+ image=ref_image,
+ width=width,
+ height=height,
+ batch_size=batch_size * num_images_per_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ dtype=prompt_embeds.dtype,
+ )
+
+ # 6. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # 7. Prepare latent variables
+ num_channels_latents = self.unet.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 8. Prepare reference latent variables
+ ref_image_latents = self.prepare_ref_latents(
+ ref_image,
+ batch_size * num_images_per_prompt,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ do_classifier_free_guidance,
+ )
+
+ # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 9. Modify self attention and group norm
+ MODE = "write"
+ uc_mask = (
+ torch.Tensor([1] * batch_size * num_images_per_prompt + [0] * batch_size * num_images_per_prompt)
+ .type_as(ref_image_latents)
+ .bool()
+ )
+
+ def hacked_basic_transformer_inner_forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ timestep: Optional[torch.LongTensor] = None,
+ cross_attention_kwargs: Dict[str, Any] = None,
+ class_labels: Optional[torch.LongTensor] = None,
+ ):
+ if self.use_ada_layer_norm:
+ norm_hidden_states = self.norm1(hidden_states, timestep)
+ elif self.use_ada_layer_norm_zero:
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
+ )
+ else:
+ norm_hidden_states = self.norm1(hidden_states)
+
+ # 1. Self-Attention
+ cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
+ if self.only_cross_attention:
+ attn_output = self.attn1(
+ norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
+ attention_mask=attention_mask,
+ **cross_attention_kwargs,
+ )
+ else:
+ if MODE == "write":
+ self.bank.append(norm_hidden_states.detach().clone())
+ attn_output = self.attn1(
+ norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
+ attention_mask=attention_mask,
+ **cross_attention_kwargs,
+ )
+ if MODE == "read":
+ if attention_auto_machine_weight > self.attn_weight:
+ attn_output_uc = self.attn1(
+ norm_hidden_states,
+ encoder_hidden_states=torch.cat([norm_hidden_states] + self.bank, dim=1),
+ # attention_mask=attention_mask,
+ **cross_attention_kwargs,
+ )
+ attn_output_c = attn_output_uc.clone()
+ if do_classifier_free_guidance and style_fidelity > 0:
+ attn_output_c[uc_mask] = self.attn1(
+ norm_hidden_states[uc_mask],
+ encoder_hidden_states=norm_hidden_states[uc_mask],
+ **cross_attention_kwargs,
+ )
+ attn_output = style_fidelity * attn_output_c + (1.0 - style_fidelity) * attn_output_uc
+ self.bank.clear()
+ else:
+ attn_output = self.attn1(
+ norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
+ attention_mask=attention_mask,
+ **cross_attention_kwargs,
+ )
+ if self.use_ada_layer_norm_zero:
+ attn_output = gate_msa.unsqueeze(1) * attn_output
+ hidden_states = attn_output + hidden_states
+
+ if self.attn2 is not None:
+ norm_hidden_states = (
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
+ )
+
+ # 2. Cross-Attention
+ attn_output = self.attn2(
+ norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=encoder_attention_mask,
+ **cross_attention_kwargs,
+ )
+ hidden_states = attn_output + hidden_states
+
+ # 3. Feed-forward
+ norm_hidden_states = self.norm3(hidden_states)
+
+ if self.use_ada_layer_norm_zero:
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
+
+ ff_output = self.ff(norm_hidden_states)
+
+ if self.use_ada_layer_norm_zero:
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
+
+ hidden_states = ff_output + hidden_states
+
+ return hidden_states
+
+ def hacked_mid_forward(self, *args, **kwargs):
+ eps = 1e-6
+ x = self.original_forward(*args, **kwargs)
+ if MODE == "write":
+ if gn_auto_machine_weight >= self.gn_weight:
+ var, mean = torch.var_mean(x, dim=(2, 3), keepdim=True, correction=0)
+ self.mean_bank.append(mean)
+ self.var_bank.append(var)
+ if MODE == "read":
+ if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
+ var, mean = torch.var_mean(x, dim=(2, 3), keepdim=True, correction=0)
+ std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
+ mean_acc = sum(self.mean_bank) / float(len(self.mean_bank))
+ var_acc = sum(self.var_bank) / float(len(self.var_bank))
+ std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
+ x_uc = (((x - mean) / std) * std_acc) + mean_acc
+ x_c = x_uc.clone()
+ if do_classifier_free_guidance and style_fidelity > 0:
+ x_c[uc_mask] = x[uc_mask]
+ x = style_fidelity * x_c + (1.0 - style_fidelity) * x_uc
+ self.mean_bank = []
+ self.var_bank = []
+ return x
+
+ def hack_CrossAttnDownBlock2D_forward(
+ self,
+ hidden_states: torch.Tensor,
+ temb: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ ):
+ eps = 1e-6
+
+ # TODO(Patrick, William) - attention mask is not used
+ output_states = ()
+
+ for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)):
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ cross_attention_kwargs=cross_attention_kwargs,
+ attention_mask=attention_mask,
+ encoder_attention_mask=encoder_attention_mask,
+ return_dict=False,
+ )[0]
+ if MODE == "write":
+ if gn_auto_machine_weight >= self.gn_weight:
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
+ self.mean_bank.append([mean])
+ self.var_bank.append([var])
+ if MODE == "read":
+ if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
+ std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
+ mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i]))
+ var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i]))
+ std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
+ hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc
+ hidden_states_c = hidden_states_uc.clone()
+ if do_classifier_free_guidance and style_fidelity > 0:
+ hidden_states_c[uc_mask] = hidden_states[uc_mask]
+ hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc
+
+ output_states = output_states + (hidden_states,)
+
+ if MODE == "read":
+ self.mean_bank = []
+ self.var_bank = []
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ output_states = output_states + (hidden_states,)
+
+ return hidden_states, output_states
+
+ def hacked_DownBlock2D_forward(self, hidden_states, temb=None, *args, **kwargs):
+ eps = 1e-6
+
+ output_states = ()
+
+ for i, resnet in enumerate(self.resnets):
+ hidden_states = resnet(hidden_states, temb)
+
+ if MODE == "write":
+ if gn_auto_machine_weight >= self.gn_weight:
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
+ self.mean_bank.append([mean])
+ self.var_bank.append([var])
+ if MODE == "read":
+ if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
+ std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
+ mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i]))
+ var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i]))
+ std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
+ hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc
+ hidden_states_c = hidden_states_uc.clone()
+ if do_classifier_free_guidance and style_fidelity > 0:
+ hidden_states_c[uc_mask] = hidden_states[uc_mask]
+ hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc
+
+ output_states = output_states + (hidden_states,)
+
+ if MODE == "read":
+ self.mean_bank = []
+ self.var_bank = []
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ output_states = output_states + (hidden_states,)
+
+ return hidden_states, output_states
+
+ def hacked_CrossAttnUpBlock2D_forward(
+ self,
+ hidden_states: torch.Tensor,
+ res_hidden_states_tuple: Tuple[torch.Tensor, ...],
+ temb: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ upsample_size: Optional[int] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ ):
+ eps = 1e-6
+ # TODO(Patrick, William) - attention mask is not used
+ for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)):
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ cross_attention_kwargs=cross_attention_kwargs,
+ attention_mask=attention_mask,
+ encoder_attention_mask=encoder_attention_mask,
+ return_dict=False,
+ )[0]
+
+ if MODE == "write":
+ if gn_auto_machine_weight >= self.gn_weight:
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
+ self.mean_bank.append([mean])
+ self.var_bank.append([var])
+ if MODE == "read":
+ if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
+ std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
+ mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i]))
+ var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i]))
+ std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
+ hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc
+ hidden_states_c = hidden_states_uc.clone()
+ if do_classifier_free_guidance and style_fidelity > 0:
+ hidden_states_c[uc_mask] = hidden_states[uc_mask]
+ hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc
+
+ if MODE == "read":
+ self.mean_bank = []
+ self.var_bank = []
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states, upsample_size)
+
+ return hidden_states
+
+ def hacked_UpBlock2D_forward(
+ self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, *args, **kwargs
+ ):
+ eps = 1e-6
+ for i, resnet in enumerate(self.resnets):
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+ hidden_states = resnet(hidden_states, temb)
+
+ if MODE == "write":
+ if gn_auto_machine_weight >= self.gn_weight:
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
+ self.mean_bank.append([mean])
+ self.var_bank.append([var])
+ if MODE == "read":
+ if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
+ std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
+ mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i]))
+ var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i]))
+ std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
+ hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc
+ hidden_states_c = hidden_states_uc.clone()
+ if do_classifier_free_guidance and style_fidelity > 0:
+ hidden_states_c[uc_mask] = hidden_states[uc_mask]
+ hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc
+
+ if MODE == "read":
+ self.mean_bank = []
+ self.var_bank = []
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states, upsample_size)
+
+ return hidden_states
+
+ if reference_attn:
+ attn_modules = [module for module in torch_dfs(self.unet) if isinstance(module, BasicTransformerBlock)]
+ attn_modules = sorted(attn_modules, key=lambda x: -x.norm1.normalized_shape[0])
+
+ for i, module in enumerate(attn_modules):
+ module._original_inner_forward = module.forward
+ module.forward = hacked_basic_transformer_inner_forward.__get__(module, BasicTransformerBlock)
+ module.bank = []
+ module.attn_weight = float(i) / float(len(attn_modules))
+
+ if reference_adain:
+ gn_modules = [self.unet.mid_block]
+ self.unet.mid_block.gn_weight = 0
+
+ down_blocks = self.unet.down_blocks
+ for w, module in enumerate(down_blocks):
+ module.gn_weight = 1.0 - float(w) / float(len(down_blocks))
+ gn_modules.append(module)
+
+ up_blocks = self.unet.up_blocks
+ for w, module in enumerate(up_blocks):
+ module.gn_weight = float(w) / float(len(up_blocks))
+ gn_modules.append(module)
+
+ for i, module in enumerate(gn_modules):
+ if getattr(module, "original_forward", None) is None:
+ module.original_forward = module.forward
+ if i == 0:
+ # mid_block
+ module.forward = hacked_mid_forward.__get__(module, torch.nn.Module)
+ elif isinstance(module, CrossAttnDownBlock2D):
+ module.forward = hack_CrossAttnDownBlock2D_forward.__get__(module, CrossAttnDownBlock2D)
+ elif isinstance(module, DownBlock2D):
+ module.forward = hacked_DownBlock2D_forward.__get__(module, DownBlock2D)
+ elif isinstance(module, CrossAttnUpBlock2D):
+ module.forward = hacked_CrossAttnUpBlock2D_forward.__get__(module, CrossAttnUpBlock2D)
+ elif isinstance(module, UpBlock2D):
+ module.forward = hacked_UpBlock2D_forward.__get__(module, UpBlock2D)
+ module.mean_bank = []
+ module.var_bank = []
+ module.gn_weight *= 2
+
+ # 11. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # controlnet(s) inference
+ if guess_mode and do_classifier_free_guidance:
+ # Infer ControlNet only for the conditional batch.
+ control_model_input = latents
+ control_model_input = self.scheduler.scale_model_input(control_model_input, t)
+ controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
+ else:
+ control_model_input = latent_model_input
+ controlnet_prompt_embeds = prompt_embeds
+
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
+ control_model_input,
+ t,
+ encoder_hidden_states=controlnet_prompt_embeds,
+ controlnet_cond=image,
+ conditioning_scale=controlnet_conditioning_scale,
+ guess_mode=guess_mode,
+ return_dict=False,
+ )
+
+ if guess_mode and do_classifier_free_guidance:
+ # Inferred ControlNet only for the conditional batch.
+ # To apply the output of ControlNet to both the unconditional and conditional batches,
+ # add 0 to the unconditional batch to keep it unchanged.
+ down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
+ mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
+
+ # ref only part
+ noise = randn_tensor(
+ ref_image_latents.shape, generator=generator, device=device, dtype=ref_image_latents.dtype
+ )
+ ref_xt = self.scheduler.add_noise(
+ ref_image_latents,
+ noise,
+ t.reshape(
+ 1,
+ ),
+ )
+ ref_xt = self.scheduler.scale_model_input(ref_xt, t)
+
+ MODE = "write"
+ self.unet(
+ ref_xt,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs=cross_attention_kwargs,
+ return_dict=False,
+ )
+
+ # predict the noise residual
+ MODE = "read"
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs=cross_attention_kwargs,
+ down_block_additional_residuals=down_block_res_samples,
+ mid_block_additional_residual=mid_block_res_sample,
+ return_dict=False,
+ )[0]
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
+
+ # If we do sequential model offloading, let's offload unet and controlnet
+ # manually for max memory savings
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.unet.to("cpu")
+ self.controlnet.to("cpu")
+ torch.cuda.empty_cache()
+
+ if not output_type == "latent":
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
+ else:
+ image = latents
+ has_nsfw_concept = None
+
+ if has_nsfw_concept is None:
+ do_denormalize = [True] * image.shape[0]
+ else:
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
+
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
+
+ # Offload last model to CPU
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.final_offload_hook.offload()
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
diff --git a/diffusers/examples/community/stable_diffusion_ipex.py b/diffusers/examples/community/stable_diffusion_ipex.py
new file mode 100644
index 0000000000000000000000000000000000000000..123892f6229a34da16a87708c8014e80f3ee3385
--- /dev/null
+++ b/diffusers/examples/community/stable_diffusion_ipex.py
@@ -0,0 +1,754 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import intel_extension_for_pytorch as ipex
+import torch
+from packaging import version
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
+
+from diffusers.configuration_utils import FrozenDict
+from diffusers.loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
+from diffusers.models import AutoencoderKL, UNet2DConditionModel
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
+from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
+from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
+from diffusers.schedulers import KarrasDiffusionSchedulers
+from diffusers.utils import (
+ deprecate,
+ logging,
+ replace_example_docstring,
+)
+from diffusers.utils.torch_utils import randn_tensor
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import StableDiffusionPipeline
+
+ >>> pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", custom_pipeline="stable_diffusion_ipex")
+
+ >>> # For Float32
+ >>> pipe.prepare_for_ipex(prompt, dtype=torch.float32, height=512, width=512) #value of image height/width should be consistent with the pipeline inference
+ >>> # For BFloat16
+ >>> pipe.prepare_for_ipex(prompt, dtype=torch.bfloat16, height=512, width=512) #value of image height/width should be consistent with the pipeline inference
+
+ >>> prompt = "a photo of an astronaut riding a horse on mars"
+ >>> # For Float32
+ >>> image = pipe(prompt, num_inference_steps=num_inference_steps, height=512, width=512).images[0] #value of image height/width should be consistent with 'prepare_for_ipex()'
+ >>> # For BFloat16
+ >>> with torch.cpu.amp.autocast(enabled=True, dtype=torch.bfloat16):
+ >>> image = pipe(prompt, num_inference_steps=num_inference_steps, height=512, width=512).images[0] #value of image height/width should be consistent with 'prepare_for_ipex()'
+ ```
+"""
+
+
+class StableDiffusionIPEXPipeline(
+ DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, StableDiffusionLoraLoaderMixin
+):
+ r"""
+ Pipeline for text-to-image generation using Stable Diffusion on IPEX.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ feature_extractor ([`CLIPImageProcessor`]):
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+ """
+
+ _optional_components = ["safety_checker", "feature_extractor"]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: KarrasDiffusionSchedulers,
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPImageProcessor,
+ requires_safety_checker: bool = True,
+ ):
+ super().__init__()
+
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
+ " file"
+ )
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["steps_offset"] = 1
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
+ )
+ deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["clip_sample"] = False
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if safety_checker is None and requires_safety_checker:
+ logger.warning(
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
+ )
+
+ if safety_checker is not None and feature_extractor is None:
+ raise ValueError(
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
+ )
+
+ is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
+ version.parse(unet.config._diffusers_version).base_version
+ ) < version.parse("0.9.0.dev0")
+ is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
+ deprecation_message = (
+ "The configuration file of the unet has set the default `sample_size` to smaller than"
+ " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
+ " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
+ " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
+ " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
+ " in the config might lead to incorrect results in future versions. If you have downloaded this"
+ " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
+ " the `unet/config.json` file"
+ )
+ deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(unet.config)
+ new_config["sample_size"] = 64
+ unet._internal_dict = FrozenDict(new_config)
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
+
+ def get_input_example(self, prompt, height=None, width=None, guidance_scale=7.5, num_images_per_prompt=1):
+ prompt_embeds = None
+ negative_prompt_embeds = None
+ negative_prompt = None
+ callback_steps = 1
+ generator = None
+ latents = None
+
+ # 0. Default height and width to unet
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
+ )
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+
+ device = "cpu"
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ prompt_embeds = self._encode_prompt(
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ )
+
+ # 5. Prepare latent variables
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ self.unet.config.in_channels,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+ dummy = torch.ones(1, dtype=torch.int32)
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, dummy)
+
+ unet_input_example = (latent_model_input, dummy, prompt_embeds)
+ vae_decoder_input_example = latents
+
+ return unet_input_example, vae_decoder_input_example
+
+ def prepare_for_ipex(self, promt, dtype=torch.float32, height=None, width=None, guidance_scale=7.5):
+ self.unet = self.unet.to(memory_format=torch.channels_last)
+ self.vae.decoder = self.vae.decoder.to(memory_format=torch.channels_last)
+ self.text_encoder = self.text_encoder.to(memory_format=torch.channels_last)
+ if self.safety_checker is not None:
+ self.safety_checker = self.safety_checker.to(memory_format=torch.channels_last)
+
+ unet_input_example, vae_decoder_input_example = self.get_input_example(promt, height, width, guidance_scale)
+
+ # optimize with ipex
+ if dtype == torch.bfloat16:
+ self.unet = ipex.optimize(self.unet.eval(), dtype=torch.bfloat16, inplace=True)
+ self.vae.decoder = ipex.optimize(self.vae.decoder.eval(), dtype=torch.bfloat16, inplace=True)
+ self.text_encoder = ipex.optimize(self.text_encoder.eval(), dtype=torch.bfloat16, inplace=True)
+ if self.safety_checker is not None:
+ self.safety_checker = ipex.optimize(self.safety_checker.eval(), dtype=torch.bfloat16, inplace=True)
+ elif dtype == torch.float32:
+ self.unet = ipex.optimize(
+ self.unet.eval(),
+ dtype=torch.float32,
+ inplace=True,
+ weights_prepack=True,
+ auto_kernel_selection=False,
+ )
+ self.vae.decoder = ipex.optimize(
+ self.vae.decoder.eval(),
+ dtype=torch.float32,
+ inplace=True,
+ weights_prepack=True,
+ auto_kernel_selection=False,
+ )
+ self.text_encoder = ipex.optimize(
+ self.text_encoder.eval(),
+ dtype=torch.float32,
+ inplace=True,
+ weights_prepack=True,
+ auto_kernel_selection=False,
+ )
+ if self.safety_checker is not None:
+ self.safety_checker = ipex.optimize(
+ self.safety_checker.eval(),
+ dtype=torch.float32,
+ inplace=True,
+ weights_prepack=True,
+ auto_kernel_selection=False,
+ )
+ else:
+ raise ValueError(" The value of 'dtype' should be 'torch.bfloat16' or 'torch.float32' !")
+
+ # trace unet model to get better performance on IPEX
+ with torch.cpu.amp.autocast(enabled=dtype == torch.bfloat16), torch.no_grad():
+ unet_trace_model = torch.jit.trace(self.unet, unet_input_example, check_trace=False, strict=False)
+ unet_trace_model = torch.jit.freeze(unet_trace_model)
+ self.unet.forward = unet_trace_model.forward
+
+ # trace vae.decoder model to get better performance on IPEX
+ with torch.cpu.amp.autocast(enabled=dtype == torch.bfloat16), torch.no_grad():
+ ave_decoder_trace_model = torch.jit.trace(
+ self.vae.decoder, vae_decoder_input_example, check_trace=False, strict=False
+ )
+ ave_decoder_trace_model = torch.jit.freeze(ave_decoder_trace_model)
+ self.vae.decoder.forward = ave_decoder_trace_model.forward
+
+ def _encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt=None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
+ Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ """
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ # textual inversion: process multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = self.tokenizer.batch_decode(
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
+ )
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = text_inputs.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ prompt_embeds = self.text_encoder(
+ text_input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ prompt_embeds = prompt_embeds[0]
+
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ # textual inversion: process multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = uncond_input.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ negative_prompt_embeds = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ negative_prompt_embeds = negative_prompt_embeds[0]
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+
+ return prompt_embeds
+
+ def run_safety_checker(self, image, device, dtype):
+ if self.safety_checker is not None:
+ safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
+ image, has_nsfw_concept = self.safety_checker(
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
+ )
+ else:
+ has_nsfw_concept = None
+ return image, has_nsfw_concept
+
+ def decode_latents(self, latents):
+ latents = 1 / self.vae.config.scaling_factor * latents
+ image = self.vae.decode(latents).sample
+ image = (image / 2 + 0.5).clamp(0, 1)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+ return image
+
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ callback_steps,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ ):
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
+ shape = (
+ batch_size,
+ num_channels_latents,
+ int(height) // self.vae_scale_factor,
+ int(width) // self.vae_scale_factor,
+ )
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
+ callback_steps: int = 1,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
+ Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+ # 0. Default height and width to unet
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
+ )
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ prompt_embeds = self._encode_prompt(
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ )
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.unet.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 7. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds)["sample"]
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
+
+ if output_type == "latent":
+ image = latents
+ has_nsfw_concept = None
+ elif output_type == "pil":
+ # 8. Post-processing
+ image = self.decode_latents(latents)
+
+ # 9. Run safety checker
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
+
+ # 10. Convert to PIL
+ image = self.numpy_to_pil(image)
+ else:
+ # 8. Post-processing
+ image = self.decode_latents(latents)
+
+ # 9. Run safety checker
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
+
+ # Offload last model to CPU
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.final_offload_hook.offload()
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
diff --git a/diffusers/examples/community/stable_diffusion_mega.py b/diffusers/examples/community/stable_diffusion_mega.py
new file mode 100644
index 0000000000000000000000000000000000000000..95b4b03e4de174c48ab9a687b20b7baee740879f
--- /dev/null
+++ b/diffusers/examples/community/stable_diffusion_mega.py
@@ -0,0 +1,202 @@
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import PIL.Image
+import torch
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
+
+from diffusers import (
+ AutoencoderKL,
+ DDIMScheduler,
+ DiffusionPipeline,
+ LMSDiscreteScheduler,
+ PNDMScheduler,
+ StableDiffusionImg2ImgPipeline,
+ StableDiffusionInpaintPipelineLegacy,
+ StableDiffusionPipeline,
+ UNet2DConditionModel,
+)
+from diffusers.configuration_utils import FrozenDict
+from diffusers.pipelines.pipeline_utils import StableDiffusionMixin
+from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
+from diffusers.utils import deprecate, logging
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class StableDiffusionMegaPipeline(DiffusionPipeline, StableDiffusionMixin):
+ r"""
+ Pipeline for text-to-image generation using Stable Diffusion.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionMegaSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ feature_extractor ([`CLIPImageProcessor`]):
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+ """
+
+ _optional_components = ["safety_checker", "feature_extractor"]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPImageProcessor,
+ requires_safety_checker: bool = True,
+ ):
+ super().__init__()
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
+ " file"
+ )
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["steps_offset"] = 1
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
+
+ @property
+ def components(self) -> Dict[str, Any]:
+ return {k: getattr(self, k) for k in self.config.keys() if not k.startswith("_")}
+
+ @torch.no_grad()
+ def inpaint(
+ self,
+ prompt: Union[str, List[str]],
+ image: Union[torch.Tensor, PIL.Image.Image],
+ mask_image: Union[torch.Tensor, PIL.Image.Image],
+ strength: float = 0.8,
+ num_inference_steps: Optional[int] = 50,
+ guidance_scale: Optional[float] = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: Optional[float] = 0.0,
+ generator: Optional[torch.Generator] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
+ callback_steps: int = 1,
+ ):
+ # For more information on how this function works, please see: https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion#diffusers.StableDiffusionImg2ImgPipeline
+ return StableDiffusionInpaintPipelineLegacy(**self.components)(
+ prompt=prompt,
+ image=image,
+ mask_image=mask_image,
+ strength=strength,
+ num_inference_steps=num_inference_steps,
+ guidance_scale=guidance_scale,
+ negative_prompt=negative_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ eta=eta,
+ generator=generator,
+ output_type=output_type,
+ return_dict=return_dict,
+ callback=callback,
+ )
+
+ @torch.no_grad()
+ def img2img(
+ self,
+ prompt: Union[str, List[str]],
+ image: Union[torch.Tensor, PIL.Image.Image],
+ strength: float = 0.8,
+ num_inference_steps: Optional[int] = 50,
+ guidance_scale: Optional[float] = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: Optional[float] = 0.0,
+ generator: Optional[torch.Generator] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
+ callback_steps: int = 1,
+ **kwargs,
+ ):
+ # For more information on how this function works, please see: https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion#diffusers.StableDiffusionImg2ImgPipeline
+ return StableDiffusionImg2ImgPipeline(**self.components)(
+ prompt=prompt,
+ image=image,
+ strength=strength,
+ num_inference_steps=num_inference_steps,
+ guidance_scale=guidance_scale,
+ negative_prompt=negative_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ eta=eta,
+ generator=generator,
+ output_type=output_type,
+ return_dict=return_dict,
+ callback=callback,
+ callback_steps=callback_steps,
+ )
+
+ @torch.no_grad()
+ def text2img(
+ self,
+ prompt: Union[str, List[str]],
+ height: int = 512,
+ width: int = 512,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[torch.Generator] = None,
+ latents: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
+ callback_steps: int = 1,
+ ):
+ # For more information on how this function https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion#diffusers.StableDiffusionPipeline
+ return StableDiffusionPipeline(**self.components)(
+ prompt=prompt,
+ height=height,
+ width=width,
+ num_inference_steps=num_inference_steps,
+ guidance_scale=guidance_scale,
+ negative_prompt=negative_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ eta=eta,
+ generator=generator,
+ latents=latents,
+ output_type=output_type,
+ return_dict=return_dict,
+ callback=callback,
+ callback_steps=callback_steps,
+ )
diff --git a/diffusers/examples/community/stable_diffusion_reference.py b/diffusers/examples/community/stable_diffusion_reference.py
new file mode 100644
index 0000000000000000000000000000000000000000..efb0fa89dbfc2e2e410de0ffb95594cd54300836
--- /dev/null
+++ b/diffusers/examples/community/stable_diffusion_reference.py
@@ -0,0 +1,1465 @@
+# Inspired by: https://github.com/Mikubill/sd-webui-controlnet/discussions/1236 and https://github.com/Mikubill/sd-webui-controlnet/discussions/1280
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import PIL.Image
+import torch
+from packaging import version
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
+
+from diffusers import AutoencoderKL, DiffusionPipeline, UNet2DConditionModel
+from diffusers.configuration_utils import FrozenDict, deprecate
+from diffusers.image_processor import VaeImageProcessor
+from diffusers.loaders import (
+ FromSingleFileMixin,
+ IPAdapterMixin,
+ StableDiffusionLoraLoaderMixin,
+ TextualInversionLoaderMixin,
+)
+from diffusers.models.attention import BasicTransformerBlock
+from diffusers.models.lora import adjust_lora_scale_text_encoder
+from diffusers.models.unets.unet_2d_blocks import CrossAttnDownBlock2D, CrossAttnUpBlock2D, DownBlock2D, UpBlock2D
+from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
+from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import rescale_noise_cfg
+from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
+from diffusers.schedulers import KarrasDiffusionSchedulers
+from diffusers.utils import (
+ PIL_INTERPOLATION,
+ USE_PEFT_BACKEND,
+ logging,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
+from diffusers.utils.torch_utils import randn_tensor
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import UniPCMultistepScheduler
+ >>> from diffusers.utils import load_image
+
+ >>> input_image = load_image("https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png")
+
+ >>> pipe = StableDiffusionReferencePipeline.from_pretrained(
+ "runwayml/stable-diffusion-v1-5",
+ safety_checker=None,
+ torch_dtype=torch.float16
+ ).to('cuda:0')
+
+ >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
+
+ >>> result_img = pipe(ref_image=input_image,
+ prompt="1girl",
+ num_inference_steps=20,
+ reference_attn=True,
+ reference_adain=True).images[0]
+
+ >>> result_img.show()
+ ```
+"""
+
+
+def torch_dfs(model: torch.nn.Module):
+ r"""
+ Performs a depth-first search on the given PyTorch model and returns a list of all its child modules.
+
+ Args:
+ model (torch.nn.Module): The PyTorch model to perform the depth-first search on.
+
+ Returns:
+ list: A list of all child modules of the given model.
+ """
+ result = [model]
+ for child in model.children():
+ result += torch_dfs(child)
+ return result
+
+
+class StableDiffusionReferencePipeline(
+ DiffusionPipeline, TextualInversionLoaderMixin, StableDiffusionLoraLoaderMixin, IPAdapterMixin, FromSingleFileMixin
+):
+ r"""
+ Pipeline for Stable Diffusion Reference.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
+
+ The pipeline also inherits the following loading methods:
+ - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
+ - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
+ - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
+ - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
+ - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ feature_extractor ([`CLIPImageProcessor`]):
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+ """
+
+ _optional_components = ["safety_checker", "feature_extractor"]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: KarrasDiffusionSchedulers,
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPImageProcessor,
+ requires_safety_checker: bool = True,
+ ):
+ super().__init__()
+
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
+ " file"
+ )
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["steps_offset"] = 1
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if hasattr(scheduler.config, "skip_prk_steps") and scheduler.config.skip_prk_steps is False:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration"
+ " `skip_prk_steps`. `skip_prk_steps` should be set to True in the configuration file. Please make"
+ " sure to update the config accordingly as not setting `skip_prk_steps` in the config might lead to"
+ " incorrect results in future versions. If you have downloaded this checkpoint from the Hugging Face"
+ " Hub, it would be very nice if you could open a Pull request for the"
+ " `scheduler/scheduler_config.json` file"
+ )
+ deprecate(
+ "skip_prk_steps not set",
+ "1.0.0",
+ deprecation_message,
+ standard_warn=False,
+ )
+ new_config = dict(scheduler.config)
+ new_config["skip_prk_steps"] = True
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if safety_checker is None and requires_safety_checker:
+ logger.warning(
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
+ )
+
+ if safety_checker is not None and feature_extractor is None:
+ raise ValueError(
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
+ )
+
+ is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
+ version.parse(unet.config._diffusers_version).base_version
+ ) < version.parse("0.9.0.dev0")
+ is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
+ deprecation_message = (
+ "The configuration file of the unet has set the default `sample_size` to smaller than"
+ " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
+ " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
+ " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
+ " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
+ " in the config might lead to incorrect results in future versions. If you have downloaded this"
+ " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
+ " the `unet/config.json` file"
+ )
+ deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(unet.config)
+ new_config["sample_size"] = 64
+ unet._internal_dict = FrozenDict(new_config)
+ # Check shapes, assume num_channels_latents == 4, num_channels_mask == 1, num_channels_masked == 4
+ if unet.config.in_channels != 4:
+ logger.warning(
+ f"You have loaded a UNet with {unet.config.in_channels} input channels, whereas by default,"
+ f" {self.__class__} assumes that `pipeline.unet` has 4 input channels: 4 for `num_channels_latents`,"
+ ". If you did not intend to modify"
+ " this behavior, please check whether you have loaded the right checkpoint."
+ )
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
+
+ def _default_height_width(
+ self,
+ height: Optional[int],
+ width: Optional[int],
+ image: Union[PIL.Image.Image, torch.Tensor, List[PIL.Image.Image]],
+ ) -> Tuple[int, int]:
+ r"""
+ Calculate the default height and width for the given image.
+
+ Args:
+ height (int or None): The desired height of the image. If None, the height will be determined based on the input image.
+ width (int or None): The desired width of the image. If None, the width will be determined based on the input image.
+ image (PIL.Image.Image or torch.Tensor or list[PIL.Image.Image]): The input image or a list of images.
+
+ Returns:
+ Tuple[int, int]: A tuple containing the calculated height and width.
+
+ """
+ # NOTE: It is possible that a list of images have different
+ # dimensions for each image, so just checking the first image
+ # is not _exactly_ correct, but it is simple.
+ while isinstance(image, list):
+ image = image[0]
+
+ if height is None:
+ if isinstance(image, PIL.Image.Image):
+ height = image.height
+ elif isinstance(image, torch.Tensor):
+ height = image.shape[2]
+
+ height = (height // 8) * 8 # round down to nearest multiple of 8
+
+ if width is None:
+ if isinstance(image, PIL.Image.Image):
+ width = image.width
+ elif isinstance(image, torch.Tensor):
+ width = image.shape[3]
+
+ width = (width // 8) * 8 # round down to nearest multiple of 8
+
+ return height, width
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs
+ def check_inputs(
+ self,
+ prompt: Optional[Union[str, List[str]]],
+ height: int,
+ width: int,
+ callback_steps: Optional[int],
+ negative_prompt: Optional[str] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ ip_adapter_image: Optional[torch.Tensor] = None,
+ ip_adapter_image_embeds: Optional[torch.Tensor] = None,
+ callback_on_step_end_tensor_inputs: Optional[List[str]] = None,
+ ) -> None:
+ """
+ Check the validity of the input arguments for the diffusion model.
+
+ Args:
+ prompt (Optional[Union[str, List[str]]]): The prompt text or list of prompt texts.
+ height (int): The height of the input image.
+ width (int): The width of the input image.
+ callback_steps (Optional[int]): The number of steps to perform the callback on.
+ negative_prompt (Optional[str]): The negative prompt text.
+ prompt_embeds (Optional[torch.Tensor]): The prompt embeddings.
+ negative_prompt_embeds (Optional[torch.Tensor]): The negative prompt embeddings.
+ ip_adapter_image (Optional[torch.Tensor]): The input adapter image.
+ ip_adapter_image_embeds (Optional[torch.Tensor]): The input adapter image embeddings.
+ callback_on_step_end_tensor_inputs (Optional[List[str]]): The list of tensor inputs to perform the callback on.
+
+ Raises:
+ ValueError: If `height` or `width` is not divisible by 8.
+ ValueError: If `callback_steps` is not a positive integer.
+ ValueError: If `callback_on_step_end_tensor_inputs` contains invalid tensor inputs.
+ ValueError: If both `prompt` and `prompt_embeds` are provided.
+ ValueError: If neither `prompt` nor `prompt_embeds` are provided.
+ ValueError: If `prompt` is not of type `str` or `list`.
+ ValueError: If both `negative_prompt` and `negative_prompt_embeds` are provided.
+ ValueError: If both `prompt_embeds` and `negative_prompt_embeds` are provided and have different shapes.
+ ValueError: If both `ip_adapter_image` and `ip_adapter_image_embeds` are provided.
+
+ Returns:
+ None
+ """
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
+ raise ValueError(
+ "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
+ )
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
+ def _encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ device: torch.device,
+ num_images_per_prompt: int,
+ do_classifier_free_guidance: bool,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ lora_scale: Optional[float] = None,
+ **kwargs,
+ ) -> torch.Tensor:
+ r"""
+ Encodes the prompt into embeddings.
+
+ Args:
+ prompt (Union[str, List[str]]): The prompt text or a list of prompt texts.
+ device (torch.device): The device to use for encoding.
+ num_images_per_prompt (int): The number of images per prompt.
+ do_classifier_free_guidance (bool): Whether to use classifier-free guidance.
+ negative_prompt (Optional[Union[str, List[str]]], optional): The negative prompt text or a list of negative prompt texts. Defaults to None.
+ prompt_embeds (Optional[torch.Tensor], optional): The prompt embeddings. Defaults to None.
+ negative_prompt_embeds (Optional[torch.Tensor], optional): The negative prompt embeddings. Defaults to None.
+ lora_scale (Optional[float], optional): The LoRA scale. Defaults to None.
+ **kwargs: Additional keyword arguments.
+
+ Returns:
+ torch.Tensor: The encoded prompt embeddings.
+ """
+ deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
+ deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
+
+ prompt_embeds_tuple = self.encode_prompt(
+ prompt=prompt,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ do_classifier_free_guidance=do_classifier_free_guidance,
+ negative_prompt=negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ lora_scale=lora_scale,
+ **kwargs,
+ )
+
+ # concatenate for backwards comp
+ prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
+
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt: Optional[str],
+ device: torch.device,
+ num_images_per_prompt: int,
+ do_classifier_free_guidance: bool,
+ negative_prompt: Optional[str] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ lora_scale: Optional[float] = None,
+ clip_skip: Optional[int] = None,
+ ) -> torch.Tensor:
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ lora_scale (`float`, *optional*):
+ A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+ """
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ # dynamically adjust the LoRA scale
+ if not USE_PEFT_BACKEND:
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
+ else:
+ scale_lora_layers(self.text_encoder, lora_scale)
+
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ # textual inversion: process multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = self.tokenizer.batch_decode(
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
+ )
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = text_inputs.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ if clip_skip is None:
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
+ prompt_embeds = prompt_embeds[0]
+ else:
+ prompt_embeds = self.text_encoder(
+ text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True
+ )
+ # Access the `hidden_states` first, that contains a tuple of
+ # all the hidden states from the encoder layers. Then index into
+ # the tuple to access the hidden states from the desired layer.
+ prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
+ # We also need to apply the final LayerNorm here to not mess with the
+ # representations. The `last_hidden_states` that we typically use for
+ # obtaining the final prompt representations passes through the LayerNorm
+ # layer.
+ prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
+
+ if self.text_encoder is not None:
+ prompt_embeds_dtype = self.text_encoder.dtype
+ elif self.unet is not None:
+ prompt_embeds_dtype = self.unet.dtype
+ else:
+ prompt_embeds_dtype = prompt_embeds.dtype
+
+ prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ # textual inversion: process multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = uncond_input.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ negative_prompt_embeds = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ negative_prompt_embeds = negative_prompt_embeds[0]
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder, lora_scale)
+
+ return prompt_embeds, negative_prompt_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
+ def prepare_latents(
+ self,
+ batch_size: int,
+ num_channels_latents: int,
+ height: int,
+ width: int,
+ dtype: torch.dtype,
+ device: torch.device,
+ generator: Union[torch.Generator, List[torch.Generator]],
+ latents: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ r"""
+ Prepare the latent vectors for diffusion.
+
+ Args:
+ batch_size (int): The number of samples in the batch.
+ num_channels_latents (int): The number of channels in the latent vectors.
+ height (int): The height of the latent vectors.
+ width (int): The width of the latent vectors.
+ dtype (torch.dtype): The data type of the latent vectors.
+ device (torch.device): The device to place the latent vectors on.
+ generator (Union[torch.Generator, List[torch.Generator]]): The generator(s) to use for random number generation.
+ latents (Optional[torch.Tensor]): The pre-existing latent vectors. If None, new latent vectors will be generated.
+
+ Returns:
+ torch.Tensor: The prepared latent vectors.
+ """
+ shape = (
+ batch_size,
+ num_channels_latents,
+ int(height) // self.vae_scale_factor,
+ int(width) // self.vae_scale_factor,
+ )
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(
+ self, generator: Union[torch.Generator, List[torch.Generator]], eta: float
+ ) -> Dict[str, Any]:
+ r"""
+ Prepare extra keyword arguments for the scheduler step.
+
+ Args:
+ generator (Union[torch.Generator, List[torch.Generator]]): The generator used for sampling.
+ eta (float): The value of eta (η) used with the DDIMScheduler. Should be between 0 and 1.
+
+ Returns:
+ Dict[str, Any]: A dictionary containing the extra keyword arguments for the scheduler step.
+ """
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def prepare_image(
+ self,
+ image: Union[torch.Tensor, PIL.Image.Image, List[Union[torch.Tensor, PIL.Image.Image]]],
+ width: int,
+ height: int,
+ batch_size: int,
+ num_images_per_prompt: int,
+ device: torch.device,
+ dtype: torch.dtype,
+ do_classifier_free_guidance: bool = False,
+ guess_mode: bool = False,
+ ) -> torch.Tensor:
+ r"""
+ Prepares the input image for processing.
+
+ Args:
+ image (torch.Tensor or PIL.Image.Image or list): The input image(s).
+ width (int): The desired width of the image.
+ height (int): The desired height of the image.
+ batch_size (int): The batch size for processing.
+ num_images_per_prompt (int): The number of images per prompt.
+ device (torch.device): The device to use for processing.
+ dtype (torch.dtype): The data type of the image.
+ do_classifier_free_guidance (bool, optional): Whether to perform classifier-free guidance. Defaults to False.
+ guess_mode (bool, optional): Whether to use guess mode. Defaults to False.
+
+ Returns:
+ torch.Tensor: The prepared image for processing.
+ """
+ if not isinstance(image, torch.Tensor):
+ if isinstance(image, PIL.Image.Image):
+ image = [image]
+
+ if isinstance(image[0], PIL.Image.Image):
+ images = []
+
+ for image_ in image:
+ image_ = image_.convert("RGB")
+ image_ = image_.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
+ image_ = np.array(image_)
+ image_ = image_[None, :]
+ images.append(image_)
+
+ image = images
+
+ image = np.concatenate(image, axis=0)
+ image = np.array(image).astype(np.float32) / 255.0
+ image = (image - 0.5) / 0.5
+ image = image.transpose(0, 3, 1, 2)
+ image = torch.from_numpy(image)
+ elif isinstance(image[0], torch.Tensor):
+ image = torch.cat(image, dim=0)
+
+ image_batch_size = image.shape[0]
+
+ if image_batch_size == 1:
+ repeat_by = batch_size
+ else:
+ # image batch size is the same as prompt batch size
+ repeat_by = num_images_per_prompt
+
+ image = image.repeat_interleave(repeat_by, dim=0)
+
+ image = image.to(device=device, dtype=dtype)
+
+ if do_classifier_free_guidance and not guess_mode:
+ image = torch.cat([image] * 2)
+
+ return image
+
+ def prepare_ref_latents(
+ self,
+ refimage: torch.Tensor,
+ batch_size: int,
+ dtype: torch.dtype,
+ device: torch.device,
+ generator: Union[int, List[int]],
+ do_classifier_free_guidance: bool,
+ ) -> torch.Tensor:
+ r"""
+ Prepares reference latents for generating images.
+
+ Args:
+ refimage (torch.Tensor): The reference image.
+ batch_size (int): The desired batch size.
+ dtype (torch.dtype): The data type of the tensors.
+ device (torch.device): The device to perform computations on.
+ generator (int or list): The generator index or a list of generator indices.
+ do_classifier_free_guidance (bool): Whether to use classifier-free guidance.
+
+ Returns:
+ torch.Tensor: The prepared reference latents.
+ """
+ refimage = refimage.to(device=device, dtype=dtype)
+
+ # encode the mask image into latents space so we can concatenate it to the latents
+ if isinstance(generator, list):
+ ref_image_latents = [
+ self.vae.encode(refimage[i : i + 1]).latent_dist.sample(generator=generator[i])
+ for i in range(batch_size)
+ ]
+ ref_image_latents = torch.cat(ref_image_latents, dim=0)
+ else:
+ ref_image_latents = self.vae.encode(refimage).latent_dist.sample(generator=generator)
+ ref_image_latents = self.vae.config.scaling_factor * ref_image_latents
+
+ # duplicate mask and ref_image_latents for each generation per prompt, using mps friendly method
+ if ref_image_latents.shape[0] < batch_size:
+ if not batch_size % ref_image_latents.shape[0] == 0:
+ raise ValueError(
+ "The passed images and the required batch size don't match. Images are supposed to be duplicated"
+ f" to a total batch size of {batch_size}, but {ref_image_latents.shape[0]} images were passed."
+ " Make sure the number of images that you pass is divisible by the total requested batch size."
+ )
+ ref_image_latents = ref_image_latents.repeat(batch_size // ref_image_latents.shape[0], 1, 1, 1)
+
+ # aligning device to prevent device errors when concating it with the latent model input
+ ref_image_latents = ref_image_latents.to(device=device, dtype=dtype)
+ return ref_image_latents
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
+ def run_safety_checker(
+ self, image: Union[torch.Tensor, PIL.Image.Image], device: torch.device, dtype: torch.dtype
+ ) -> Tuple[Union[torch.Tensor, PIL.Image.Image], Optional[bool]]:
+ r"""
+ Runs the safety checker on the given image.
+
+ Args:
+ image (Union[torch.Tensor, PIL.Image.Image]): The input image to be checked.
+ device (torch.device): The device to run the safety checker on.
+ dtype (torch.dtype): The data type of the input image.
+
+ Returns:
+ (image, has_nsfw_concept) Tuple[Union[torch.Tensor, PIL.Image.Image], Optional[bool]]: A tuple containing the processed image and
+ a boolean indicating whether the image has a NSFW (Not Safe for Work) concept.
+ """
+ if self.safety_checker is None:
+ has_nsfw_concept = None
+ else:
+ if torch.is_tensor(image):
+ feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
+ else:
+ feature_extractor_input = self.image_processor.numpy_to_pil(image)
+ safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
+ image, has_nsfw_concept = self.safety_checker(
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
+ )
+ return image, has_nsfw_concept
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ ref_image: Union[torch.Tensor, PIL.Image.Image] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
+ callback_steps: int = 1,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ guidance_rescale: float = 0.0,
+ attention_auto_machine_weight: float = 1.0,
+ gn_auto_machine_weight: float = 1.0,
+ style_fidelity: float = 0.5,
+ reference_attn: bool = True,
+ reference_adain: bool = True,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ ref_image (`torch.Tensor`, `PIL.Image.Image`):
+ The Reference Control input condition. Reference Control uses this input condition to generate guidance to Unet. If
+ the type is specified as `torch.Tensor`, it is passed to Reference Control as is. `PIL.Image.Image` can
+ also be accepted as an image.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
+ Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
+ [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
+ Guidance rescale factor should fix overexposure when using zero terminal SNR.
+ attention_auto_machine_weight (`float`):
+ Weight of using reference query for self attention's context.
+ If attention_auto_machine_weight=1.0, use reference query for all self attention's context.
+ gn_auto_machine_weight (`float`):
+ Weight of using reference adain. If gn_auto_machine_weight=2.0, use all reference adain plugins.
+ style_fidelity (`float`):
+ style fidelity of ref_uncond_xt. If style_fidelity=1.0, control more important,
+ elif style_fidelity=0.0, prompt more important, else balanced.
+ reference_attn (`bool`):
+ Whether to use reference query for self attention's context.
+ reference_adain (`bool`):
+ Whether to use reference adain.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+ assert reference_attn or reference_adain, "`reference_attn` or `reference_adain` must be True."
+
+ # 0. Default height and width to unet
+ height, width = self._default_height_width(height, width, ref_image)
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
+ )
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ text_encoder_lora_scale = (
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
+ )
+ prompt_embeds = self._encode_prompt(
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ lora_scale=text_encoder_lora_scale,
+ )
+
+ # 4. Preprocess reference image
+ ref_image = self.prepare_image(
+ image=ref_image,
+ width=width,
+ height=height,
+ batch_size=batch_size * num_images_per_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ dtype=prompt_embeds.dtype,
+ )
+
+ # 5. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # 6. Prepare latent variables
+ num_channels_latents = self.unet.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 7. Prepare reference latent variables
+ ref_image_latents = self.prepare_ref_latents(
+ ref_image,
+ batch_size * num_images_per_prompt,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ do_classifier_free_guidance,
+ )
+
+ # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 9. Modify self attention and group norm
+ MODE = "write"
+ uc_mask = (
+ torch.Tensor([1] * batch_size * num_images_per_prompt + [0] * batch_size * num_images_per_prompt)
+ .type_as(ref_image_latents)
+ .bool()
+ )
+
+ def hacked_basic_transformer_inner_forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ timestep: Optional[torch.LongTensor] = None,
+ cross_attention_kwargs: Dict[str, Any] = None,
+ class_labels: Optional[torch.LongTensor] = None,
+ ):
+ if self.use_ada_layer_norm:
+ norm_hidden_states = self.norm1(hidden_states, timestep)
+ elif self.use_ada_layer_norm_zero:
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
+ )
+ else:
+ norm_hidden_states = self.norm1(hidden_states)
+
+ # 1. Self-Attention
+ cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
+ if self.only_cross_attention:
+ attn_output = self.attn1(
+ norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
+ attention_mask=attention_mask,
+ **cross_attention_kwargs,
+ )
+ else:
+ if MODE == "write":
+ self.bank.append(norm_hidden_states.detach().clone())
+ attn_output = self.attn1(
+ norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
+ attention_mask=attention_mask,
+ **cross_attention_kwargs,
+ )
+ if MODE == "read":
+ if attention_auto_machine_weight > self.attn_weight:
+ attn_output_uc = self.attn1(
+ norm_hidden_states,
+ encoder_hidden_states=torch.cat([norm_hidden_states] + self.bank, dim=1),
+ # attention_mask=attention_mask,
+ **cross_attention_kwargs,
+ )
+ attn_output_c = attn_output_uc.clone()
+ if do_classifier_free_guidance and style_fidelity > 0:
+ attn_output_c[uc_mask] = self.attn1(
+ norm_hidden_states[uc_mask],
+ encoder_hidden_states=norm_hidden_states[uc_mask],
+ **cross_attention_kwargs,
+ )
+ attn_output = style_fidelity * attn_output_c + (1.0 - style_fidelity) * attn_output_uc
+ self.bank.clear()
+ else:
+ attn_output = self.attn1(
+ norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
+ attention_mask=attention_mask,
+ **cross_attention_kwargs,
+ )
+ if self.use_ada_layer_norm_zero:
+ attn_output = gate_msa.unsqueeze(1) * attn_output
+ hidden_states = attn_output + hidden_states
+
+ if self.attn2 is not None:
+ norm_hidden_states = (
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
+ )
+
+ # 2. Cross-Attention
+ attn_output = self.attn2(
+ norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=encoder_attention_mask,
+ **cross_attention_kwargs,
+ )
+ hidden_states = attn_output + hidden_states
+
+ # 3. Feed-forward
+ norm_hidden_states = self.norm3(hidden_states)
+
+ if self.use_ada_layer_norm_zero:
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
+
+ ff_output = self.ff(norm_hidden_states)
+
+ if self.use_ada_layer_norm_zero:
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
+
+ hidden_states = ff_output + hidden_states
+
+ return hidden_states
+
+ def hacked_mid_forward(self, *args, **kwargs):
+ eps = 1e-6
+ x = self.original_forward(*args, **kwargs)
+ if MODE == "write":
+ if gn_auto_machine_weight >= self.gn_weight:
+ var, mean = torch.var_mean(x, dim=(2, 3), keepdim=True, correction=0)
+ self.mean_bank.append(mean)
+ self.var_bank.append(var)
+ if MODE == "read":
+ if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
+ var, mean = torch.var_mean(x, dim=(2, 3), keepdim=True, correction=0)
+ std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
+ mean_acc = sum(self.mean_bank) / float(len(self.mean_bank))
+ var_acc = sum(self.var_bank) / float(len(self.var_bank))
+ std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
+ x_uc = (((x - mean) / std) * std_acc) + mean_acc
+ x_c = x_uc.clone()
+ if do_classifier_free_guidance and style_fidelity > 0:
+ x_c[uc_mask] = x[uc_mask]
+ x = style_fidelity * x_c + (1.0 - style_fidelity) * x_uc
+ self.mean_bank = []
+ self.var_bank = []
+ return x
+
+ def hack_CrossAttnDownBlock2D_forward(
+ self,
+ hidden_states: torch.Tensor,
+ temb: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ ):
+ eps = 1e-6
+
+ # TODO(Patrick, William) - attention mask is not used
+ output_states = ()
+
+ for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)):
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ cross_attention_kwargs=cross_attention_kwargs,
+ attention_mask=attention_mask,
+ encoder_attention_mask=encoder_attention_mask,
+ return_dict=False,
+ )[0]
+ if MODE == "write":
+ if gn_auto_machine_weight >= self.gn_weight:
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
+ self.mean_bank.append([mean])
+ self.var_bank.append([var])
+ if MODE == "read":
+ if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
+ std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
+ mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i]))
+ var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i]))
+ std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
+ hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc
+ hidden_states_c = hidden_states_uc.clone()
+ if do_classifier_free_guidance and style_fidelity > 0:
+ hidden_states_c[uc_mask] = hidden_states[uc_mask]
+ hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc
+
+ output_states = output_states + (hidden_states,)
+
+ if MODE == "read":
+ self.mean_bank = []
+ self.var_bank = []
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ output_states = output_states + (hidden_states,)
+
+ return hidden_states, output_states
+
+ def hacked_DownBlock2D_forward(
+ self,
+ hidden_states: torch.Tensor,
+ temb: Optional[torch.Tensor] = None,
+ **kwargs: Any,
+ ) -> Tuple[torch.Tensor, ...]:
+ eps = 1e-6
+
+ output_states = ()
+
+ for i, resnet in enumerate(self.resnets):
+ hidden_states = resnet(hidden_states, temb)
+
+ if MODE == "write":
+ if gn_auto_machine_weight >= self.gn_weight:
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
+ self.mean_bank.append([mean])
+ self.var_bank.append([var])
+ if MODE == "read":
+ if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
+ std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
+ mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i]))
+ var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i]))
+ std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
+ hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc
+ hidden_states_c = hidden_states_uc.clone()
+ if do_classifier_free_guidance and style_fidelity > 0:
+ hidden_states_c[uc_mask] = hidden_states[uc_mask]
+ hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc
+
+ output_states = output_states + (hidden_states,)
+
+ if MODE == "read":
+ self.mean_bank = []
+ self.var_bank = []
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ output_states = output_states + (hidden_states,)
+
+ return hidden_states, output_states
+
+ def hacked_CrossAttnUpBlock2D_forward(
+ self,
+ hidden_states: torch.Tensor,
+ res_hidden_states_tuple: Tuple[torch.Tensor, ...],
+ temb: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ upsample_size: Optional[int] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ eps = 1e-6
+ # TODO(Patrick, William) - attention mask is not used
+ for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)):
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ cross_attention_kwargs=cross_attention_kwargs,
+ attention_mask=attention_mask,
+ encoder_attention_mask=encoder_attention_mask,
+ return_dict=False,
+ )[0]
+
+ if MODE == "write":
+ if gn_auto_machine_weight >= self.gn_weight:
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
+ self.mean_bank.append([mean])
+ self.var_bank.append([var])
+ if MODE == "read":
+ if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
+ std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
+ mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i]))
+ var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i]))
+ std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
+ hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc
+ hidden_states_c = hidden_states_uc.clone()
+ if do_classifier_free_guidance and style_fidelity > 0:
+ hidden_states_c[uc_mask] = hidden_states[uc_mask]
+ hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc
+
+ if MODE == "read":
+ self.mean_bank = []
+ self.var_bank = []
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states, upsample_size)
+
+ return hidden_states
+
+ def hacked_UpBlock2D_forward(
+ self,
+ hidden_states: torch.Tensor,
+ res_hidden_states_tuple: Tuple[torch.Tensor, ...],
+ temb: Optional[torch.Tensor] = None,
+ upsample_size: Optional[int] = None,
+ **kwargs: Any,
+ ) -> torch.Tensor:
+ eps = 1e-6
+ for i, resnet in enumerate(self.resnets):
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+ hidden_states = resnet(hidden_states, temb)
+
+ if MODE == "write":
+ if gn_auto_machine_weight >= self.gn_weight:
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
+ self.mean_bank.append([mean])
+ self.var_bank.append([var])
+ if MODE == "read":
+ if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
+ std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
+ mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i]))
+ var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i]))
+ std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
+ hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc
+ hidden_states_c = hidden_states_uc.clone()
+ if do_classifier_free_guidance and style_fidelity > 0:
+ hidden_states_c[uc_mask] = hidden_states[uc_mask]
+ hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc
+
+ if MODE == "read":
+ self.mean_bank = []
+ self.var_bank = []
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states, upsample_size)
+
+ return hidden_states
+
+ if reference_attn:
+ attn_modules = [module for module in torch_dfs(self.unet) if isinstance(module, BasicTransformerBlock)]
+ attn_modules = sorted(attn_modules, key=lambda x: -x.norm1.normalized_shape[0])
+
+ for i, module in enumerate(attn_modules):
+ module._original_inner_forward = module.forward
+ module.forward = hacked_basic_transformer_inner_forward.__get__(module, BasicTransformerBlock)
+ module.bank = []
+ module.attn_weight = float(i) / float(len(attn_modules))
+
+ if reference_adain:
+ gn_modules = [self.unet.mid_block]
+ self.unet.mid_block.gn_weight = 0
+
+ down_blocks = self.unet.down_blocks
+ for w, module in enumerate(down_blocks):
+ module.gn_weight = 1.0 - float(w) / float(len(down_blocks))
+ gn_modules.append(module)
+
+ up_blocks = self.unet.up_blocks
+ for w, module in enumerate(up_blocks):
+ module.gn_weight = float(w) / float(len(up_blocks))
+ gn_modules.append(module)
+
+ for i, module in enumerate(gn_modules):
+ if getattr(module, "original_forward", None) is None:
+ module.original_forward = module.forward
+ if i == 0:
+ # mid_block
+ module.forward = hacked_mid_forward.__get__(module, torch.nn.Module)
+ elif isinstance(module, CrossAttnDownBlock2D):
+ module.forward = hack_CrossAttnDownBlock2D_forward.__get__(module, CrossAttnDownBlock2D)
+ elif isinstance(module, DownBlock2D):
+ module.forward = hacked_DownBlock2D_forward.__get__(module, DownBlock2D)
+ elif isinstance(module, CrossAttnUpBlock2D):
+ module.forward = hacked_CrossAttnUpBlock2D_forward.__get__(module, CrossAttnUpBlock2D)
+ elif isinstance(module, UpBlock2D):
+ module.forward = hacked_UpBlock2D_forward.__get__(module, UpBlock2D)
+ module.mean_bank = []
+ module.var_bank = []
+ module.gn_weight *= 2
+
+ # 10. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # ref only part
+ noise = randn_tensor(
+ ref_image_latents.shape, generator=generator, device=device, dtype=ref_image_latents.dtype
+ )
+ ref_xt = self.scheduler.add_noise(
+ ref_image_latents,
+ noise,
+ t.reshape(
+ 1,
+ ),
+ )
+ ref_xt = torch.cat([ref_xt] * 2) if do_classifier_free_guidance else ref_xt
+ ref_xt = self.scheduler.scale_model_input(ref_xt, t)
+
+ MODE = "write"
+ self.unet(
+ ref_xt,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs=cross_attention_kwargs,
+ return_dict=False,
+ )
+
+ # predict the noise residual
+ MODE = "read"
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs=cross_attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
+
+ if not output_type == "latent":
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
+ else:
+ image = latents
+ has_nsfw_concept = None
+
+ if has_nsfw_concept is None:
+ do_denormalize = [True] * image.shape[0]
+ else:
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
+
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
+
+ # Offload last model to CPU
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.final_offload_hook.offload()
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
diff --git a/diffusers/examples/community/stable_diffusion_repaint.py b/diffusers/examples/community/stable_diffusion_repaint.py
new file mode 100644
index 0000000000000000000000000000000000000000..980e9a1559970ae71bb1e00e180151d1f5e19b39
--- /dev/null
+++ b/diffusers/examples/community/stable_diffusion_repaint.py
@@ -0,0 +1,885 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Callable, List, Optional, Union
+
+import numpy as np
+import PIL.Image
+import torch
+from packaging import version
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
+
+from diffusers import AutoencoderKL, DiffusionPipeline, UNet2DConditionModel
+from diffusers.configuration_utils import FrozenDict, deprecate
+from diffusers.loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
+from diffusers.pipelines.pipeline_utils import StableDiffusionMixin
+from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
+from diffusers.pipelines.stable_diffusion.safety_checker import (
+ StableDiffusionSafetyChecker,
+)
+from diffusers.schedulers import KarrasDiffusionSchedulers
+from diffusers.utils import (
+ logging,
+)
+from diffusers.utils.torch_utils import randn_tensor
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+def prepare_mask_and_masked_image(image, mask):
+ """
+ Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be
+ converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the
+ ``image`` and ``1`` for the ``mask``.
+ The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be
+ binarized (``mask > 0.5``) and cast to ``torch.float32`` too.
+ Args:
+ image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint.
+ It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width``
+ ``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``.
+ mask (_type_): The mask to apply to the image, i.e. regions to inpaint.
+ It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width``
+ ``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``.
+ Raises:
+ ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask
+ should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions.
+ TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not
+ (ot the other way around).
+ Returns:
+ tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4
+ dimensions: ``batch x channels x height x width``.
+ """
+ if isinstance(image, torch.Tensor):
+ if not isinstance(mask, torch.Tensor):
+ raise TypeError(f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not")
+
+ # Batch single image
+ if image.ndim == 3:
+ assert image.shape[0] == 3, "Image outside a batch should be of shape (3, H, W)"
+ image = image.unsqueeze(0)
+
+ # Batch and add channel dim for single mask
+ if mask.ndim == 2:
+ mask = mask.unsqueeze(0).unsqueeze(0)
+
+ # Batch single mask or add channel dim
+ if mask.ndim == 3:
+ # Single batched mask, no channel dim or single mask not batched but channel dim
+ if mask.shape[0] == 1:
+ mask = mask.unsqueeze(0)
+
+ # Batched masks no channel dim
+ else:
+ mask = mask.unsqueeze(1)
+
+ assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions"
+ assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions"
+ assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size"
+
+ # Check image is in [-1, 1]
+ if image.min() < -1 or image.max() > 1:
+ raise ValueError("Image should be in [-1, 1] range")
+
+ # Check mask is in [0, 1]
+ if mask.min() < 0 or mask.max() > 1:
+ raise ValueError("Mask should be in [0, 1] range")
+
+ # Binarize mask
+ mask[mask < 0.5] = 0
+ mask[mask >= 0.5] = 1
+
+ # Image as float32
+ image = image.to(dtype=torch.float32)
+ elif isinstance(mask, torch.Tensor):
+ raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not")
+ else:
+ # preprocess image
+ if isinstance(image, (PIL.Image.Image, np.ndarray)):
+ image = [image]
+
+ if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):
+ image = [np.array(i.convert("RGB"))[None, :] for i in image]
+ image = np.concatenate(image, axis=0)
+ elif isinstance(image, list) and isinstance(image[0], np.ndarray):
+ image = np.concatenate([i[None, :] for i in image], axis=0)
+
+ image = image.transpose(0, 3, 1, 2)
+ image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
+
+ # preprocess mask
+ if isinstance(mask, (PIL.Image.Image, np.ndarray)):
+ mask = [mask]
+
+ if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image):
+ mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0)
+ mask = mask.astype(np.float32) / 255.0
+ elif isinstance(mask, list) and isinstance(mask[0], np.ndarray):
+ mask = np.concatenate([m[None, None, :] for m in mask], axis=0)
+
+ mask[mask < 0.5] = 0
+ mask[mask >= 0.5] = 1
+ mask = torch.from_numpy(mask)
+
+ # masked_image = image * (mask >= 0.5)
+ masked_image = image
+
+ return mask, masked_image
+
+
+class StableDiffusionRepaintPipeline(
+ DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, StableDiffusionLoraLoaderMixin
+):
+ r"""
+ Pipeline for text-guided image inpainting using Stable Diffusion. *This is an experimental feature*.
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+ In addition the pipeline inherits the following loading methods:
+ - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
+ - *LoRA*: [`loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]
+ as well as the following saving methods:
+ - *LoRA*: [`loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`]
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ feature_extractor ([`CLIPImageProcessor`]):
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+ """
+
+ _optional_components = ["safety_checker", "feature_extractor"]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: KarrasDiffusionSchedulers,
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPImageProcessor,
+ requires_safety_checker: bool = True,
+ ):
+ super().__init__()
+
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
+ " file"
+ )
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["steps_offset"] = 1
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if hasattr(scheduler.config, "skip_prk_steps") and scheduler.config.skip_prk_steps is False:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration"
+ " `skip_prk_steps`. `skip_prk_steps` should be set to True in the configuration file. Please make"
+ " sure to update the config accordingly as not setting `skip_prk_steps` in the config might lead to"
+ " incorrect results in future versions. If you have downloaded this checkpoint from the Hugging Face"
+ " Hub, it would be very nice if you could open a Pull request for the"
+ " `scheduler/scheduler_config.json` file"
+ )
+ deprecate(
+ "skip_prk_steps not set",
+ "1.0.0",
+ deprecation_message,
+ standard_warn=False,
+ )
+ new_config = dict(scheduler.config)
+ new_config["skip_prk_steps"] = True
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if safety_checker is None and requires_safety_checker:
+ logger.warning(
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
+ )
+
+ if safety_checker is not None and feature_extractor is None:
+ raise ValueError(
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
+ )
+
+ is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
+ version.parse(unet.config._diffusers_version).base_version
+ ) < version.parse("0.9.0.dev0")
+ is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
+ deprecation_message = (
+ "The configuration file of the unet has set the default `sample_size` to smaller than"
+ " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
+ " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
+ " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
+ " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
+ " in the config might lead to incorrect results in future versions. If you have downloaded this"
+ " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
+ " the `unet/config.json` file"
+ )
+ deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(unet.config)
+ new_config["sample_size"] = 64
+ unet._internal_dict = FrozenDict(new_config)
+ # Check shapes, assume num_channels_latents == 4, num_channels_mask == 1, num_channels_masked == 4
+ if unet.config.in_channels != 4:
+ logger.warning(
+ f"You have loaded a UNet with {unet.config.in_channels} input channels, whereas by default,"
+ f" {self.__class__} assumes that `pipeline.unet` has 4 input channels: 4 for `num_channels_latents`,"
+ ". If you did not intend to modify"
+ " this behavior, please check whether you have loaded the right checkpoint."
+ )
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
+ def _encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt=None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ """
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ # textual inversion: process multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = self.tokenizer.batch_decode(
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
+ )
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = text_inputs.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ prompt_embeds = self.text_encoder(
+ text_input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ prompt_embeds = prompt_embeds[0]
+
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ # textual inversion: process multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = uncond_input.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ negative_prompt_embeds = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ negative_prompt_embeds = negative_prompt_embeds[0]
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
+ def run_safety_checker(self, image, device, dtype):
+ if self.safety_checker is not None:
+ safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
+ image, has_nsfw_concept = self.safety_checker(
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
+ )
+ else:
+ has_nsfw_concept = None
+ return image, has_nsfw_concept
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
+ def decode_latents(self, latents):
+ latents = 1 / self.vae.config.scaling_factor * latents
+ image = self.vae.decode(latents).sample
+ image = (image / 2 + 0.5).clamp(0, 1)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+ return image
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ callback_steps,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ ):
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
+ def prepare_latents(
+ self,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ latents=None,
+ ):
+ shape = (
+ batch_size,
+ num_channels_latents,
+ height // self.vae_scale_factor,
+ width // self.vae_scale_factor,
+ )
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ def prepare_mask_latents(
+ self,
+ mask,
+ masked_image,
+ batch_size,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ do_classifier_free_guidance,
+ ):
+ # resize the mask to latents shape as we concatenate the mask to the latents
+ # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
+ # and half precision
+ mask = torch.nn.functional.interpolate(
+ mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor)
+ )
+ mask = mask.to(device=device, dtype=dtype)
+
+ masked_image = masked_image.to(device=device, dtype=dtype)
+
+ # encode the mask image into latents space so we can concatenate it to the latents
+ if isinstance(generator, list):
+ masked_image_latents = [
+ self.vae.encode(masked_image[i : i + 1]).latent_dist.sample(generator=generator[i])
+ for i in range(batch_size)
+ ]
+ masked_image_latents = torch.cat(masked_image_latents, dim=0)
+ else:
+ masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(generator=generator)
+ masked_image_latents = self.vae.config.scaling_factor * masked_image_latents
+
+ # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
+ if mask.shape[0] < batch_size:
+ if not batch_size % mask.shape[0] == 0:
+ raise ValueError(
+ "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
+ f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
+ " of masks that you pass is divisible by the total requested batch size."
+ )
+ mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
+ if masked_image_latents.shape[0] < batch_size:
+ if not batch_size % masked_image_latents.shape[0] == 0:
+ raise ValueError(
+ "The passed images and the required batch size don't match. Images are supposed to be duplicated"
+ f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
+ " Make sure the number of images that you pass is divisible by the total requested batch size."
+ )
+ masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1)
+
+ mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
+ masked_image_latents = (
+ torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents
+ )
+
+ # aligning device to prevent device errors when concating it with the latent model input
+ masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
+ return mask, masked_image_latents
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ image: Union[torch.Tensor, PIL.Image.Image] = None,
+ mask_image: Union[torch.Tensor, PIL.Image.Image] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ jump_length: Optional[int] = 10,
+ jump_n_sample: Optional[int] = 10,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
+ callback_steps: int = 1,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ image (`PIL.Image.Image`):
+ `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will
+ be masked out with `mask_image` and repainted according to `prompt`.
+ mask_image (`PIL.Image.Image`):
+ `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
+ repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted
+ to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)
+ instead of 3, so the expected shape would be `(B, H, W, 1)`.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ jump_length (`int`, *optional*, defaults to 10):
+ The number of steps taken forward in time before going backward in time for a single jump ("j" in
+ RePaint paper). Take a look at Figure 9 and 10 in https://arxiv.org/pdf/2201.09865.pdf.
+ jump_n_sample (`int`, *optional*, defaults to 10):
+ The number of times we will make forward time jump for a given chosen time sample. Take a look at
+ Figure 9 and 10 in https://arxiv.org/pdf/2201.09865.pdf.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale`
+ is less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ Examples:
+ ```py
+ >>> import PIL
+ >>> import requests
+ >>> import torch
+ >>> from io import BytesIO
+ >>> from diffusers import StableDiffusionPipeline, RePaintScheduler
+ >>> def download_image(url):
+ ... response = requests.get(url)
+ ... return PIL.Image.open(BytesIO(response.content)).convert("RGB")
+ >>> base_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/"
+ >>> img_url = base_url + "overture-creations-5sI6fQgYIuo.png"
+ >>> mask_url = base_url + "overture-creations-5sI6fQgYIuo_mask.png "
+ >>> init_image = download_image(img_url).resize((512, 512))
+ >>> mask_image = download_image(mask_url).resize((512, 512))
+ >>> pipe = DiffusionPipeline.from_pretrained(
+ ... "CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16, custom_pipeline="stable_diffusion_repaint",
+ ... )
+ >>> pipe.scheduler = RePaintScheduler.from_config(pipe.scheduler.config)
+ >>> pipe = pipe.to("cuda")
+ >>> prompt = "Face of a yellow cat, high resolution, sitting on a park bench"
+ >>> image = pipe(prompt=prompt, image=init_image, mask_image=mask_image).images[0]
+ ```
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+ # 0. Default height and width to unet
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+ # 1. Check inputs
+ self.check_inputs(
+ prompt,
+ height,
+ width,
+ callback_steps,
+ negative_prompt,
+ prompt_embeds,
+ negative_prompt_embeds,
+ )
+
+ if image is None:
+ raise ValueError("`image` input cannot be undefined.")
+
+ if mask_image is None:
+ raise ValueError("`mask_image` input cannot be undefined.")
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ prompt_embeds = self._encode_prompt(
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ )
+
+ # 4. Preprocess mask and image
+ mask, masked_image = prepare_mask_and_masked_image(image, mask_image)
+
+ # 5. set timesteps
+ self.scheduler.set_timesteps(num_inference_steps, jump_length, jump_n_sample, device)
+ self.scheduler.eta = eta
+
+ timesteps = self.scheduler.timesteps
+ # latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
+
+ # 6. Prepare latent variables
+ num_channels_latents = self.vae.config.latent_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 7. Prepare mask latent variables
+ mask, masked_image_latents = self.prepare_mask_latents(
+ mask,
+ masked_image,
+ batch_size * num_images_per_prompt,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ do_classifier_free_guidance=False, # We do not need duplicate mask and image
+ )
+
+ # 8. Check that sizes of mask, masked image and latents match
+ # num_channels_mask = mask.shape[1]
+ # num_channels_masked_image = masked_image_latents.shape[1]
+ if num_channels_latents != self.unet.config.in_channels:
+ raise ValueError(
+ f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
+ f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} "
+ f" = Please verify the config of"
+ " `pipeline.unet` or your `mask_image` or `image` input."
+ )
+
+ # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ t_last = timesteps[0] + 1
+
+ # 10. Denoising loop
+ with self.progress_bar(total=len(timesteps)) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if t >= t_last:
+ # compute the reverse: x_t-1 -> x_t
+ latents = self.scheduler.undo_step(latents, t_last, generator)
+ progress_bar.update()
+ t_last = t
+ continue
+
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+
+ # concat latents, mask, masked_image_latents in the channel dimension
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+ # latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
+
+ # predict the noise residual
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds).sample
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(
+ noise_pred,
+ t,
+ latents,
+ masked_image_latents,
+ mask,
+ **extra_step_kwargs,
+ ).prev_sample
+
+ # call the callback, if provided
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
+
+ t_last = t
+
+ # 11. Post-processing
+ image = self.decode_latents(latents)
+
+ # 12. Run safety checker
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
+
+ # 13. Convert to PIL
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ # Offload last model to CPU
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.final_offload_hook.offload()
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
diff --git a/diffusers/examples/community/stable_diffusion_tensorrt_img2img.py b/diffusers/examples/community/stable_diffusion_tensorrt_img2img.py
new file mode 100644
index 0000000000000000000000000000000000000000..91540d1f4159d7389f91da4c3ed80d45905970df
--- /dev/null
+++ b/diffusers/examples/community/stable_diffusion_tensorrt_img2img.py
@@ -0,0 +1,1134 @@
+#
+# Copyright 2024 The HuggingFace Inc. team.
+# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import gc
+import os
+from collections import OrderedDict
+from typing import List, Optional, Tuple, Union
+
+import numpy as np
+import onnx
+import onnx_graphsurgeon as gs
+import PIL.Image
+import tensorrt as trt
+import torch
+from cuda import cudart
+from huggingface_hub import snapshot_download
+from huggingface_hub.utils import validate_hf_hub_args
+from onnx import shape_inference
+from packaging import version
+from polygraphy import cuda
+from polygraphy.backend.common import bytes_from_path
+from polygraphy.backend.onnx.loader import fold_constants
+from polygraphy.backend.trt import (
+ CreateConfig,
+ Profile,
+ engine_from_bytes,
+ engine_from_network,
+ network_from_onnx_path,
+ save_engine,
+)
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
+
+from diffusers import DiffusionPipeline
+from diffusers.configuration_utils import FrozenDict, deprecate
+from diffusers.image_processor import VaeImageProcessor
+from diffusers.models import AutoencoderKL, UNet2DConditionModel
+from diffusers.pipelines.stable_diffusion import (
+ StableDiffusionPipelineOutput,
+ StableDiffusionSafetyChecker,
+)
+from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import retrieve_latents
+from diffusers.schedulers import DDIMScheduler
+from diffusers.utils import logging
+
+
+"""
+Installation instructions
+python3 -m pip install --upgrade transformers diffusers>=0.16.0
+python3 -m pip install --upgrade tensorrt~=10.2.0
+python3 -m pip install --upgrade polygraphy>=0.47.0 onnx-graphsurgeon --extra-index-url https://pypi.ngc.nvidia.com
+python3 -m pip install onnxruntime
+"""
+
+TRT_LOGGER = trt.Logger(trt.Logger.ERROR)
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+# Map of numpy dtype -> torch dtype
+numpy_to_torch_dtype_dict = {
+ np.uint8: torch.uint8,
+ np.int8: torch.int8,
+ np.int16: torch.int16,
+ np.int32: torch.int32,
+ np.int64: torch.int64,
+ np.float16: torch.float16,
+ np.float32: torch.float32,
+ np.float64: torch.float64,
+ np.complex64: torch.complex64,
+ np.complex128: torch.complex128,
+}
+if np.version.full_version >= "1.24.0":
+ numpy_to_torch_dtype_dict[np.bool_] = torch.bool
+else:
+ numpy_to_torch_dtype_dict[np.bool] = torch.bool
+
+# Map of torch dtype -> numpy dtype
+torch_to_numpy_dtype_dict = {value: key for (key, value) in numpy_to_torch_dtype_dict.items()}
+
+
+def preprocess_image(image):
+ """
+ image: torch.Tensor
+ """
+ w, h = image.size
+ w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32
+ image = image.resize((w, h))
+ image = np.array(image).astype(np.float32) / 255.0
+ image = image[None].transpose(0, 3, 1, 2)
+ image = torch.from_numpy(image).contiguous()
+ return 2.0 * image - 1.0
+
+
+class Engine:
+ def __init__(self, engine_path):
+ self.engine_path = engine_path
+ self.engine = None
+ self.context = None
+ self.buffers = OrderedDict()
+ self.tensors = OrderedDict()
+
+ def __del__(self):
+ [buf.free() for buf in self.buffers.values() if isinstance(buf, cuda.DeviceArray)]
+ del self.engine
+ del self.context
+ del self.buffers
+ del self.tensors
+
+ def build(
+ self,
+ onnx_path,
+ fp16,
+ input_profile=None,
+ enable_all_tactics=False,
+ timing_cache=None,
+ ):
+ logger.warning(f"Building TensorRT engine for {onnx_path}: {self.engine_path}")
+ p = Profile()
+ if input_profile:
+ for name, dims in input_profile.items():
+ assert len(dims) == 3
+ p.add(name, min=dims[0], opt=dims[1], max=dims[2])
+
+ extra_build_args = {}
+ if not enable_all_tactics:
+ extra_build_args["tactic_sources"] = []
+
+ engine = engine_from_network(
+ network_from_onnx_path(onnx_path, flags=[trt.OnnxParserFlag.NATIVE_INSTANCENORM]),
+ config=CreateConfig(fp16=fp16, profiles=[p], load_timing_cache=timing_cache, **extra_build_args),
+ save_timing_cache=timing_cache,
+ )
+ save_engine(engine, path=self.engine_path)
+
+ def load(self):
+ logger.warning(f"Loading TensorRT engine: {self.engine_path}")
+ self.engine = engine_from_bytes(bytes_from_path(self.engine_path))
+
+ def activate(self):
+ self.context = self.engine.create_execution_context()
+
+ def allocate_buffers(self, shape_dict=None, device="cuda"):
+ for binding in range(self.engine.num_io_tensors):
+ name = self.engine.get_tensor_name(binding)
+ if shape_dict and name in shape_dict:
+ shape = shape_dict[name]
+ else:
+ shape = self.engine.get_tensor_shape(name)
+ dtype = trt.nptype(self.engine.get_tensor_dtype(name))
+ if self.engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT:
+ self.context.set_input_shape(name, shape)
+ tensor = torch.empty(tuple(shape), dtype=numpy_to_torch_dtype_dict[dtype]).to(device=device)
+ self.tensors[name] = tensor
+
+ def infer(self, feed_dict, stream):
+ for name, buf in feed_dict.items():
+ self.tensors[name].copy_(buf)
+ for name, tensor in self.tensors.items():
+ self.context.set_tensor_address(name, tensor.data_ptr())
+ noerror = self.context.execute_async_v3(stream)
+ if not noerror:
+ raise ValueError("ERROR: inference failed.")
+
+ return self.tensors
+
+
+class Optimizer:
+ def __init__(self, onnx_graph):
+ self.graph = gs.import_onnx(onnx_graph)
+
+ def cleanup(self, return_onnx=False):
+ self.graph.cleanup().toposort()
+ if return_onnx:
+ return gs.export_onnx(self.graph)
+
+ def select_outputs(self, keep, names=None):
+ self.graph.outputs = [self.graph.outputs[o] for o in keep]
+ if names:
+ for i, name in enumerate(names):
+ self.graph.outputs[i].name = name
+
+ def fold_constants(self, return_onnx=False):
+ onnx_graph = fold_constants(gs.export_onnx(self.graph), allow_onnxruntime_shape_inference=True)
+ self.graph = gs.import_onnx(onnx_graph)
+ if return_onnx:
+ return onnx_graph
+
+ def infer_shapes(self, return_onnx=False):
+ onnx_graph = gs.export_onnx(self.graph)
+ if onnx_graph.ByteSize() > 2147483648:
+ raise TypeError("ERROR: model size exceeds supported 2GB limit")
+ else:
+ onnx_graph = shape_inference.infer_shapes(onnx_graph)
+
+ self.graph = gs.import_onnx(onnx_graph)
+ if return_onnx:
+ return onnx_graph
+
+
+class BaseModel:
+ def __init__(self, model, fp16=False, device="cuda", max_batch_size=16, embedding_dim=768, text_maxlen=77):
+ self.model = model
+ self.name = "SD Model"
+ self.fp16 = fp16
+ self.device = device
+
+ self.min_batch = 1
+ self.max_batch = max_batch_size
+ self.min_image_shape = 256 # min image resolution: 256x256
+ self.max_image_shape = 1024 # max image resolution: 1024x1024
+ self.min_latent_shape = self.min_image_shape // 8
+ self.max_latent_shape = self.max_image_shape // 8
+
+ self.embedding_dim = embedding_dim
+ self.text_maxlen = text_maxlen
+
+ def get_model(self):
+ return self.model
+
+ def get_input_names(self):
+ pass
+
+ def get_output_names(self):
+ pass
+
+ def get_dynamic_axes(self):
+ return None
+
+ def get_sample_input(self, batch_size, image_height, image_width):
+ pass
+
+ def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape):
+ return None
+
+ def get_shape_dict(self, batch_size, image_height, image_width):
+ return None
+
+ def optimize(self, onnx_graph):
+ opt = Optimizer(onnx_graph)
+ opt.cleanup()
+ opt.fold_constants()
+ opt.infer_shapes()
+ onnx_opt_graph = opt.cleanup(return_onnx=True)
+ return onnx_opt_graph
+
+ def check_dims(self, batch_size, image_height, image_width):
+ assert batch_size >= self.min_batch and batch_size <= self.max_batch
+ assert image_height % 8 == 0 or image_width % 8 == 0
+ latent_height = image_height // 8
+ latent_width = image_width // 8
+ assert latent_height >= self.min_latent_shape and latent_height <= self.max_latent_shape
+ assert latent_width >= self.min_latent_shape and latent_width <= self.max_latent_shape
+ return (latent_height, latent_width)
+
+ def get_minmax_dims(self, batch_size, image_height, image_width, static_batch, static_shape):
+ min_batch = batch_size if static_batch else self.min_batch
+ max_batch = batch_size if static_batch else self.max_batch
+ latent_height = image_height // 8
+ latent_width = image_width // 8
+ min_image_height = image_height if static_shape else self.min_image_shape
+ max_image_height = image_height if static_shape else self.max_image_shape
+ min_image_width = image_width if static_shape else self.min_image_shape
+ max_image_width = image_width if static_shape else self.max_image_shape
+ min_latent_height = latent_height if static_shape else self.min_latent_shape
+ max_latent_height = latent_height if static_shape else self.max_latent_shape
+ min_latent_width = latent_width if static_shape else self.min_latent_shape
+ max_latent_width = latent_width if static_shape else self.max_latent_shape
+ return (
+ min_batch,
+ max_batch,
+ min_image_height,
+ max_image_height,
+ min_image_width,
+ max_image_width,
+ min_latent_height,
+ max_latent_height,
+ min_latent_width,
+ max_latent_width,
+ )
+
+
+def getOnnxPath(model_name, onnx_dir, opt=True):
+ return os.path.join(onnx_dir, model_name + (".opt" if opt else "") + ".onnx")
+
+
+def getEnginePath(model_name, engine_dir):
+ return os.path.join(engine_dir, model_name + ".plan")
+
+
+def build_engines(
+ models: dict,
+ engine_dir,
+ onnx_dir,
+ onnx_opset,
+ opt_image_height,
+ opt_image_width,
+ opt_batch_size=1,
+ force_engine_rebuild=False,
+ static_batch=False,
+ static_shape=True,
+ enable_all_tactics=False,
+ timing_cache=None,
+):
+ built_engines = {}
+ if not os.path.isdir(onnx_dir):
+ os.makedirs(onnx_dir)
+ if not os.path.isdir(engine_dir):
+ os.makedirs(engine_dir)
+
+ # Export models to ONNX
+ for model_name, model_obj in models.items():
+ engine_path = getEnginePath(model_name, engine_dir)
+ if force_engine_rebuild or not os.path.exists(engine_path):
+ logger.warning("Building Engines...")
+ logger.warning("Engine build can take a while to complete")
+ onnx_path = getOnnxPath(model_name, onnx_dir, opt=False)
+ onnx_opt_path = getOnnxPath(model_name, onnx_dir)
+ if force_engine_rebuild or not os.path.exists(onnx_opt_path):
+ if force_engine_rebuild or not os.path.exists(onnx_path):
+ logger.warning(f"Exporting model: {onnx_path}")
+ model = model_obj.get_model()
+ with torch.inference_mode(), torch.autocast("cuda"):
+ inputs = model_obj.get_sample_input(opt_batch_size, opt_image_height, opt_image_width)
+ torch.onnx.export(
+ model,
+ inputs,
+ onnx_path,
+ export_params=True,
+ opset_version=onnx_opset,
+ do_constant_folding=True,
+ input_names=model_obj.get_input_names(),
+ output_names=model_obj.get_output_names(),
+ dynamic_axes=model_obj.get_dynamic_axes(),
+ )
+ del model
+ torch.cuda.empty_cache()
+ gc.collect()
+ else:
+ logger.warning(f"Found cached model: {onnx_path}")
+
+ # Optimize onnx
+ if force_engine_rebuild or not os.path.exists(onnx_opt_path):
+ logger.warning(f"Generating optimizing model: {onnx_opt_path}")
+ onnx_opt_graph = model_obj.optimize(onnx.load(onnx_path))
+ onnx.save(onnx_opt_graph, onnx_opt_path)
+ else:
+ logger.warning(f"Found cached optimized model: {onnx_opt_path} ")
+
+ # Build TensorRT engines
+ for model_name, model_obj in models.items():
+ engine_path = getEnginePath(model_name, engine_dir)
+ engine = Engine(engine_path)
+ onnx_path = getOnnxPath(model_name, onnx_dir, opt=False)
+ onnx_opt_path = getOnnxPath(model_name, onnx_dir)
+
+ if force_engine_rebuild or not os.path.exists(engine.engine_path):
+ engine.build(
+ onnx_opt_path,
+ fp16=True,
+ input_profile=model_obj.get_input_profile(
+ opt_batch_size,
+ opt_image_height,
+ opt_image_width,
+ static_batch=static_batch,
+ static_shape=static_shape,
+ ),
+ timing_cache=timing_cache,
+ )
+ built_engines[model_name] = engine
+
+ # Load and activate TensorRT engines
+ for model_name, model_obj in models.items():
+ engine = built_engines[model_name]
+ engine.load()
+ engine.activate()
+
+ return built_engines
+
+
+def runEngine(engine, feed_dict, stream):
+ return engine.infer(feed_dict, stream)
+
+
+class CLIP(BaseModel):
+ def __init__(self, model, device, max_batch_size, embedding_dim):
+ super(CLIP, self).__init__(
+ model=model, device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim
+ )
+ self.name = "CLIP"
+
+ def get_input_names(self):
+ return ["input_ids"]
+
+ def get_output_names(self):
+ return ["text_embeddings", "pooler_output"]
+
+ def get_dynamic_axes(self):
+ return {"input_ids": {0: "B"}, "text_embeddings": {0: "B"}}
+
+ def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape):
+ self.check_dims(batch_size, image_height, image_width)
+ min_batch, max_batch, _, _, _, _, _, _, _, _ = self.get_minmax_dims(
+ batch_size, image_height, image_width, static_batch, static_shape
+ )
+ return {
+ "input_ids": [(min_batch, self.text_maxlen), (batch_size, self.text_maxlen), (max_batch, self.text_maxlen)]
+ }
+
+ def get_shape_dict(self, batch_size, image_height, image_width):
+ self.check_dims(batch_size, image_height, image_width)
+ return {
+ "input_ids": (batch_size, self.text_maxlen),
+ "text_embeddings": (batch_size, self.text_maxlen, self.embedding_dim),
+ }
+
+ def get_sample_input(self, batch_size, image_height, image_width):
+ self.check_dims(batch_size, image_height, image_width)
+ return torch.zeros(batch_size, self.text_maxlen, dtype=torch.int32, device=self.device)
+
+ def optimize(self, onnx_graph):
+ opt = Optimizer(onnx_graph)
+ opt.select_outputs([0]) # delete graph output#1
+ opt.cleanup()
+ opt.fold_constants()
+ opt.infer_shapes()
+ opt.select_outputs([0], names=["text_embeddings"]) # rename network output
+ opt_onnx_graph = opt.cleanup(return_onnx=True)
+ return opt_onnx_graph
+
+
+def make_CLIP(model, device, max_batch_size, embedding_dim, inpaint=False):
+ return CLIP(model, device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim)
+
+
+class UNet(BaseModel):
+ def __init__(
+ self, model, fp16=False, device="cuda", max_batch_size=16, embedding_dim=768, text_maxlen=77, unet_dim=4
+ ):
+ super(UNet, self).__init__(
+ model=model,
+ fp16=fp16,
+ device=device,
+ max_batch_size=max_batch_size,
+ embedding_dim=embedding_dim,
+ text_maxlen=text_maxlen,
+ )
+ self.unet_dim = unet_dim
+ self.name = "UNet"
+
+ def get_input_names(self):
+ return ["sample", "timestep", "encoder_hidden_states"]
+
+ def get_output_names(self):
+ return ["latent"]
+
+ def get_dynamic_axes(self):
+ return {
+ "sample": {0: "2B", 2: "H", 3: "W"},
+ "encoder_hidden_states": {0: "2B"},
+ "latent": {0: "2B", 2: "H", 3: "W"},
+ }
+
+ def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape):
+ latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
+ (
+ min_batch,
+ max_batch,
+ _,
+ _,
+ _,
+ _,
+ min_latent_height,
+ max_latent_height,
+ min_latent_width,
+ max_latent_width,
+ ) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_shape)
+ return {
+ "sample": [
+ (2 * min_batch, self.unet_dim, min_latent_height, min_latent_width),
+ (2 * batch_size, self.unet_dim, latent_height, latent_width),
+ (2 * max_batch, self.unet_dim, max_latent_height, max_latent_width),
+ ],
+ "encoder_hidden_states": [
+ (2 * min_batch, self.text_maxlen, self.embedding_dim),
+ (2 * batch_size, self.text_maxlen, self.embedding_dim),
+ (2 * max_batch, self.text_maxlen, self.embedding_dim),
+ ],
+ }
+
+ def get_shape_dict(self, batch_size, image_height, image_width):
+ latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
+ return {
+ "sample": (2 * batch_size, self.unet_dim, latent_height, latent_width),
+ "encoder_hidden_states": (2 * batch_size, self.text_maxlen, self.embedding_dim),
+ "latent": (2 * batch_size, 4, latent_height, latent_width),
+ }
+
+ def get_sample_input(self, batch_size, image_height, image_width):
+ latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
+ dtype = torch.float16 if self.fp16 else torch.float32
+ return (
+ torch.randn(
+ 2 * batch_size, self.unet_dim, latent_height, latent_width, dtype=torch.float32, device=self.device
+ ),
+ torch.tensor([1.0], dtype=torch.float32, device=self.device),
+ torch.randn(2 * batch_size, self.text_maxlen, self.embedding_dim, dtype=dtype, device=self.device),
+ )
+
+
+def make_UNet(model, device, max_batch_size, embedding_dim, inpaint=False):
+ return UNet(
+ model,
+ fp16=True,
+ device=device,
+ max_batch_size=max_batch_size,
+ embedding_dim=embedding_dim,
+ unet_dim=(9 if inpaint else 4),
+ )
+
+
+class VAE(BaseModel):
+ def __init__(self, model, device, max_batch_size, embedding_dim):
+ super(VAE, self).__init__(
+ model=model, device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim
+ )
+ self.name = "VAE decoder"
+
+ def get_input_names(self):
+ return ["latent"]
+
+ def get_output_names(self):
+ return ["images"]
+
+ def get_dynamic_axes(self):
+ return {"latent": {0: "B", 2: "H", 3: "W"}, "images": {0: "B", 2: "8H", 3: "8W"}}
+
+ def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape):
+ latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
+ (
+ min_batch,
+ max_batch,
+ _,
+ _,
+ _,
+ _,
+ min_latent_height,
+ max_latent_height,
+ min_latent_width,
+ max_latent_width,
+ ) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_shape)
+ return {
+ "latent": [
+ (min_batch, 4, min_latent_height, min_latent_width),
+ (batch_size, 4, latent_height, latent_width),
+ (max_batch, 4, max_latent_height, max_latent_width),
+ ]
+ }
+
+ def get_shape_dict(self, batch_size, image_height, image_width):
+ latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
+ return {
+ "latent": (batch_size, 4, latent_height, latent_width),
+ "images": (batch_size, 3, image_height, image_width),
+ }
+
+ def get_sample_input(self, batch_size, image_height, image_width):
+ latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
+ return torch.randn(batch_size, 4, latent_height, latent_width, dtype=torch.float32, device=self.device)
+
+
+def make_VAE(model, device, max_batch_size, embedding_dim, inpaint=False):
+ return VAE(model, device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim)
+
+
+class TorchVAEEncoder(torch.nn.Module):
+ def __init__(self, model):
+ super().__init__()
+ self.vae_encoder = model
+
+ def forward(self, x):
+ return retrieve_latents(self.vae_encoder.encode(x))
+
+
+class VAEEncoder(BaseModel):
+ def __init__(self, model, device, max_batch_size, embedding_dim):
+ super(VAEEncoder, self).__init__(
+ model=model, device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim
+ )
+ self.name = "VAE encoder"
+
+ def get_model(self):
+ vae_encoder = TorchVAEEncoder(self.model)
+ return vae_encoder
+
+ def get_input_names(self):
+ return ["images"]
+
+ def get_output_names(self):
+ return ["latent"]
+
+ def get_dynamic_axes(self):
+ return {"images": {0: "B", 2: "8H", 3: "8W"}, "latent": {0: "B", 2: "H", 3: "W"}}
+
+ def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape):
+ assert batch_size >= self.min_batch and batch_size <= self.max_batch
+ min_batch = batch_size if static_batch else self.min_batch
+ max_batch = batch_size if static_batch else self.max_batch
+ self.check_dims(batch_size, image_height, image_width)
+ (
+ min_batch,
+ max_batch,
+ min_image_height,
+ max_image_height,
+ min_image_width,
+ max_image_width,
+ _,
+ _,
+ _,
+ _,
+ ) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_shape)
+
+ return {
+ "images": [
+ (min_batch, 3, min_image_height, min_image_width),
+ (batch_size, 3, image_height, image_width),
+ (max_batch, 3, max_image_height, max_image_width),
+ ]
+ }
+
+ def get_shape_dict(self, batch_size, image_height, image_width):
+ latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
+ return {
+ "images": (batch_size, 3, image_height, image_width),
+ "latent": (batch_size, 4, latent_height, latent_width),
+ }
+
+ def get_sample_input(self, batch_size, image_height, image_width):
+ self.check_dims(batch_size, image_height, image_width)
+ return torch.randn(batch_size, 3, image_height, image_width, dtype=torch.float32, device=self.device)
+
+
+def make_VAEEncoder(model, device, max_batch_size, embedding_dim, inpaint=False):
+ return VAEEncoder(model, device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim)
+
+
+class TensorRTStableDiffusionImg2ImgPipeline(DiffusionPipeline):
+ r"""
+ Pipeline for image-to-image generation using TensorRT accelerated Stable Diffusion.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ feature_extractor ([`CLIPImageProcessor`]):
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+ """
+
+ _optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: DDIMScheduler,
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPImageProcessor,
+ image_encoder: CLIPVisionModelWithProjection = None,
+ requires_safety_checker: bool = True,
+ stages=["clip", "unet", "vae", "vae_encoder"],
+ image_height: int = 512,
+ image_width: int = 512,
+ max_batch_size: int = 16,
+ # ONNX export parameters
+ onnx_opset: int = 17,
+ onnx_dir: str = "onnx",
+ # TensorRT engine build parameters
+ engine_dir: str = "engine",
+ force_engine_rebuild: bool = False,
+ timing_cache: str = "timing_cache",
+ ):
+ super().__init__()
+
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
+ " file"
+ )
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["steps_offset"] = 1
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
+ )
+ deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["clip_sample"] = False
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if safety_checker is None and requires_safety_checker:
+ logger.warning(
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
+ )
+
+ if safety_checker is not None and feature_extractor is None:
+ raise ValueError(
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
+ )
+
+ is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
+ version.parse(unet.config._diffusers_version).base_version
+ ) < version.parse("0.9.0.dev0")
+ is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
+ deprecation_message = (
+ "The configuration file of the unet has set the default `sample_size` to smaller than"
+ " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
+ " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
+ " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
+ " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
+ " in the config might lead to incorrect results in future versions. If you have downloaded this"
+ " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
+ " the `unet/config.json` file"
+ )
+ deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(unet.config)
+ new_config["sample_size"] = 64
+ unet._internal_dict = FrozenDict(new_config)
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ image_encoder=image_encoder,
+ )
+
+ self.stages = stages
+ self.image_height, self.image_width = image_height, image_width
+ self.inpaint = False
+ self.onnx_opset = onnx_opset
+ self.onnx_dir = onnx_dir
+ self.engine_dir = engine_dir
+ self.force_engine_rebuild = force_engine_rebuild
+ self.timing_cache = timing_cache
+ self.build_static_batch = False
+ self.build_dynamic_shape = False
+
+ self.max_batch_size = max_batch_size
+ # TODO: Restrict batch size to 4 for larger image dimensions as a WAR for TensorRT limitation.
+ if self.build_dynamic_shape or self.image_height > 512 or self.image_width > 512:
+ self.max_batch_size = 4
+
+ self.stream = None # loaded in loadResources()
+ self.models = {} # loaded in __loadModels()
+ self.engine = {} # loaded in build_engines()
+
+ self.vae.forward = self.vae.decode
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
+
+ def __loadModels(self):
+ # Load pipeline models
+ self.embedding_dim = self.text_encoder.config.hidden_size
+ models_args = {
+ "device": self.torch_device,
+ "max_batch_size": self.max_batch_size,
+ "embedding_dim": self.embedding_dim,
+ "inpaint": self.inpaint,
+ }
+ if "clip" in self.stages:
+ self.models["clip"] = make_CLIP(self.text_encoder, **models_args)
+ if "unet" in self.stages:
+ self.models["unet"] = make_UNet(self.unet, **models_args)
+ if "vae" in self.stages:
+ self.models["vae"] = make_VAE(self.vae, **models_args)
+ if "vae_encoder" in self.stages:
+ self.models["vae_encoder"] = make_VAEEncoder(self.vae, **models_args)
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
+ def run_safety_checker(
+ self, image: Union[torch.Tensor, PIL.Image.Image], device: torch.device, dtype: torch.dtype
+ ) -> Tuple[Union[torch.Tensor, PIL.Image.Image], Optional[bool]]:
+ r"""
+ Runs the safety checker on the given image.
+ Args:
+ image (Union[torch.Tensor, PIL.Image.Image]): The input image to be checked.
+ device (torch.device): The device to run the safety checker on.
+ dtype (torch.dtype): The data type of the input image.
+ Returns:
+ (image, has_nsfw_concept) Tuple[Union[torch.Tensor, PIL.Image.Image], Optional[bool]]: A tuple containing the processed image and
+ a boolean indicating whether the image has a NSFW (Not Safe for Work) concept.
+ """
+ if self.safety_checker is None:
+ has_nsfw_concept = None
+ else:
+ if torch.is_tensor(image):
+ feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
+ else:
+ feature_extractor_input = self.image_processor.numpy_to_pil(image)
+ safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
+ image, has_nsfw_concept = self.safety_checker(
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
+ )
+ return image, has_nsfw_concept
+
+ @classmethod
+ @validate_hf_hub_args
+ def set_cached_folder(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
+ cache_dir = kwargs.pop("cache_dir", None)
+ proxies = kwargs.pop("proxies", None)
+ local_files_only = kwargs.pop("local_files_only", False)
+ token = kwargs.pop("token", None)
+ revision = kwargs.pop("revision", None)
+
+ cls.cached_folder = (
+ pretrained_model_name_or_path
+ if os.path.isdir(pretrained_model_name_or_path)
+ else snapshot_download(
+ pretrained_model_name_or_path,
+ cache_dir=cache_dir,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ token=token,
+ revision=revision,
+ )
+ )
+
+ def to(self, torch_device: Optional[Union[str, torch.device]] = None, silence_dtype_warnings: bool = False):
+ super().to(torch_device, silence_dtype_warnings=silence_dtype_warnings)
+
+ self.onnx_dir = os.path.join(self.cached_folder, self.onnx_dir)
+ self.engine_dir = os.path.join(self.cached_folder, self.engine_dir)
+ self.timing_cache = os.path.join(self.cached_folder, self.timing_cache)
+
+ # set device
+ self.torch_device = self._execution_device
+ logger.warning(f"Running inference on device: {self.torch_device}")
+
+ # load models
+ self.__loadModels()
+
+ # build engines
+ self.engine = build_engines(
+ self.models,
+ self.engine_dir,
+ self.onnx_dir,
+ self.onnx_opset,
+ opt_image_height=self.image_height,
+ opt_image_width=self.image_width,
+ force_engine_rebuild=self.force_engine_rebuild,
+ static_batch=self.build_static_batch,
+ static_shape=not self.build_dynamic_shape,
+ timing_cache=self.timing_cache,
+ )
+
+ return self
+
+ def __initialize_timesteps(self, timesteps, strength):
+ self.scheduler.set_timesteps(timesteps)
+ offset = self.scheduler.steps_offset if hasattr(self.scheduler, "steps_offset") else 0
+ init_timestep = int(timesteps * strength) + offset
+ init_timestep = min(init_timestep, timesteps)
+ t_start = max(timesteps - init_timestep + offset, 0)
+ timesteps = self.scheduler.timesteps[t_start:].to(self.torch_device)
+ return timesteps, t_start
+
+ def __preprocess_images(self, batch_size, images=()):
+ init_images = []
+ for image in images:
+ image = image.to(self.torch_device).float()
+ image = image.repeat(batch_size, 1, 1, 1)
+ init_images.append(image)
+ return tuple(init_images)
+
+ def __encode_image(self, init_image):
+ init_latents = runEngine(self.engine["vae_encoder"], {"images": init_image}, self.stream)["latent"]
+ init_latents = 0.18215 * init_latents
+ return init_latents
+
+ def __encode_prompt(self, prompt, negative_prompt):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
+ Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
+ """
+ # Tokenize prompt
+ text_input_ids = (
+ self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ .input_ids.type(torch.int32)
+ .to(self.torch_device)
+ )
+
+ # NOTE: output tensor for CLIP must be cloned because it will be overwritten when called again for negative prompt
+ text_embeddings = runEngine(self.engine["clip"], {"input_ids": text_input_ids}, self.stream)[
+ "text_embeddings"
+ ].clone()
+
+ # Tokenize negative prompt
+ uncond_input_ids = (
+ self.tokenizer(
+ negative_prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ .input_ids.type(torch.int32)
+ .to(self.torch_device)
+ )
+ uncond_embeddings = runEngine(self.engine["clip"], {"input_ids": uncond_input_ids}, self.stream)[
+ "text_embeddings"
+ ]
+
+ # Concatenate the unconditional and text embeddings into a single batch to avoid doing two forward passes for classifier free guidance
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings]).to(dtype=torch.float16)
+
+ return text_embeddings
+
+ def __denoise_latent(
+ self, latents, text_embeddings, timesteps=None, step_offset=0, mask=None, masked_image_latents=None
+ ):
+ if not isinstance(timesteps, torch.Tensor):
+ timesteps = self.scheduler.timesteps
+ for step_index, timestep in enumerate(timesteps):
+ # Expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2)
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, timestep)
+ if isinstance(mask, torch.Tensor):
+ latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
+
+ # Predict the noise residual
+ timestep_float = timestep.float() if timestep.dtype != torch.float32 else timestep
+
+ noise_pred = runEngine(
+ self.engine["unet"],
+ {"sample": latent_model_input, "timestep": timestep_float, "encoder_hidden_states": text_embeddings},
+ self.stream,
+ )["latent"]
+
+ # Perform guidance
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + self._guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ latents = self.scheduler.step(noise_pred, timestep, latents).prev_sample
+
+ latents = 1.0 / 0.18215 * latents
+ return latents
+
+ def __decode_latent(self, latents):
+ images = runEngine(self.engine["vae"], {"latent": latents}, self.stream)["images"]
+ images = (images / 2 + 0.5).clamp(0, 1)
+ return images.cpu().permute(0, 2, 3, 1).float().numpy()
+
+ def __loadResources(self, image_height, image_width, batch_size):
+ self.stream = cudart.cudaStreamCreate()[1]
+
+ # Allocate buffers for TensorRT engine bindings
+ for model_name, obj in self.models.items():
+ self.engine[model_name].allocate_buffers(
+ shape_dict=obj.get_shape_dict(batch_size, image_height, image_width), device=self.torch_device
+ )
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ image: Union[torch.Tensor, PIL.Image.Image] = None,
+ strength: float = 0.8,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ image (`PIL.Image.Image`):
+ `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will
+ be masked out with `mask_image` and repainted according to `prompt`.
+ strength (`float`, *optional*, defaults to 0.8):
+ Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image`
+ will be used as a starting point, adding more noise to it the larger the `strength`. The number of
+ denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will
+ be maximum and the denoising process will run for the full number of iterations specified in
+ `num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
+ Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+
+ """
+ self.generator = generator
+ self.denoising_steps = num_inference_steps
+ self._guidance_scale = guidance_scale
+
+ # Pre-compute latent input scales and linear multistep coefficients
+ self.scheduler.set_timesteps(self.denoising_steps, device=self.torch_device)
+
+ # Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ prompt = [prompt]
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ raise ValueError(f"Expected prompt to be of type list or str but got {type(prompt)}")
+
+ if negative_prompt is None:
+ negative_prompt = [""] * batch_size
+
+ if negative_prompt is not None and isinstance(negative_prompt, str):
+ negative_prompt = [negative_prompt]
+
+ assert len(prompt) == len(negative_prompt)
+
+ if batch_size > self.max_batch_size:
+ raise ValueError(
+ f"Batch size {len(prompt)} is larger than allowed {self.max_batch_size}. If dynamic shape is used, then maximum batch size is 4"
+ )
+
+ # load resources
+ self.__loadResources(self.image_height, self.image_width, batch_size)
+
+ with torch.inference_mode(), torch.autocast("cuda"), trt.Runtime(TRT_LOGGER):
+ # Initialize timesteps
+ timesteps, t_start = self.__initialize_timesteps(self.denoising_steps, strength)
+ latent_timestep = timesteps[:1].repeat(batch_size)
+
+ # Pre-process input image
+ if isinstance(image, PIL.Image.Image):
+ image = preprocess_image(image)
+ init_image = self.__preprocess_images(batch_size, (image,))[0]
+
+ # VAE encode init image
+ init_latents = self.__encode_image(init_image)
+
+ # Add noise to latents using timesteps
+ noise = torch.randn(
+ init_latents.shape, generator=self.generator, device=self.torch_device, dtype=torch.float32
+ )
+ latents = self.scheduler.add_noise(init_latents, noise, latent_timestep)
+
+ # CLIP text encoder
+ text_embeddings = self.__encode_prompt(prompt, negative_prompt)
+
+ # UNet denoiser
+ latents = self.__denoise_latent(latents, text_embeddings, timesteps=timesteps, step_offset=t_start)
+
+ # VAE decode latent
+ images = self.__decode_latent(latents)
+
+ images, has_nsfw_concept = self.run_safety_checker(images, self.torch_device, text_embeddings.dtype)
+ images = self.numpy_to_pil(images)
+ return StableDiffusionPipelineOutput(images=images, nsfw_content_detected=has_nsfw_concept)
diff --git a/diffusers/examples/community/stable_diffusion_tensorrt_inpaint.py b/diffusers/examples/community/stable_diffusion_tensorrt_inpaint.py
new file mode 100644
index 0000000000000000000000000000000000000000..b6f6711a53e7c416100deadaa8d4f6a3169e2d39
--- /dev/null
+++ b/diffusers/examples/community/stable_diffusion_tensorrt_inpaint.py
@@ -0,0 +1,1268 @@
+#
+# Copyright 2024 The HuggingFace Inc. team.
+# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import gc
+import os
+from collections import OrderedDict
+from typing import List, Optional, Tuple, Union
+
+import numpy as np
+import onnx
+import onnx_graphsurgeon as gs
+import PIL.Image
+import tensorrt as trt
+import torch
+from cuda import cudart
+from huggingface_hub import snapshot_download
+from huggingface_hub.utils import validate_hf_hub_args
+from onnx import shape_inference
+from packaging import version
+from polygraphy import cuda
+from polygraphy.backend.common import bytes_from_path
+from polygraphy.backend.onnx.loader import fold_constants
+from polygraphy.backend.trt import (
+ CreateConfig,
+ Profile,
+ engine_from_bytes,
+ engine_from_network,
+ network_from_onnx_path,
+ save_engine,
+)
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
+
+from diffusers import DiffusionPipeline
+from diffusers.configuration_utils import FrozenDict, deprecate
+from diffusers.image_processor import VaeImageProcessor
+from diffusers.models import AutoencoderKL, UNet2DConditionModel
+from diffusers.pipelines.stable_diffusion import (
+ StableDiffusionPipelineOutput,
+ StableDiffusionSafetyChecker,
+)
+from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint import (
+ prepare_mask_and_masked_image,
+ retrieve_latents,
+)
+from diffusers.schedulers import DDIMScheduler
+from diffusers.utils import logging
+from diffusers.utils.torch_utils import randn_tensor
+
+
+"""
+Installation instructions
+python3 -m pip install --upgrade transformers diffusers>=0.16.0
+python3 -m pip install --upgrade tensorrt~=10.2.0
+python3 -m pip install --upgrade polygraphy>=0.47.0 onnx-graphsurgeon --extra-index-url https://pypi.ngc.nvidia.com
+python3 -m pip install onnxruntime
+"""
+
+TRT_LOGGER = trt.Logger(trt.Logger.ERROR)
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+# Map of numpy dtype -> torch dtype
+numpy_to_torch_dtype_dict = {
+ np.uint8: torch.uint8,
+ np.int8: torch.int8,
+ np.int16: torch.int16,
+ np.int32: torch.int32,
+ np.int64: torch.int64,
+ np.float16: torch.float16,
+ np.float32: torch.float32,
+ np.float64: torch.float64,
+ np.complex64: torch.complex64,
+ np.complex128: torch.complex128,
+}
+if np.version.full_version >= "1.24.0":
+ numpy_to_torch_dtype_dict[np.bool_] = torch.bool
+else:
+ numpy_to_torch_dtype_dict[np.bool] = torch.bool
+
+# Map of torch dtype -> numpy dtype
+torch_to_numpy_dtype_dict = {value: key for (key, value) in numpy_to_torch_dtype_dict.items()}
+
+
+def preprocess_image(image):
+ """
+ image: torch.Tensor
+ """
+ w, h = image.size
+ w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32
+ image = image.resize((w, h))
+ image = np.array(image).astype(np.float32) / 255.0
+ image = image[None].transpose(0, 3, 1, 2)
+ image = torch.from_numpy(image).contiguous()
+ return 2.0 * image - 1.0
+
+
+class Engine:
+ def __init__(self, engine_path):
+ self.engine_path = engine_path
+ self.engine = None
+ self.context = None
+ self.buffers = OrderedDict()
+ self.tensors = OrderedDict()
+
+ def __del__(self):
+ [buf.free() for buf in self.buffers.values() if isinstance(buf, cuda.DeviceArray)]
+ del self.engine
+ del self.context
+ del self.buffers
+ del self.tensors
+
+ def build(
+ self,
+ onnx_path,
+ fp16,
+ input_profile=None,
+ enable_all_tactics=False,
+ timing_cache=None,
+ ):
+ logger.warning(f"Building TensorRT engine for {onnx_path}: {self.engine_path}")
+ p = Profile()
+ if input_profile:
+ for name, dims in input_profile.items():
+ assert len(dims) == 3
+ p.add(name, min=dims[0], opt=dims[1], max=dims[2])
+
+ extra_build_args = {}
+ if not enable_all_tactics:
+ extra_build_args["tactic_sources"] = []
+
+ engine = engine_from_network(
+ network_from_onnx_path(onnx_path, flags=[trt.OnnxParserFlag.NATIVE_INSTANCENORM]),
+ config=CreateConfig(fp16=fp16, profiles=[p], load_timing_cache=timing_cache, **extra_build_args),
+ save_timing_cache=timing_cache,
+ )
+ save_engine(engine, path=self.engine_path)
+
+ def load(self):
+ logger.warning(f"Loading TensorRT engine: {self.engine_path}")
+ self.engine = engine_from_bytes(bytes_from_path(self.engine_path))
+
+ def activate(self):
+ self.context = self.engine.create_execution_context()
+
+ def allocate_buffers(self, shape_dict=None, device="cuda"):
+ for binding in range(self.engine.num_io_tensors):
+ name = self.engine.get_tensor_name(binding)
+ if shape_dict and name in shape_dict:
+ shape = shape_dict[name]
+ else:
+ shape = self.engine.get_tensor_shape(name)
+ dtype = trt.nptype(self.engine.get_tensor_dtype(name))
+ if self.engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT:
+ self.context.set_input_shape(name, shape)
+ tensor = torch.empty(tuple(shape), dtype=numpy_to_torch_dtype_dict[dtype]).to(device=device)
+ self.tensors[name] = tensor
+
+ def infer(self, feed_dict, stream):
+ for name, buf in feed_dict.items():
+ self.tensors[name].copy_(buf)
+ for name, tensor in self.tensors.items():
+ self.context.set_tensor_address(name, tensor.data_ptr())
+ noerror = self.context.execute_async_v3(stream)
+ if not noerror:
+ raise ValueError("ERROR: inference failed.")
+
+ return self.tensors
+
+
+class Optimizer:
+ def __init__(self, onnx_graph):
+ self.graph = gs.import_onnx(onnx_graph)
+
+ def cleanup(self, return_onnx=False):
+ self.graph.cleanup().toposort()
+ if return_onnx:
+ return gs.export_onnx(self.graph)
+
+ def select_outputs(self, keep, names=None):
+ self.graph.outputs = [self.graph.outputs[o] for o in keep]
+ if names:
+ for i, name in enumerate(names):
+ self.graph.outputs[i].name = name
+
+ def fold_constants(self, return_onnx=False):
+ onnx_graph = fold_constants(gs.export_onnx(self.graph), allow_onnxruntime_shape_inference=True)
+ self.graph = gs.import_onnx(onnx_graph)
+ if return_onnx:
+ return onnx_graph
+
+ def infer_shapes(self, return_onnx=False):
+ onnx_graph = gs.export_onnx(self.graph)
+ if onnx_graph.ByteSize() > 2147483648:
+ raise TypeError("ERROR: model size exceeds supported 2GB limit")
+ else:
+ onnx_graph = shape_inference.infer_shapes(onnx_graph)
+
+ self.graph = gs.import_onnx(onnx_graph)
+ if return_onnx:
+ return onnx_graph
+
+
+class BaseModel:
+ def __init__(self, model, fp16=False, device="cuda", max_batch_size=16, embedding_dim=768, text_maxlen=77):
+ self.model = model
+ self.name = "SD Model"
+ self.fp16 = fp16
+ self.device = device
+
+ self.min_batch = 1
+ self.max_batch = max_batch_size
+ self.min_image_shape = 256 # min image resolution: 256x256
+ self.max_image_shape = 1024 # max image resolution: 1024x1024
+ self.min_latent_shape = self.min_image_shape // 8
+ self.max_latent_shape = self.max_image_shape // 8
+
+ self.embedding_dim = embedding_dim
+ self.text_maxlen = text_maxlen
+
+ def get_model(self):
+ return self.model
+
+ def get_input_names(self):
+ pass
+
+ def get_output_names(self):
+ pass
+
+ def get_dynamic_axes(self):
+ return None
+
+ def get_sample_input(self, batch_size, image_height, image_width):
+ pass
+
+ def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape):
+ return None
+
+ def get_shape_dict(self, batch_size, image_height, image_width):
+ return None
+
+ def optimize(self, onnx_graph):
+ opt = Optimizer(onnx_graph)
+ opt.cleanup()
+ opt.fold_constants()
+ opt.infer_shapes()
+ onnx_opt_graph = opt.cleanup(return_onnx=True)
+ return onnx_opt_graph
+
+ def check_dims(self, batch_size, image_height, image_width):
+ assert batch_size >= self.min_batch and batch_size <= self.max_batch
+ assert image_height % 8 == 0 or image_width % 8 == 0
+ latent_height = image_height // 8
+ latent_width = image_width // 8
+ assert latent_height >= self.min_latent_shape and latent_height <= self.max_latent_shape
+ assert latent_width >= self.min_latent_shape and latent_width <= self.max_latent_shape
+ return (latent_height, latent_width)
+
+ def get_minmax_dims(self, batch_size, image_height, image_width, static_batch, static_shape):
+ min_batch = batch_size if static_batch else self.min_batch
+ max_batch = batch_size if static_batch else self.max_batch
+ latent_height = image_height // 8
+ latent_width = image_width // 8
+ min_image_height = image_height if static_shape else self.min_image_shape
+ max_image_height = image_height if static_shape else self.max_image_shape
+ min_image_width = image_width if static_shape else self.min_image_shape
+ max_image_width = image_width if static_shape else self.max_image_shape
+ min_latent_height = latent_height if static_shape else self.min_latent_shape
+ max_latent_height = latent_height if static_shape else self.max_latent_shape
+ min_latent_width = latent_width if static_shape else self.min_latent_shape
+ max_latent_width = latent_width if static_shape else self.max_latent_shape
+ return (
+ min_batch,
+ max_batch,
+ min_image_height,
+ max_image_height,
+ min_image_width,
+ max_image_width,
+ min_latent_height,
+ max_latent_height,
+ min_latent_width,
+ max_latent_width,
+ )
+
+
+def getOnnxPath(model_name, onnx_dir, opt=True):
+ return os.path.join(onnx_dir, model_name + (".opt" if opt else "") + ".onnx")
+
+
+def getEnginePath(model_name, engine_dir):
+ return os.path.join(engine_dir, model_name + ".plan")
+
+
+def build_engines(
+ models: dict,
+ engine_dir,
+ onnx_dir,
+ onnx_opset,
+ opt_image_height,
+ opt_image_width,
+ opt_batch_size=1,
+ force_engine_rebuild=False,
+ static_batch=False,
+ static_shape=True,
+ enable_all_tactics=False,
+ timing_cache=None,
+):
+ built_engines = {}
+ if not os.path.isdir(onnx_dir):
+ os.makedirs(onnx_dir)
+ if not os.path.isdir(engine_dir):
+ os.makedirs(engine_dir)
+
+ # Export models to ONNX
+ for model_name, model_obj in models.items():
+ engine_path = getEnginePath(model_name, engine_dir)
+ if force_engine_rebuild or not os.path.exists(engine_path):
+ logger.warning("Building Engines...")
+ logger.warning("Engine build can take a while to complete")
+ onnx_path = getOnnxPath(model_name, onnx_dir, opt=False)
+ onnx_opt_path = getOnnxPath(model_name, onnx_dir)
+ if force_engine_rebuild or not os.path.exists(onnx_opt_path):
+ if force_engine_rebuild or not os.path.exists(onnx_path):
+ logger.warning(f"Exporting model: {onnx_path}")
+ model = model_obj.get_model()
+ with torch.inference_mode(), torch.autocast("cuda"):
+ inputs = model_obj.get_sample_input(opt_batch_size, opt_image_height, opt_image_width)
+ torch.onnx.export(
+ model,
+ inputs,
+ onnx_path,
+ export_params=True,
+ opset_version=onnx_opset,
+ do_constant_folding=True,
+ input_names=model_obj.get_input_names(),
+ output_names=model_obj.get_output_names(),
+ dynamic_axes=model_obj.get_dynamic_axes(),
+ )
+ del model
+ torch.cuda.empty_cache()
+ gc.collect()
+ else:
+ logger.warning(f"Found cached model: {onnx_path}")
+
+ # Optimize onnx
+ if force_engine_rebuild or not os.path.exists(onnx_opt_path):
+ logger.warning(f"Generating optimizing model: {onnx_opt_path}")
+ onnx_opt_graph = model_obj.optimize(onnx.load(onnx_path))
+ onnx.save(onnx_opt_graph, onnx_opt_path)
+ else:
+ logger.warning(f"Found cached optimized model: {onnx_opt_path} ")
+
+ # Build TensorRT engines
+ for model_name, model_obj in models.items():
+ engine_path = getEnginePath(model_name, engine_dir)
+ engine = Engine(engine_path)
+ onnx_path = getOnnxPath(model_name, onnx_dir, opt=False)
+ onnx_opt_path = getOnnxPath(model_name, onnx_dir)
+
+ if force_engine_rebuild or not os.path.exists(engine.engine_path):
+ engine.build(
+ onnx_opt_path,
+ fp16=True,
+ input_profile=model_obj.get_input_profile(
+ opt_batch_size,
+ opt_image_height,
+ opt_image_width,
+ static_batch=static_batch,
+ static_shape=static_shape,
+ ),
+ timing_cache=timing_cache,
+ )
+ built_engines[model_name] = engine
+
+ # Load and activate TensorRT engines
+ for model_name, model_obj in models.items():
+ engine = built_engines[model_name]
+ engine.load()
+ engine.activate()
+
+ return built_engines
+
+
+def runEngine(engine, feed_dict, stream):
+ return engine.infer(feed_dict, stream)
+
+
+class CLIP(BaseModel):
+ def __init__(self, model, device, max_batch_size, embedding_dim):
+ super(CLIP, self).__init__(
+ model=model, device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim
+ )
+ self.name = "CLIP"
+
+ def get_input_names(self):
+ return ["input_ids"]
+
+ def get_output_names(self):
+ return ["text_embeddings", "pooler_output"]
+
+ def get_dynamic_axes(self):
+ return {"input_ids": {0: "B"}, "text_embeddings": {0: "B"}}
+
+ def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape):
+ self.check_dims(batch_size, image_height, image_width)
+ min_batch, max_batch, _, _, _, _, _, _, _, _ = self.get_minmax_dims(
+ batch_size, image_height, image_width, static_batch, static_shape
+ )
+ return {
+ "input_ids": [(min_batch, self.text_maxlen), (batch_size, self.text_maxlen), (max_batch, self.text_maxlen)]
+ }
+
+ def get_shape_dict(self, batch_size, image_height, image_width):
+ self.check_dims(batch_size, image_height, image_width)
+ return {
+ "input_ids": (batch_size, self.text_maxlen),
+ "text_embeddings": (batch_size, self.text_maxlen, self.embedding_dim),
+ }
+
+ def get_sample_input(self, batch_size, image_height, image_width):
+ self.check_dims(batch_size, image_height, image_width)
+ return torch.zeros(batch_size, self.text_maxlen, dtype=torch.int32, device=self.device)
+
+ def optimize(self, onnx_graph):
+ opt = Optimizer(onnx_graph)
+ opt.select_outputs([0]) # delete graph output#1
+ opt.cleanup()
+ opt.fold_constants()
+ opt.infer_shapes()
+ opt.select_outputs([0], names=["text_embeddings"]) # rename network output
+ opt_onnx_graph = opt.cleanup(return_onnx=True)
+ return opt_onnx_graph
+
+
+def make_CLIP(model, device, max_batch_size, embedding_dim, inpaint=False):
+ return CLIP(model, device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim)
+
+
+class UNet(BaseModel):
+ def __init__(
+ self, model, fp16=False, device="cuda", max_batch_size=16, embedding_dim=768, text_maxlen=77, unet_dim=4
+ ):
+ super(UNet, self).__init__(
+ model=model,
+ fp16=fp16,
+ device=device,
+ max_batch_size=max_batch_size,
+ embedding_dim=embedding_dim,
+ text_maxlen=text_maxlen,
+ )
+ self.unet_dim = unet_dim
+ self.name = "UNet"
+
+ def get_input_names(self):
+ return ["sample", "timestep", "encoder_hidden_states"]
+
+ def get_output_names(self):
+ return ["latent"]
+
+ def get_dynamic_axes(self):
+ return {
+ "sample": {0: "2B", 2: "H", 3: "W"},
+ "encoder_hidden_states": {0: "2B"},
+ "latent": {0: "2B", 2: "H", 3: "W"},
+ }
+
+ def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape):
+ latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
+ (
+ min_batch,
+ max_batch,
+ _,
+ _,
+ _,
+ _,
+ min_latent_height,
+ max_latent_height,
+ min_latent_width,
+ max_latent_width,
+ ) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_shape)
+ return {
+ "sample": [
+ (2 * min_batch, self.unet_dim, min_latent_height, min_latent_width),
+ (2 * batch_size, self.unet_dim, latent_height, latent_width),
+ (2 * max_batch, self.unet_dim, max_latent_height, max_latent_width),
+ ],
+ "encoder_hidden_states": [
+ (2 * min_batch, self.text_maxlen, self.embedding_dim),
+ (2 * batch_size, self.text_maxlen, self.embedding_dim),
+ (2 * max_batch, self.text_maxlen, self.embedding_dim),
+ ],
+ }
+
+ def get_shape_dict(self, batch_size, image_height, image_width):
+ latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
+ return {
+ "sample": (2 * batch_size, self.unet_dim, latent_height, latent_width),
+ "encoder_hidden_states": (2 * batch_size, self.text_maxlen, self.embedding_dim),
+ "latent": (2 * batch_size, 4, latent_height, latent_width),
+ }
+
+ def get_sample_input(self, batch_size, image_height, image_width):
+ latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
+ dtype = torch.float16 if self.fp16 else torch.float32
+ return (
+ torch.randn(
+ 2 * batch_size, self.unet_dim, latent_height, latent_width, dtype=torch.float32, device=self.device
+ ),
+ torch.tensor([1.0], dtype=torch.float32, device=self.device),
+ torch.randn(2 * batch_size, self.text_maxlen, self.embedding_dim, dtype=dtype, device=self.device),
+ )
+
+
+def make_UNet(model, device, max_batch_size, embedding_dim, inpaint=False, unet_dim=4):
+ return UNet(
+ model,
+ fp16=True,
+ device=device,
+ max_batch_size=max_batch_size,
+ embedding_dim=embedding_dim,
+ unet_dim=unet_dim,
+ )
+
+
+class VAE(BaseModel):
+ def __init__(self, model, device, max_batch_size, embedding_dim):
+ super(VAE, self).__init__(
+ model=model, device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim
+ )
+ self.name = "VAE decoder"
+
+ def get_input_names(self):
+ return ["latent"]
+
+ def get_output_names(self):
+ return ["images"]
+
+ def get_dynamic_axes(self):
+ return {"latent": {0: "B", 2: "H", 3: "W"}, "images": {0: "B", 2: "8H", 3: "8W"}}
+
+ def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape):
+ latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
+ (
+ min_batch,
+ max_batch,
+ _,
+ _,
+ _,
+ _,
+ min_latent_height,
+ max_latent_height,
+ min_latent_width,
+ max_latent_width,
+ ) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_shape)
+ return {
+ "latent": [
+ (min_batch, 4, min_latent_height, min_latent_width),
+ (batch_size, 4, latent_height, latent_width),
+ (max_batch, 4, max_latent_height, max_latent_width),
+ ]
+ }
+
+ def get_shape_dict(self, batch_size, image_height, image_width):
+ latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
+ return {
+ "latent": (batch_size, 4, latent_height, latent_width),
+ "images": (batch_size, 3, image_height, image_width),
+ }
+
+ def get_sample_input(self, batch_size, image_height, image_width):
+ latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
+ return torch.randn(batch_size, 4, latent_height, latent_width, dtype=torch.float32, device=self.device)
+
+
+def make_VAE(model, device, max_batch_size, embedding_dim, inpaint=False):
+ return VAE(model, device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim)
+
+
+class TorchVAEEncoder(torch.nn.Module):
+ def __init__(self, model):
+ super().__init__()
+ self.vae_encoder = model
+
+ def forward(self, x):
+ return self.vae_encoder.encode(x).latent_dist.sample()
+
+
+class VAEEncoder(BaseModel):
+ def __init__(self, model, device, max_batch_size, embedding_dim):
+ super(VAEEncoder, self).__init__(
+ model=model, device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim
+ )
+ self.name = "VAE encoder"
+
+ def get_model(self):
+ vae_encoder = TorchVAEEncoder(self.model)
+ return vae_encoder
+
+ def get_input_names(self):
+ return ["images"]
+
+ def get_output_names(self):
+ return ["latent"]
+
+ def get_dynamic_axes(self):
+ return {"images": {0: "B", 2: "8H", 3: "8W"}, "latent": {0: "B", 2: "H", 3: "W"}}
+
+ def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape):
+ assert batch_size >= self.min_batch and batch_size <= self.max_batch
+ min_batch = batch_size if static_batch else self.min_batch
+ max_batch = batch_size if static_batch else self.max_batch
+ self.check_dims(batch_size, image_height, image_width)
+ (
+ min_batch,
+ max_batch,
+ min_image_height,
+ max_image_height,
+ min_image_width,
+ max_image_width,
+ _,
+ _,
+ _,
+ _,
+ ) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_shape)
+
+ return {
+ "images": [
+ (min_batch, 3, min_image_height, min_image_width),
+ (batch_size, 3, image_height, image_width),
+ (max_batch, 3, max_image_height, max_image_width),
+ ]
+ }
+
+ def get_shape_dict(self, batch_size, image_height, image_width):
+ latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
+ return {
+ "images": (batch_size, 3, image_height, image_width),
+ "latent": (batch_size, 4, latent_height, latent_width),
+ }
+
+ def get_sample_input(self, batch_size, image_height, image_width):
+ self.check_dims(batch_size, image_height, image_width)
+ return torch.randn(batch_size, 3, image_height, image_width, dtype=torch.float32, device=self.device)
+
+
+def make_VAEEncoder(model, device, max_batch_size, embedding_dim, inpaint=False):
+ return VAEEncoder(model, device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim)
+
+
+class TensorRTStableDiffusionInpaintPipeline(DiffusionPipeline):
+ r"""
+ Pipeline for inpainting using TensorRT accelerated Stable Diffusion.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ feature_extractor ([`CLIPImageProcessor`]):
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+ """
+
+ _optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: DDIMScheduler,
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPImageProcessor,
+ image_encoder: CLIPVisionModelWithProjection = None,
+ requires_safety_checker: bool = True,
+ stages=["clip", "unet", "vae", "vae_encoder"],
+ image_height: int = 512,
+ image_width: int = 512,
+ max_batch_size: int = 16,
+ # ONNX export parameters
+ onnx_opset: int = 17,
+ onnx_dir: str = "onnx",
+ # TensorRT engine build parameters
+ engine_dir: str = "engine",
+ force_engine_rebuild: bool = False,
+ timing_cache: str = "timing_cache",
+ ):
+ super().__init__()
+
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
+ " file"
+ )
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["steps_offset"] = 1
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
+ )
+ deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["clip_sample"] = False
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if safety_checker is None and requires_safety_checker:
+ logger.warning(
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
+ )
+
+ if safety_checker is not None and feature_extractor is None:
+ raise ValueError(
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
+ )
+
+ is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
+ version.parse(unet.config._diffusers_version).base_version
+ ) < version.parse("0.9.0.dev0")
+ is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
+ deprecation_message = (
+ "The configuration file of the unet has set the default `sample_size` to smaller than"
+ " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
+ " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
+ " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
+ " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
+ " in the config might lead to incorrect results in future versions. If you have downloaded this"
+ " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
+ " the `unet/config.json` file"
+ )
+ deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(unet.config)
+ new_config["sample_size"] = 64
+ unet._internal_dict = FrozenDict(new_config)
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ image_encoder=image_encoder,
+ )
+
+ self.stages = stages
+ self.image_height, self.image_width = image_height, image_width
+ self.inpaint = True
+ self.onnx_opset = onnx_opset
+ self.onnx_dir = onnx_dir
+ self.engine_dir = engine_dir
+ self.force_engine_rebuild = force_engine_rebuild
+ self.timing_cache = timing_cache
+ self.build_static_batch = False
+ self.build_dynamic_shape = False
+
+ self.max_batch_size = max_batch_size
+ # TODO: Restrict batch size to 4 for larger image dimensions as a WAR for TensorRT limitation.
+ if self.build_dynamic_shape or self.image_height > 512 or self.image_width > 512:
+ self.max_batch_size = 4
+
+ self.stream = None # loaded in loadResources()
+ self.models = {} # loaded in __loadModels()
+ self.engine = {} # loaded in build_engines()
+
+ self.vae.forward = self.vae.decode
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
+
+ def __loadModels(self):
+ # Load pipeline models
+ self.embedding_dim = self.text_encoder.config.hidden_size
+ models_args = {
+ "device": self.torch_device,
+ "max_batch_size": self.max_batch_size,
+ "embedding_dim": self.embedding_dim,
+ "inpaint": self.inpaint,
+ }
+ if "clip" in self.stages:
+ self.models["clip"] = make_CLIP(self.text_encoder, **models_args)
+ if "unet" in self.stages:
+ self.models["unet"] = make_UNet(self.unet, **models_args, unet_dim=self.unet.config.in_channels)
+ if "vae" in self.stages:
+ self.models["vae"] = make_VAE(self.vae, **models_args)
+ if "vae_encoder" in self.stages:
+ self.models["vae_encoder"] = make_VAEEncoder(self.vae, **models_args)
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline
+
+ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
+ if isinstance(generator, list):
+ image_latents = [
+ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
+ for i in range(image.shape[0])
+ ]
+ image_latents = torch.cat(image_latents, dim=0)
+ else:
+ image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
+
+ image_latents = self.vae.config.scaling_factor * image_latents
+
+ return image_latents
+
+ def prepare_latents(
+ self,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ latents=None,
+ image=None,
+ timestep=None,
+ is_strength_max=True,
+ return_noise=False,
+ return_image_latents=False,
+ ):
+ shape = (
+ batch_size,
+ num_channels_latents,
+ int(height) // self.vae_scale_factor,
+ int(width) // self.vae_scale_factor,
+ )
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if (image is None or timestep is None) and not is_strength_max:
+ raise ValueError(
+ "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise."
+ "However, either the image or the noise timestep has not been provided."
+ )
+
+ if return_image_latents or (latents is None and not is_strength_max):
+ image = image.to(device=device, dtype=dtype)
+
+ if image.shape[1] == 4:
+ image_latents = image
+ else:
+ image_latents = self._encode_vae_image(image=image, generator=generator)
+ image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1)
+
+ if latents is None:
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ # if strength is 1. then initialise the latents to noise, else initial to image + noise
+ latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep)
+ # if pure noise then scale the initial latents by the Scheduler's init sigma
+ latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents
+ else:
+ noise = latents.to(device)
+ latents = noise * self.scheduler.init_noise_sigma
+
+ outputs = (latents,)
+
+ if return_noise:
+ outputs += (noise,)
+
+ if return_image_latents:
+ outputs += (image_latents,)
+
+ return outputs
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
+ def run_safety_checker(
+ self, image: Union[torch.Tensor, PIL.Image.Image], device: torch.device, dtype: torch.dtype
+ ) -> Tuple[Union[torch.Tensor, PIL.Image.Image], Optional[bool]]:
+ r"""
+ Runs the safety checker on the given image.
+ Args:
+ image (Union[torch.Tensor, PIL.Image.Image]): The input image to be checked.
+ device (torch.device): The device to run the safety checker on.
+ dtype (torch.dtype): The data type of the input image.
+ Returns:
+ (image, has_nsfw_concept) Tuple[Union[torch.Tensor, PIL.Image.Image], Optional[bool]]: A tuple containing the processed image and
+ a boolean indicating whether the image has a NSFW (Not Safe for Work) concept.
+ """
+ if self.safety_checker is None:
+ has_nsfw_concept = None
+ else:
+ if torch.is_tensor(image):
+ feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
+ else:
+ feature_extractor_input = self.image_processor.numpy_to_pil(image)
+ safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
+ image, has_nsfw_concept = self.safety_checker(
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
+ )
+ return image, has_nsfw_concept
+
+ @classmethod
+ @validate_hf_hub_args
+ def set_cached_folder(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
+ cache_dir = kwargs.pop("cache_dir", None)
+ proxies = kwargs.pop("proxies", None)
+ local_files_only = kwargs.pop("local_files_only", False)
+ token = kwargs.pop("token", None)
+ revision = kwargs.pop("revision", None)
+
+ cls.cached_folder = (
+ pretrained_model_name_or_path
+ if os.path.isdir(pretrained_model_name_or_path)
+ else snapshot_download(
+ pretrained_model_name_or_path,
+ cache_dir=cache_dir,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ token=token,
+ revision=revision,
+ )
+ )
+
+ def to(self, torch_device: Optional[Union[str, torch.device]] = None, silence_dtype_warnings: bool = False):
+ super().to(torch_device, silence_dtype_warnings=silence_dtype_warnings)
+
+ self.onnx_dir = os.path.join(self.cached_folder, self.onnx_dir)
+ self.engine_dir = os.path.join(self.cached_folder, self.engine_dir)
+ self.timing_cache = os.path.join(self.cached_folder, self.timing_cache)
+
+ # set device
+ self.torch_device = self._execution_device
+ logger.warning(f"Running inference on device: {self.torch_device}")
+
+ # load models
+ self.__loadModels()
+
+ # build engines
+ self.engine = build_engines(
+ self.models,
+ self.engine_dir,
+ self.onnx_dir,
+ self.onnx_opset,
+ opt_image_height=self.image_height,
+ opt_image_width=self.image_width,
+ force_engine_rebuild=self.force_engine_rebuild,
+ static_batch=self.build_static_batch,
+ static_shape=not self.build_dynamic_shape,
+ timing_cache=self.timing_cache,
+ )
+
+ return self
+
+ def __initialize_timesteps(self, num_inference_steps, strength):
+ self.scheduler.set_timesteps(num_inference_steps)
+ offset = self.scheduler.config.steps_offset if hasattr(self.scheduler, "steps_offset") else 0
+ init_timestep = int(num_inference_steps * strength) + offset
+ init_timestep = min(init_timestep, num_inference_steps)
+ t_start = max(num_inference_steps - init_timestep + offset, 0)
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :].to(self.torch_device)
+ return timesteps, num_inference_steps - t_start
+
+ def __preprocess_images(self, batch_size, images=()):
+ init_images = []
+ for image in images:
+ image = image.to(self.torch_device).float()
+ image = image.repeat(batch_size, 1, 1, 1)
+ init_images.append(image)
+ return tuple(init_images)
+
+ def __encode_image(self, init_image):
+ init_latents = runEngine(self.engine["vae_encoder"], {"images": init_image}, self.stream)["latent"]
+ init_latents = 0.18215 * init_latents
+ return init_latents
+
+ def __encode_prompt(self, prompt, negative_prompt):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
+ Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
+ """
+ # Tokenize prompt
+ text_input_ids = (
+ self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ .input_ids.type(torch.int32)
+ .to(self.torch_device)
+ )
+
+ # NOTE: output tensor for CLIP must be cloned because it will be overwritten when called again for negative prompt
+ text_embeddings = runEngine(self.engine["clip"], {"input_ids": text_input_ids}, self.stream)[
+ "text_embeddings"
+ ].clone()
+
+ # Tokenize negative prompt
+ uncond_input_ids = (
+ self.tokenizer(
+ negative_prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ .input_ids.type(torch.int32)
+ .to(self.torch_device)
+ )
+ uncond_embeddings = runEngine(self.engine["clip"], {"input_ids": uncond_input_ids}, self.stream)[
+ "text_embeddings"
+ ]
+
+ # Concatenate the unconditional and text embeddings into a single batch to avoid doing two forward passes for classifier free guidance
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings]).to(dtype=torch.float16)
+
+ return text_embeddings
+
+ def __denoise_latent(
+ self, latents, text_embeddings, timesteps=None, step_offset=0, mask=None, masked_image_latents=None
+ ):
+ if not isinstance(timesteps, torch.Tensor):
+ timesteps = self.scheduler.timesteps
+ for step_index, timestep in enumerate(timesteps):
+ # Expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2)
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, timestep)
+ if isinstance(mask, torch.Tensor):
+ latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
+
+ # Predict the noise residual
+ timestep_float = timestep.float() if timestep.dtype != torch.float32 else timestep
+
+ noise_pred = runEngine(
+ self.engine["unet"],
+ {"sample": latent_model_input, "timestep": timestep_float, "encoder_hidden_states": text_embeddings},
+ self.stream,
+ )["latent"]
+
+ # Perform guidance
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + self._guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ latents = self.scheduler.step(noise_pred, timestep, latents).prev_sample
+
+ latents = 1.0 / 0.18215 * latents
+ return latents
+
+ def __decode_latent(self, latents):
+ images = runEngine(self.engine["vae"], {"latent": latents}, self.stream)["images"]
+ images = (images / 2 + 0.5).clamp(0, 1)
+ return images.cpu().permute(0, 2, 3, 1).float().numpy()
+
+ def __loadResources(self, image_height, image_width, batch_size):
+ self.stream = cudart.cudaStreamCreate()[1]
+
+ # Allocate buffers for TensorRT engine bindings
+ for model_name, obj in self.models.items():
+ self.engine[model_name].allocate_buffers(
+ shape_dict=obj.get_shape_dict(batch_size, image_height, image_width), device=self.torch_device
+ )
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ image: Union[torch.Tensor, PIL.Image.Image] = None,
+ mask_image: Union[torch.Tensor, PIL.Image.Image] = None,
+ strength: float = 1.0,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ image (`PIL.Image.Image`):
+ `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will
+ be masked out with `mask_image` and repainted according to `prompt`.
+ mask_image (`PIL.Image.Image`):
+ `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
+ repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted
+ to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)
+ instead of 3, so the expected shape would be `(B, H, W, 1)`.
+ strength (`float`, *optional*, defaults to 0.8):
+ Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image`
+ will be used as a starting point, adding more noise to it the larger the `strength`. The number of
+ denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will
+ be maximum and the denoising process will run for the full number of iterations specified in
+ `num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
+ Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+
+ """
+ self.generator = generator
+ self.denoising_steps = num_inference_steps
+ self._guidance_scale = guidance_scale
+
+ # Pre-compute latent input scales and linear multistep coefficients
+ self.scheduler.set_timesteps(self.denoising_steps, device=self.torch_device)
+
+ # Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ prompt = [prompt]
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ raise ValueError(f"Expected prompt to be of type list or str but got {type(prompt)}")
+
+ if negative_prompt is None:
+ negative_prompt = [""] * batch_size
+
+ if negative_prompt is not None and isinstance(negative_prompt, str):
+ negative_prompt = [negative_prompt]
+
+ assert len(prompt) == len(negative_prompt)
+
+ if batch_size > self.max_batch_size:
+ raise ValueError(
+ f"Batch size {len(prompt)} is larger than allowed {self.max_batch_size}. If dynamic shape is used, then maximum batch size is 4"
+ )
+
+ # Validate image dimensions
+ mask_width, mask_height = mask_image.size
+ if mask_height != self.image_height or mask_width != self.image_width:
+ raise ValueError(
+ f"Input image height and width {self.image_height} and {self.image_width} are not equal to "
+ f"the respective dimensions of the mask image {mask_height} and {mask_width}"
+ )
+
+ # load resources
+ self.__loadResources(self.image_height, self.image_width, batch_size)
+
+ with torch.inference_mode(), torch.autocast("cuda"), trt.Runtime(TRT_LOGGER):
+ # Spatial dimensions of latent tensor
+ latent_height = self.image_height // 8
+ latent_width = self.image_width // 8
+
+ # Pre-process input images
+ mask, masked_image, init_image = self.__preprocess_images(
+ batch_size,
+ prepare_mask_and_masked_image(
+ image,
+ mask_image,
+ self.image_height,
+ self.image_width,
+ return_image=True,
+ ),
+ )
+
+ mask = torch.nn.functional.interpolate(mask, size=(latent_height, latent_width))
+ mask = torch.cat([mask] * 2)
+
+ # Initialize timesteps
+ timesteps, t_start = self.__initialize_timesteps(self.denoising_steps, strength)
+
+ # at which timestep to set the initial noise (n.b. 50% if strength is 0.5)
+ latent_timestep = timesteps[:1].repeat(batch_size)
+ # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
+ is_strength_max = strength == 1.0
+
+ # Pre-initialize latents
+ num_channels_latents = self.vae.config.latent_channels
+ latents_outputs = self.prepare_latents(
+ batch_size,
+ num_channels_latents,
+ self.image_height,
+ self.image_width,
+ torch.float32,
+ self.torch_device,
+ generator,
+ image=init_image,
+ timestep=latent_timestep,
+ is_strength_max=is_strength_max,
+ )
+
+ latents = latents_outputs[0]
+
+ # VAE encode masked image
+ masked_latents = self.__encode_image(masked_image)
+ masked_latents = torch.cat([masked_latents] * 2)
+
+ # CLIP text encoder
+ text_embeddings = self.__encode_prompt(prompt, negative_prompt)
+
+ # UNet denoiser
+ latents = self.__denoise_latent(
+ latents,
+ text_embeddings,
+ timesteps=timesteps,
+ step_offset=t_start,
+ mask=mask,
+ masked_image_latents=masked_latents,
+ )
+
+ # VAE decode latent
+ images = self.__decode_latent(latents)
+
+ images, has_nsfw_concept = self.run_safety_checker(images, self.torch_device, text_embeddings.dtype)
+ images = self.numpy_to_pil(images)
+ return StableDiffusionPipelineOutput(images=images, nsfw_content_detected=has_nsfw_concept)
diff --git a/diffusers/examples/community/stable_diffusion_tensorrt_txt2img.py b/diffusers/examples/community/stable_diffusion_tensorrt_txt2img.py
new file mode 100644
index 0000000000000000000000000000000000000000..f8761053ed1aae7144ce74b916ac1b63ef2c0f0d
--- /dev/null
+++ b/diffusers/examples/community/stable_diffusion_tensorrt_txt2img.py
@@ -0,0 +1,1051 @@
+#
+# Copyright 2024 The HuggingFace Inc. team.
+# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import gc
+import os
+from collections import OrderedDict
+from typing import List, Optional, Tuple, Union
+
+import numpy as np
+import onnx
+import onnx_graphsurgeon as gs
+import PIL.Image
+import tensorrt as trt
+import torch
+from cuda import cudart
+from huggingface_hub import snapshot_download
+from huggingface_hub.utils import validate_hf_hub_args
+from onnx import shape_inference
+from packaging import version
+from polygraphy import cuda
+from polygraphy.backend.common import bytes_from_path
+from polygraphy.backend.onnx.loader import fold_constants
+from polygraphy.backend.trt import (
+ CreateConfig,
+ Profile,
+ engine_from_bytes,
+ engine_from_network,
+ network_from_onnx_path,
+ save_engine,
+)
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
+
+from diffusers import DiffusionPipeline
+from diffusers.configuration_utils import FrozenDict, deprecate
+from diffusers.image_processor import VaeImageProcessor
+from diffusers.models import AutoencoderKL, UNet2DConditionModel
+from diffusers.pipelines.stable_diffusion import (
+ StableDiffusionPipelineOutput,
+ StableDiffusionSafetyChecker,
+)
+from diffusers.schedulers import DDIMScheduler
+from diffusers.utils import logging
+from diffusers.utils.torch_utils import randn_tensor
+
+
+"""
+Installation instructions
+python3 -m pip install --upgrade transformers diffusers>=0.16.0
+python3 -m pip install --upgrade tensorrt~=10.2.0
+python3 -m pip install --upgrade polygraphy>=0.47.0 onnx-graphsurgeon --extra-index-url https://pypi.ngc.nvidia.com
+python3 -m pip install onnxruntime
+"""
+
+TRT_LOGGER = trt.Logger(trt.Logger.ERROR)
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+# Map of numpy dtype -> torch dtype
+numpy_to_torch_dtype_dict = {
+ np.uint8: torch.uint8,
+ np.int8: torch.int8,
+ np.int16: torch.int16,
+ np.int32: torch.int32,
+ np.int64: torch.int64,
+ np.float16: torch.float16,
+ np.float32: torch.float32,
+ np.float64: torch.float64,
+ np.complex64: torch.complex64,
+ np.complex128: torch.complex128,
+}
+if np.version.full_version >= "1.24.0":
+ numpy_to_torch_dtype_dict[np.bool_] = torch.bool
+else:
+ numpy_to_torch_dtype_dict[np.bool] = torch.bool
+
+# Map of torch dtype -> numpy dtype
+torch_to_numpy_dtype_dict = {value: key for (key, value) in numpy_to_torch_dtype_dict.items()}
+
+
+class Engine:
+ def __init__(self, engine_path):
+ self.engine_path = engine_path
+ self.engine = None
+ self.context = None
+ self.buffers = OrderedDict()
+ self.tensors = OrderedDict()
+
+ def __del__(self):
+ [buf.free() for buf in self.buffers.values() if isinstance(buf, cuda.DeviceArray)]
+ del self.engine
+ del self.context
+ del self.buffers
+ del self.tensors
+
+ def build(
+ self,
+ onnx_path,
+ fp16,
+ input_profile=None,
+ enable_all_tactics=False,
+ timing_cache=None,
+ ):
+ logger.warning(f"Building TensorRT engine for {onnx_path}: {self.engine_path}")
+ p = Profile()
+ if input_profile:
+ for name, dims in input_profile.items():
+ assert len(dims) == 3
+ p.add(name, min=dims[0], opt=dims[1], max=dims[2])
+
+ extra_build_args = {}
+ if not enable_all_tactics:
+ extra_build_args["tactic_sources"] = []
+
+ engine = engine_from_network(
+ network_from_onnx_path(onnx_path, flags=[trt.OnnxParserFlag.NATIVE_INSTANCENORM]),
+ config=CreateConfig(fp16=fp16, profiles=[p], load_timing_cache=timing_cache, **extra_build_args),
+ save_timing_cache=timing_cache,
+ )
+ save_engine(engine, path=self.engine_path)
+
+ def load(self):
+ logger.warning(f"Loading TensorRT engine: {self.engine_path}")
+ self.engine = engine_from_bytes(bytes_from_path(self.engine_path))
+
+ def activate(self):
+ self.context = self.engine.create_execution_context()
+
+ def allocate_buffers(self, shape_dict=None, device="cuda"):
+ for binding in range(self.engine.num_io_tensors):
+ name = self.engine.get_tensor_name(binding)
+ if shape_dict and name in shape_dict:
+ shape = shape_dict[name]
+ else:
+ shape = self.engine.get_tensor_shape(name)
+ dtype = trt.nptype(self.engine.get_tensor_dtype(name))
+ if self.engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT:
+ self.context.set_input_shape(name, shape)
+ tensor = torch.empty(tuple(shape), dtype=numpy_to_torch_dtype_dict[dtype]).to(device=device)
+ self.tensors[name] = tensor
+
+ def infer(self, feed_dict, stream):
+ for name, buf in feed_dict.items():
+ self.tensors[name].copy_(buf)
+ for name, tensor in self.tensors.items():
+ self.context.set_tensor_address(name, tensor.data_ptr())
+ noerror = self.context.execute_async_v3(stream)
+ if not noerror:
+ raise ValueError("ERROR: inference failed.")
+
+ return self.tensors
+
+
+class Optimizer:
+ def __init__(self, onnx_graph):
+ self.graph = gs.import_onnx(onnx_graph)
+
+ def cleanup(self, return_onnx=False):
+ self.graph.cleanup().toposort()
+ if return_onnx:
+ return gs.export_onnx(self.graph)
+
+ def select_outputs(self, keep, names=None):
+ self.graph.outputs = [self.graph.outputs[o] for o in keep]
+ if names:
+ for i, name in enumerate(names):
+ self.graph.outputs[i].name = name
+
+ def fold_constants(self, return_onnx=False):
+ onnx_graph = fold_constants(gs.export_onnx(self.graph), allow_onnxruntime_shape_inference=True)
+ self.graph = gs.import_onnx(onnx_graph)
+ if return_onnx:
+ return onnx_graph
+
+ def infer_shapes(self, return_onnx=False):
+ onnx_graph = gs.export_onnx(self.graph)
+ if onnx_graph.ByteSize() > 2147483648:
+ raise TypeError("ERROR: model size exceeds supported 2GB limit")
+ else:
+ onnx_graph = shape_inference.infer_shapes(onnx_graph)
+
+ self.graph = gs.import_onnx(onnx_graph)
+ if return_onnx:
+ return onnx_graph
+
+
+class BaseModel:
+ def __init__(self, model, fp16=False, device="cuda", max_batch_size=16, embedding_dim=768, text_maxlen=77):
+ self.model = model
+ self.name = "SD Model"
+ self.fp16 = fp16
+ self.device = device
+
+ self.min_batch = 1
+ self.max_batch = max_batch_size
+ self.min_image_shape = 256 # min image resolution: 256x256
+ self.max_image_shape = 1024 # max image resolution: 1024x1024
+ self.min_latent_shape = self.min_image_shape // 8
+ self.max_latent_shape = self.max_image_shape // 8
+
+ self.embedding_dim = embedding_dim
+ self.text_maxlen = text_maxlen
+
+ def get_model(self):
+ return self.model
+
+ def get_input_names(self):
+ pass
+
+ def get_output_names(self):
+ pass
+
+ def get_dynamic_axes(self):
+ return None
+
+ def get_sample_input(self, batch_size, image_height, image_width):
+ pass
+
+ def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape):
+ return None
+
+ def get_shape_dict(self, batch_size, image_height, image_width):
+ return None
+
+ def optimize(self, onnx_graph):
+ opt = Optimizer(onnx_graph)
+ opt.cleanup()
+ opt.fold_constants()
+ opt.infer_shapes()
+ onnx_opt_graph = opt.cleanup(return_onnx=True)
+ return onnx_opt_graph
+
+ def check_dims(self, batch_size, image_height, image_width):
+ assert batch_size >= self.min_batch and batch_size <= self.max_batch
+ assert image_height % 8 == 0 or image_width % 8 == 0
+ latent_height = image_height // 8
+ latent_width = image_width // 8
+ assert latent_height >= self.min_latent_shape and latent_height <= self.max_latent_shape
+ assert latent_width >= self.min_latent_shape and latent_width <= self.max_latent_shape
+ return (latent_height, latent_width)
+
+ def get_minmax_dims(self, batch_size, image_height, image_width, static_batch, static_shape):
+ min_batch = batch_size if static_batch else self.min_batch
+ max_batch = batch_size if static_batch else self.max_batch
+ latent_height = image_height // 8
+ latent_width = image_width // 8
+ min_image_height = image_height if static_shape else self.min_image_shape
+ max_image_height = image_height if static_shape else self.max_image_shape
+ min_image_width = image_width if static_shape else self.min_image_shape
+ max_image_width = image_width if static_shape else self.max_image_shape
+ min_latent_height = latent_height if static_shape else self.min_latent_shape
+ max_latent_height = latent_height if static_shape else self.max_latent_shape
+ min_latent_width = latent_width if static_shape else self.min_latent_shape
+ max_latent_width = latent_width if static_shape else self.max_latent_shape
+ return (
+ min_batch,
+ max_batch,
+ min_image_height,
+ max_image_height,
+ min_image_width,
+ max_image_width,
+ min_latent_height,
+ max_latent_height,
+ min_latent_width,
+ max_latent_width,
+ )
+
+
+def getOnnxPath(model_name, onnx_dir, opt=True):
+ return os.path.join(onnx_dir, model_name + (".opt" if opt else "") + ".onnx")
+
+
+def getEnginePath(model_name, engine_dir):
+ return os.path.join(engine_dir, model_name + ".plan")
+
+
+def build_engines(
+ models: dict,
+ engine_dir,
+ onnx_dir,
+ onnx_opset,
+ opt_image_height,
+ opt_image_width,
+ opt_batch_size=1,
+ force_engine_rebuild=False,
+ static_batch=False,
+ static_shape=True,
+ enable_all_tactics=False,
+ timing_cache=None,
+):
+ built_engines = {}
+ if not os.path.isdir(onnx_dir):
+ os.makedirs(onnx_dir)
+ if not os.path.isdir(engine_dir):
+ os.makedirs(engine_dir)
+
+ # Export models to ONNX
+ for model_name, model_obj in models.items():
+ engine_path = getEnginePath(model_name, engine_dir)
+ if force_engine_rebuild or not os.path.exists(engine_path):
+ logger.warning("Building Engines...")
+ logger.warning("Engine build can take a while to complete")
+ onnx_path = getOnnxPath(model_name, onnx_dir, opt=False)
+ onnx_opt_path = getOnnxPath(model_name, onnx_dir)
+ if force_engine_rebuild or not os.path.exists(onnx_opt_path):
+ if force_engine_rebuild or not os.path.exists(onnx_path):
+ logger.warning(f"Exporting model: {onnx_path}")
+ model = model_obj.get_model()
+ with torch.inference_mode(), torch.autocast("cuda"):
+ inputs = model_obj.get_sample_input(opt_batch_size, opt_image_height, opt_image_width)
+ torch.onnx.export(
+ model,
+ inputs,
+ onnx_path,
+ export_params=True,
+ opset_version=onnx_opset,
+ do_constant_folding=True,
+ input_names=model_obj.get_input_names(),
+ output_names=model_obj.get_output_names(),
+ dynamic_axes=model_obj.get_dynamic_axes(),
+ )
+ del model
+ torch.cuda.empty_cache()
+ gc.collect()
+ else:
+ logger.warning(f"Found cached model: {onnx_path}")
+
+ # Optimize onnx
+ if force_engine_rebuild or not os.path.exists(onnx_opt_path):
+ logger.warning(f"Generating optimizing model: {onnx_opt_path}")
+ onnx_opt_graph = model_obj.optimize(onnx.load(onnx_path))
+ onnx.save(onnx_opt_graph, onnx_opt_path)
+ else:
+ logger.warning(f"Found cached optimized model: {onnx_opt_path} ")
+
+ # Build TensorRT engines
+ for model_name, model_obj in models.items():
+ engine_path = getEnginePath(model_name, engine_dir)
+ engine = Engine(engine_path)
+ onnx_path = getOnnxPath(model_name, onnx_dir, opt=False)
+ onnx_opt_path = getOnnxPath(model_name, onnx_dir)
+
+ if force_engine_rebuild or not os.path.exists(engine.engine_path):
+ engine.build(
+ onnx_opt_path,
+ fp16=True,
+ input_profile=model_obj.get_input_profile(
+ opt_batch_size,
+ opt_image_height,
+ opt_image_width,
+ static_batch=static_batch,
+ static_shape=static_shape,
+ ),
+ timing_cache=timing_cache,
+ )
+ built_engines[model_name] = engine
+
+ # Load and activate TensorRT engines
+ for model_name, model_obj in models.items():
+ engine = built_engines[model_name]
+ engine.load()
+ engine.activate()
+
+ return built_engines
+
+
+def runEngine(engine, feed_dict, stream):
+ return engine.infer(feed_dict, stream)
+
+
+class CLIP(BaseModel):
+ def __init__(self, model, device, max_batch_size, embedding_dim):
+ super(CLIP, self).__init__(
+ model=model, device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim
+ )
+ self.name = "CLIP"
+
+ def get_input_names(self):
+ return ["input_ids"]
+
+ def get_output_names(self):
+ return ["text_embeddings", "pooler_output"]
+
+ def get_dynamic_axes(self):
+ return {"input_ids": {0: "B"}, "text_embeddings": {0: "B"}}
+
+ def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape):
+ self.check_dims(batch_size, image_height, image_width)
+ min_batch, max_batch, _, _, _, _, _, _, _, _ = self.get_minmax_dims(
+ batch_size, image_height, image_width, static_batch, static_shape
+ )
+ return {
+ "input_ids": [(min_batch, self.text_maxlen), (batch_size, self.text_maxlen), (max_batch, self.text_maxlen)]
+ }
+
+ def get_shape_dict(self, batch_size, image_height, image_width):
+ self.check_dims(batch_size, image_height, image_width)
+ return {
+ "input_ids": (batch_size, self.text_maxlen),
+ "text_embeddings": (batch_size, self.text_maxlen, self.embedding_dim),
+ }
+
+ def get_sample_input(self, batch_size, image_height, image_width):
+ self.check_dims(batch_size, image_height, image_width)
+ return torch.zeros(batch_size, self.text_maxlen, dtype=torch.int32, device=self.device)
+
+ def optimize(self, onnx_graph):
+ opt = Optimizer(onnx_graph)
+ opt.select_outputs([0]) # delete graph output#1
+ opt.cleanup()
+ opt.fold_constants()
+ opt.infer_shapes()
+ opt.select_outputs([0], names=["text_embeddings"]) # rename network output
+ opt_onnx_graph = opt.cleanup(return_onnx=True)
+ return opt_onnx_graph
+
+
+def make_CLIP(model, device, max_batch_size, embedding_dim, inpaint=False):
+ return CLIP(model, device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim)
+
+
+class UNet(BaseModel):
+ def __init__(
+ self, model, fp16=False, device="cuda", max_batch_size=16, embedding_dim=768, text_maxlen=77, unet_dim=4
+ ):
+ super(UNet, self).__init__(
+ model=model,
+ fp16=fp16,
+ device=device,
+ max_batch_size=max_batch_size,
+ embedding_dim=embedding_dim,
+ text_maxlen=text_maxlen,
+ )
+ self.unet_dim = unet_dim
+ self.name = "UNet"
+
+ def get_input_names(self):
+ return ["sample", "timestep", "encoder_hidden_states"]
+
+ def get_output_names(self):
+ return ["latent"]
+
+ def get_dynamic_axes(self):
+ return {
+ "sample": {0: "2B", 2: "H", 3: "W"},
+ "encoder_hidden_states": {0: "2B"},
+ "latent": {0: "2B", 2: "H", 3: "W"},
+ }
+
+ def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape):
+ latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
+ (
+ min_batch,
+ max_batch,
+ _,
+ _,
+ _,
+ _,
+ min_latent_height,
+ max_latent_height,
+ min_latent_width,
+ max_latent_width,
+ ) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_shape)
+ return {
+ "sample": [
+ (2 * min_batch, self.unet_dim, min_latent_height, min_latent_width),
+ (2 * batch_size, self.unet_dim, latent_height, latent_width),
+ (2 * max_batch, self.unet_dim, max_latent_height, max_latent_width),
+ ],
+ "encoder_hidden_states": [
+ (2 * min_batch, self.text_maxlen, self.embedding_dim),
+ (2 * batch_size, self.text_maxlen, self.embedding_dim),
+ (2 * max_batch, self.text_maxlen, self.embedding_dim),
+ ],
+ }
+
+ def get_shape_dict(self, batch_size, image_height, image_width):
+ latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
+ return {
+ "sample": (2 * batch_size, self.unet_dim, latent_height, latent_width),
+ "encoder_hidden_states": (2 * batch_size, self.text_maxlen, self.embedding_dim),
+ "latent": (2 * batch_size, 4, latent_height, latent_width),
+ }
+
+ def get_sample_input(self, batch_size, image_height, image_width):
+ latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
+ dtype = torch.float16 if self.fp16 else torch.float32
+ return (
+ torch.randn(
+ 2 * batch_size, self.unet_dim, latent_height, latent_width, dtype=torch.float32, device=self.device
+ ),
+ torch.tensor([1.0], dtype=torch.float32, device=self.device),
+ torch.randn(2 * batch_size, self.text_maxlen, self.embedding_dim, dtype=dtype, device=self.device),
+ )
+
+
+def make_UNet(model, device, max_batch_size, embedding_dim, inpaint=False):
+ return UNet(
+ model,
+ fp16=True,
+ device=device,
+ max_batch_size=max_batch_size,
+ embedding_dim=embedding_dim,
+ unet_dim=(9 if inpaint else 4),
+ )
+
+
+class VAE(BaseModel):
+ def __init__(self, model, device, max_batch_size, embedding_dim):
+ super(VAE, self).__init__(
+ model=model, device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim
+ )
+ self.name = "VAE decoder"
+
+ def get_input_names(self):
+ return ["latent"]
+
+ def get_output_names(self):
+ return ["images"]
+
+ def get_dynamic_axes(self):
+ return {"latent": {0: "B", 2: "H", 3: "W"}, "images": {0: "B", 2: "8H", 3: "8W"}}
+
+ def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape):
+ latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
+ (
+ min_batch,
+ max_batch,
+ _,
+ _,
+ _,
+ _,
+ min_latent_height,
+ max_latent_height,
+ min_latent_width,
+ max_latent_width,
+ ) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_shape)
+ return {
+ "latent": [
+ (min_batch, 4, min_latent_height, min_latent_width),
+ (batch_size, 4, latent_height, latent_width),
+ (max_batch, 4, max_latent_height, max_latent_width),
+ ]
+ }
+
+ def get_shape_dict(self, batch_size, image_height, image_width):
+ latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
+ return {
+ "latent": (batch_size, 4, latent_height, latent_width),
+ "images": (batch_size, 3, image_height, image_width),
+ }
+
+ def get_sample_input(self, batch_size, image_height, image_width):
+ latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
+ return torch.randn(batch_size, 4, latent_height, latent_width, dtype=torch.float32, device=self.device)
+
+
+def make_VAE(model, device, max_batch_size, embedding_dim, inpaint=False):
+ return VAE(model, device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim)
+
+
+class TensorRTStableDiffusionPipeline(DiffusionPipeline):
+ r"""
+ Pipeline for text-to-image generation using TensorRT accelerated Stable Diffusion.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ feature_extractor ([`CLIPImageProcessor`]):
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+ """
+
+ _optional_components = ["safety_checker", "feature_extractor"]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: DDIMScheduler,
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPImageProcessor,
+ image_encoder: CLIPVisionModelWithProjection = None,
+ requires_safety_checker: bool = True,
+ stages=["clip", "unet", "vae"],
+ image_height: int = 768,
+ image_width: int = 768,
+ max_batch_size: int = 16,
+ # ONNX export parameters
+ onnx_opset: int = 18,
+ onnx_dir: str = "onnx",
+ # TensorRT engine build parameters
+ engine_dir: str = "engine",
+ force_engine_rebuild: bool = False,
+ timing_cache: str = "timing_cache",
+ ):
+ super().__init__()
+
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
+ " file"
+ )
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["steps_offset"] = 1
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
+ )
+ deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["clip_sample"] = False
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if safety_checker is None and requires_safety_checker:
+ logger.warning(
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
+ )
+
+ if safety_checker is not None and feature_extractor is None:
+ raise ValueError(
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
+ )
+
+ is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
+ version.parse(unet.config._diffusers_version).base_version
+ ) < version.parse("0.9.0.dev0")
+ is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
+ deprecation_message = (
+ "The configuration file of the unet has set the default `sample_size` to smaller than"
+ " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
+ " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
+ " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
+ " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
+ " in the config might lead to incorrect results in future versions. If you have downloaded this"
+ " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
+ " the `unet/config.json` file"
+ )
+ deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(unet.config)
+ new_config["sample_size"] = 64
+ unet._internal_dict = FrozenDict(new_config)
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ image_encoder=image_encoder,
+ )
+
+ self.stages = stages
+ self.image_height, self.image_width = image_height, image_width
+ self.inpaint = False
+ self.onnx_opset = onnx_opset
+ self.onnx_dir = onnx_dir
+ self.engine_dir = engine_dir
+ self.force_engine_rebuild = force_engine_rebuild
+ self.timing_cache = timing_cache
+ self.build_static_batch = False
+ self.build_dynamic_shape = False
+
+ self.max_batch_size = max_batch_size
+ # TODO: Restrict batch size to 4 for larger image dimensions as a WAR for TensorRT limitation.
+ if self.build_dynamic_shape or self.image_height > 512 or self.image_width > 512:
+ self.max_batch_size = 4
+
+ self.stream = None # loaded in loadResources()
+ self.models = {} # loaded in __loadModels()
+ self.engine = {} # loaded in build_engines()
+
+ self.vae.forward = self.vae.decode
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
+
+ def __loadModels(self):
+ # Load pipeline models
+ self.embedding_dim = self.text_encoder.config.hidden_size
+ models_args = {
+ "device": self.torch_device,
+ "max_batch_size": self.max_batch_size,
+ "embedding_dim": self.embedding_dim,
+ "inpaint": self.inpaint,
+ }
+ if "clip" in self.stages:
+ self.models["clip"] = make_CLIP(self.text_encoder, **models_args)
+ if "unet" in self.stages:
+ self.models["unet"] = make_UNet(self.unet, **models_args)
+ if "vae" in self.stages:
+ self.models["vae"] = make_VAE(self.vae, **models_args)
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
+ def prepare_latents(
+ self,
+ batch_size: int,
+ num_channels_latents: int,
+ height: int,
+ width: int,
+ dtype: torch.dtype,
+ device: torch.device,
+ generator: Union[torch.Generator, List[torch.Generator]],
+ latents: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ r"""
+ Prepare the latent vectors for diffusion.
+ Args:
+ batch_size (int): The number of samples in the batch.
+ num_channels_latents (int): The number of channels in the latent vectors.
+ height (int): The height of the latent vectors.
+ width (int): The width of the latent vectors.
+ dtype (torch.dtype): The data type of the latent vectors.
+ device (torch.device): The device to place the latent vectors on.
+ generator (Union[torch.Generator, List[torch.Generator]]): The generator(s) to use for random number generation.
+ latents (Optional[torch.Tensor]): The pre-existing latent vectors. If None, new latent vectors will be generated.
+ Returns:
+ torch.Tensor: The prepared latent vectors.
+ """
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
+ def run_safety_checker(
+ self, image: Union[torch.Tensor, PIL.Image.Image], device: torch.device, dtype: torch.dtype
+ ) -> Tuple[Union[torch.Tensor, PIL.Image.Image], Optional[bool]]:
+ r"""
+ Runs the safety checker on the given image.
+ Args:
+ image (Union[torch.Tensor, PIL.Image.Image]): The input image to be checked.
+ device (torch.device): The device to run the safety checker on.
+ dtype (torch.dtype): The data type of the input image.
+ Returns:
+ (image, has_nsfw_concept) Tuple[Union[torch.Tensor, PIL.Image.Image], Optional[bool]]: A tuple containing the processed image and
+ a boolean indicating whether the image has a NSFW (Not Safe for Work) concept.
+ """
+ if self.safety_checker is None:
+ has_nsfw_concept = None
+ else:
+ if torch.is_tensor(image):
+ feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
+ else:
+ feature_extractor_input = self.image_processor.numpy_to_pil(image)
+ safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
+ image, has_nsfw_concept = self.safety_checker(
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
+ )
+ return image, has_nsfw_concept
+
+ @classmethod
+ @validate_hf_hub_args
+ def set_cached_folder(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
+ cache_dir = kwargs.pop("cache_dir", None)
+ proxies = kwargs.pop("proxies", None)
+ local_files_only = kwargs.pop("local_files_only", False)
+ token = kwargs.pop("token", None)
+ revision = kwargs.pop("revision", None)
+
+ cls.cached_folder = (
+ pretrained_model_name_or_path
+ if os.path.isdir(pretrained_model_name_or_path)
+ else snapshot_download(
+ pretrained_model_name_or_path,
+ cache_dir=cache_dir,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ token=token,
+ revision=revision,
+ )
+ )
+
+ def to(self, torch_device: Optional[Union[str, torch.device]] = None, silence_dtype_warnings: bool = False):
+ super().to(torch_device, silence_dtype_warnings=silence_dtype_warnings)
+
+ self.onnx_dir = os.path.join(self.cached_folder, self.onnx_dir)
+ self.engine_dir = os.path.join(self.cached_folder, self.engine_dir)
+ self.timing_cache = os.path.join(self.cached_folder, self.timing_cache)
+
+ # set device
+ self.torch_device = self._execution_device
+ logger.warning(f"Running inference on device: {self.torch_device}")
+
+ # load models
+ self.__loadModels()
+
+ # build engines
+ self.engine = build_engines(
+ self.models,
+ self.engine_dir,
+ self.onnx_dir,
+ self.onnx_opset,
+ opt_image_height=self.image_height,
+ opt_image_width=self.image_width,
+ force_engine_rebuild=self.force_engine_rebuild,
+ static_batch=self.build_static_batch,
+ static_shape=not self.build_dynamic_shape,
+ timing_cache=self.timing_cache,
+ )
+
+ return self
+
+ def __encode_prompt(self, prompt, negative_prompt):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
+ Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
+ """
+ # Tokenize prompt
+ text_input_ids = (
+ self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ .input_ids.type(torch.int32)
+ .to(self.torch_device)
+ )
+
+ # NOTE: output tensor for CLIP must be cloned because it will be overwritten when called again for negative prompt
+ text_embeddings = runEngine(self.engine["clip"], {"input_ids": text_input_ids}, self.stream)[
+ "text_embeddings"
+ ].clone()
+
+ # Tokenize negative prompt
+ uncond_input_ids = (
+ self.tokenizer(
+ negative_prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ .input_ids.type(torch.int32)
+ .to(self.torch_device)
+ )
+ uncond_embeddings = runEngine(self.engine["clip"], {"input_ids": uncond_input_ids}, self.stream)[
+ "text_embeddings"
+ ]
+
+ # Concatenate the unconditional and text embeddings into a single batch to avoid doing two forward passes for classifier free guidance
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings]).to(dtype=torch.float16)
+
+ return text_embeddings
+
+ def __denoise_latent(
+ self, latents, text_embeddings, timesteps=None, step_offset=0, mask=None, masked_image_latents=None
+ ):
+ if not isinstance(timesteps, torch.Tensor):
+ timesteps = self.scheduler.timesteps
+ for step_index, timestep in enumerate(timesteps):
+ # Expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2)
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, timestep)
+ if isinstance(mask, torch.Tensor):
+ latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
+
+ # Predict the noise residual
+ timestep_float = timestep.float() if timestep.dtype != torch.float32 else timestep
+
+ noise_pred = runEngine(
+ self.engine["unet"],
+ {"sample": latent_model_input, "timestep": timestep_float, "encoder_hidden_states": text_embeddings},
+ self.stream,
+ )["latent"]
+
+ # Perform guidance
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + self._guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ latents = self.scheduler.step(noise_pred, timestep, latents).prev_sample
+
+ latents = 1.0 / 0.18215 * latents
+ return latents
+
+ def __decode_latent(self, latents):
+ images = runEngine(self.engine["vae"], {"latent": latents}, self.stream)["images"]
+ images = (images / 2 + 0.5).clamp(0, 1)
+ return images.cpu().permute(0, 2, 3, 1).float().numpy()
+
+ def __loadResources(self, image_height, image_width, batch_size):
+ self.stream = cudart.cudaStreamCreate()[1]
+
+ # Allocate buffers for TensorRT engine bindings
+ for model_name, obj in self.models.items():
+ self.engine[model_name].allocate_buffers(
+ shape_dict=obj.get_shape_dict(batch_size, image_height, image_width), device=self.torch_device
+ )
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
+ Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+
+ """
+ self.generator = generator
+ self.denoising_steps = num_inference_steps
+ self._guidance_scale = guidance_scale
+
+ # Pre-compute latent input scales and linear multistep coefficients
+ self.scheduler.set_timesteps(self.denoising_steps, device=self.torch_device)
+
+ # Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ prompt = [prompt]
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ raise ValueError(f"Expected prompt to be of type list or str but got {type(prompt)}")
+
+ if negative_prompt is None:
+ negative_prompt = [""] * batch_size
+
+ if negative_prompt is not None and isinstance(negative_prompt, str):
+ negative_prompt = [negative_prompt]
+
+ assert len(prompt) == len(negative_prompt)
+
+ if batch_size > self.max_batch_size:
+ raise ValueError(
+ f"Batch size {len(prompt)} is larger than allowed {self.max_batch_size}. If dynamic shape is used, then maximum batch size is 4"
+ )
+
+ # load resources
+ self.__loadResources(self.image_height, self.image_width, batch_size)
+
+ with torch.inference_mode(), torch.autocast("cuda"), trt.Runtime(TRT_LOGGER):
+ # CLIP text encoder
+ text_embeddings = self.__encode_prompt(prompt, negative_prompt)
+
+ # Pre-initialize latents
+ num_channels_latents = self.unet.config.in_channels
+ latents = self.prepare_latents(
+ batch_size,
+ num_channels_latents,
+ self.image_height,
+ self.image_width,
+ torch.float32,
+ self.torch_device,
+ generator,
+ )
+
+ # UNet denoiser
+ latents = self.__denoise_latent(latents, text_embeddings)
+
+ # VAE decode latent
+ images = self.__decode_latent(latents)
+
+ images, has_nsfw_concept = self.run_safety_checker(images, self.torch_device, text_embeddings.dtype)
+ images = self.numpy_to_pil(images)
+ return StableDiffusionPipelineOutput(images=images, nsfw_content_detected=has_nsfw_concept)
diff --git a/diffusers/examples/community/stable_diffusion_xl_reference.py b/diffusers/examples/community/stable_diffusion_xl_reference.py
new file mode 100644
index 0000000000000000000000000000000000000000..107afc1f8b7aadc77738910af1b92939d4f8f914
--- /dev/null
+++ b/diffusers/examples/community/stable_diffusion_xl_reference.py
@@ -0,0 +1,818 @@
+# Based on stable_diffusion_reference.py
+
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import PIL.Image
+import torch
+
+from diffusers import StableDiffusionXLPipeline
+from diffusers.models.attention import BasicTransformerBlock
+from diffusers.models.unets.unet_2d_blocks import (
+ CrossAttnDownBlock2D,
+ CrossAttnUpBlock2D,
+ DownBlock2D,
+ UpBlock2D,
+)
+from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput
+from diffusers.utils import PIL_INTERPOLATION, logging
+from diffusers.utils.torch_utils import randn_tensor
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import UniPCMultistepScheduler
+ >>> from diffusers.utils import load_image
+
+ >>> input_image = load_image("https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png")
+
+ >>> pipe = StableDiffusionXLReferencePipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0",
+ torch_dtype=torch.float16,
+ use_safetensors=True,
+ variant="fp16").to('cuda:0')
+
+ >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
+ >>> result_img = pipe(ref_image=input_image,
+ prompt="1girl",
+ num_inference_steps=20,
+ reference_attn=True,
+ reference_adain=True).images[0]
+
+ >>> result_img.show()
+ ```
+"""
+
+
+def torch_dfs(model: torch.nn.Module):
+ result = [model]
+ for child in model.children():
+ result += torch_dfs(child)
+ return result
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
+
+
+def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
+ """
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
+ """
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
+ # rescale the results from guidance (fixes overexposure)
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
+ return noise_cfg
+
+
+class StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline):
+ def _default_height_width(self, height, width, image):
+ # NOTE: It is possible that a list of images have different
+ # dimensions for each image, so just checking the first image
+ # is not _exactly_ correct, but it is simple.
+ while isinstance(image, list):
+ image = image[0]
+
+ if height is None:
+ if isinstance(image, PIL.Image.Image):
+ height = image.height
+ elif isinstance(image, torch.Tensor):
+ height = image.shape[2]
+
+ height = (height // 8) * 8 # round down to nearest multiple of 8
+
+ if width is None:
+ if isinstance(image, PIL.Image.Image):
+ width = image.width
+ elif isinstance(image, torch.Tensor):
+ width = image.shape[3]
+
+ width = (width // 8) * 8
+
+ return height, width
+
+ def prepare_image(
+ self,
+ image,
+ width,
+ height,
+ batch_size,
+ num_images_per_prompt,
+ device,
+ dtype,
+ do_classifier_free_guidance=False,
+ guess_mode=False,
+ ):
+ if not isinstance(image, torch.Tensor):
+ if isinstance(image, PIL.Image.Image):
+ image = [image]
+
+ if isinstance(image[0], PIL.Image.Image):
+ images = []
+
+ for image_ in image:
+ image_ = image_.convert("RGB")
+ image_ = image_.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
+ image_ = np.array(image_)
+ image_ = image_[None, :]
+ images.append(image_)
+
+ image = images
+
+ image = np.concatenate(image, axis=0)
+ image = np.array(image).astype(np.float32) / 255.0
+ image = (image - 0.5) / 0.5
+ image = image.transpose(0, 3, 1, 2)
+ image = torch.from_numpy(image)
+
+ elif isinstance(image[0], torch.Tensor):
+ image = torch.stack(image, dim=0)
+
+ image_batch_size = image.shape[0]
+
+ if image_batch_size == 1:
+ repeat_by = batch_size
+ else:
+ repeat_by = num_images_per_prompt
+
+ image = image.repeat_interleave(repeat_by, dim=0)
+
+ image = image.to(device=device, dtype=dtype)
+
+ if do_classifier_free_guidance and not guess_mode:
+ image = torch.cat([image] * 2)
+
+ return image
+
+ def prepare_ref_latents(self, refimage, batch_size, dtype, device, generator, do_classifier_free_guidance):
+ refimage = refimage.to(device=device)
+ if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
+ self.upcast_vae()
+ refimage = refimage.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
+ if refimage.dtype != self.vae.dtype:
+ refimage = refimage.to(dtype=self.vae.dtype)
+ # encode the mask image into latents space so we can concatenate it to the latents
+ if isinstance(generator, list):
+ ref_image_latents = [
+ self.vae.encode(refimage[i : i + 1]).latent_dist.sample(generator=generator[i])
+ for i in range(batch_size)
+ ]
+ ref_image_latents = torch.cat(ref_image_latents, dim=0)
+ else:
+ ref_image_latents = self.vae.encode(refimage).latent_dist.sample(generator=generator)
+ ref_image_latents = self.vae.config.scaling_factor * ref_image_latents
+
+ # duplicate mask and ref_image_latents for each generation per prompt, using mps friendly method
+ if ref_image_latents.shape[0] < batch_size:
+ if not batch_size % ref_image_latents.shape[0] == 0:
+ raise ValueError(
+ "The passed images and the required batch size don't match. Images are supposed to be duplicated"
+ f" to a total batch size of {batch_size}, but {ref_image_latents.shape[0]} images were passed."
+ " Make sure the number of images that you pass is divisible by the total requested batch size."
+ )
+ ref_image_latents = ref_image_latents.repeat(batch_size // ref_image_latents.shape[0], 1, 1, 1)
+
+ ref_image_latents = torch.cat([ref_image_latents] * 2) if do_classifier_free_guidance else ref_image_latents
+
+ # aligning device to prevent device errors when concating it with the latent model input
+ ref_image_latents = ref_image_latents.to(device=device, dtype=dtype)
+ return ref_image_latents
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ prompt_2: Optional[Union[str, List[str]]] = None,
+ ref_image: Union[torch.Tensor, PIL.Image.Image] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ denoising_end: Optional[float] = None,
+ guidance_scale: float = 5.0,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ pooled_prompt_embeds: Optional[torch.Tensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
+ callback_steps: int = 1,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ guidance_rescale: float = 0.0,
+ original_size: Optional[Tuple[int, int]] = None,
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
+ target_size: Optional[Tuple[int, int]] = None,
+ attention_auto_machine_weight: float = 1.0,
+ gn_auto_machine_weight: float = 1.0,
+ style_fidelity: float = 0.5,
+ reference_attn: bool = True,
+ reference_adain: bool = True,
+ ):
+ assert reference_attn or reference_adain, "`reference_attn` or `reference_adain` must be True."
+
+ # 0. Default height and width to unet
+ # height, width = self._default_height_width(height, width, ref_image)
+
+ height = height or self.default_sample_size * self.vae_scale_factor
+ width = width or self.default_sample_size * self.vae_scale_factor
+ original_size = original_size or (height, width)
+ target_size = target_size or (height, width)
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ prompt_2,
+ height,
+ width,
+ callback_steps,
+ negative_prompt,
+ negative_prompt_2,
+ prompt_embeds,
+ negative_prompt_embeds,
+ pooled_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ )
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ text_encoder_lora_scale = (
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
+ )
+ (
+ prompt_embeds,
+ negative_prompt_embeds,
+ pooled_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ ) = self.encode_prompt(
+ prompt=prompt,
+ prompt_2=prompt_2,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ do_classifier_free_guidance=do_classifier_free_guidance,
+ negative_prompt=negative_prompt,
+ negative_prompt_2=negative_prompt_2,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ lora_scale=text_encoder_lora_scale,
+ )
+ # 4. Preprocess reference image
+ ref_image = self.prepare_image(
+ image=ref_image,
+ width=width,
+ height=height,
+ batch_size=batch_size * num_images_per_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ dtype=prompt_embeds.dtype,
+ )
+
+ # 5. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+
+ timesteps = self.scheduler.timesteps
+
+ # 6. Prepare latent variables
+ num_channels_latents = self.unet.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+ # 7. Prepare reference latent variables
+ ref_image_latents = self.prepare_ref_latents(
+ ref_image,
+ batch_size * num_images_per_prompt,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ do_classifier_free_guidance,
+ )
+
+ # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 9. Modify self attebtion and group norm
+ MODE = "write"
+ uc_mask = (
+ torch.Tensor([1] * batch_size * num_images_per_prompt + [0] * batch_size * num_images_per_prompt)
+ .type_as(ref_image_latents)
+ .bool()
+ )
+
+ def hacked_basic_transformer_inner_forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ timestep: Optional[torch.LongTensor] = None,
+ cross_attention_kwargs: Dict[str, Any] = None,
+ class_labels: Optional[torch.LongTensor] = None,
+ ):
+ if self.use_ada_layer_norm:
+ norm_hidden_states = self.norm1(hidden_states, timestep)
+ elif self.use_ada_layer_norm_zero:
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
+ )
+ else:
+ norm_hidden_states = self.norm1(hidden_states)
+
+ # 1. Self-Attention
+ cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
+ if self.only_cross_attention:
+ attn_output = self.attn1(
+ norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
+ attention_mask=attention_mask,
+ **cross_attention_kwargs,
+ )
+ else:
+ if MODE == "write":
+ self.bank.append(norm_hidden_states.detach().clone())
+ attn_output = self.attn1(
+ norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
+ attention_mask=attention_mask,
+ **cross_attention_kwargs,
+ )
+ if MODE == "read":
+ if attention_auto_machine_weight > self.attn_weight:
+ attn_output_uc = self.attn1(
+ norm_hidden_states,
+ encoder_hidden_states=torch.cat([norm_hidden_states] + self.bank, dim=1),
+ # attention_mask=attention_mask,
+ **cross_attention_kwargs,
+ )
+ attn_output_c = attn_output_uc.clone()
+ if do_classifier_free_guidance and style_fidelity > 0:
+ attn_output_c[uc_mask] = self.attn1(
+ norm_hidden_states[uc_mask],
+ encoder_hidden_states=norm_hidden_states[uc_mask],
+ **cross_attention_kwargs,
+ )
+ attn_output = style_fidelity * attn_output_c + (1.0 - style_fidelity) * attn_output_uc
+ self.bank.clear()
+ else:
+ attn_output = self.attn1(
+ norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
+ attention_mask=attention_mask,
+ **cross_attention_kwargs,
+ )
+ if self.use_ada_layer_norm_zero:
+ attn_output = gate_msa.unsqueeze(1) * attn_output
+ hidden_states = attn_output + hidden_states
+
+ if self.attn2 is not None:
+ norm_hidden_states = (
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
+ )
+
+ # 2. Cross-Attention
+ attn_output = self.attn2(
+ norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=encoder_attention_mask,
+ **cross_attention_kwargs,
+ )
+ hidden_states = attn_output + hidden_states
+
+ # 3. Feed-forward
+ norm_hidden_states = self.norm3(hidden_states)
+
+ if self.use_ada_layer_norm_zero:
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
+
+ ff_output = self.ff(norm_hidden_states)
+
+ if self.use_ada_layer_norm_zero:
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
+
+ hidden_states = ff_output + hidden_states
+
+ return hidden_states
+
+ def hacked_mid_forward(self, *args, **kwargs):
+ eps = 1e-6
+ x = self.original_forward(*args, **kwargs)
+ if MODE == "write":
+ if gn_auto_machine_weight >= self.gn_weight:
+ var, mean = torch.var_mean(x, dim=(2, 3), keepdim=True, correction=0)
+ self.mean_bank.append(mean)
+ self.var_bank.append(var)
+ if MODE == "read":
+ if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
+ var, mean = torch.var_mean(x, dim=(2, 3), keepdim=True, correction=0)
+ std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
+ mean_acc = sum(self.mean_bank) / float(len(self.mean_bank))
+ var_acc = sum(self.var_bank) / float(len(self.var_bank))
+ std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
+ x_uc = (((x - mean) / std) * std_acc) + mean_acc
+ x_c = x_uc.clone()
+ if do_classifier_free_guidance and style_fidelity > 0:
+ x_c[uc_mask] = x[uc_mask]
+ x = style_fidelity * x_c + (1.0 - style_fidelity) * x_uc
+ self.mean_bank = []
+ self.var_bank = []
+ return x
+
+ def hack_CrossAttnDownBlock2D_forward(
+ self,
+ hidden_states: torch.Tensor,
+ temb: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ ):
+ eps = 1e-6
+
+ # TODO(Patrick, William) - attention mask is not used
+ output_states = ()
+
+ for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)):
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ cross_attention_kwargs=cross_attention_kwargs,
+ attention_mask=attention_mask,
+ encoder_attention_mask=encoder_attention_mask,
+ return_dict=False,
+ )[0]
+ if MODE == "write":
+ if gn_auto_machine_weight >= self.gn_weight:
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
+ self.mean_bank.append([mean])
+ self.var_bank.append([var])
+ if MODE == "read":
+ if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
+ std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
+ mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i]))
+ var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i]))
+ std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
+ hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc
+ hidden_states_c = hidden_states_uc.clone()
+ if do_classifier_free_guidance and style_fidelity > 0:
+ hidden_states_c[uc_mask] = hidden_states[uc_mask]
+ hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc
+
+ output_states = output_states + (hidden_states,)
+
+ if MODE == "read":
+ self.mean_bank = []
+ self.var_bank = []
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ output_states = output_states + (hidden_states,)
+
+ return hidden_states, output_states
+
+ def hacked_DownBlock2D_forward(self, hidden_states, temb=None, *args, **kwargs):
+ eps = 1e-6
+
+ output_states = ()
+
+ for i, resnet in enumerate(self.resnets):
+ hidden_states = resnet(hidden_states, temb)
+
+ if MODE == "write":
+ if gn_auto_machine_weight >= self.gn_weight:
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
+ self.mean_bank.append([mean])
+ self.var_bank.append([var])
+ if MODE == "read":
+ if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
+ std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
+ mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i]))
+ var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i]))
+ std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
+ hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc
+ hidden_states_c = hidden_states_uc.clone()
+ if do_classifier_free_guidance and style_fidelity > 0:
+ hidden_states_c[uc_mask] = hidden_states[uc_mask]
+ hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc
+
+ output_states = output_states + (hidden_states,)
+
+ if MODE == "read":
+ self.mean_bank = []
+ self.var_bank = []
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ output_states = output_states + (hidden_states,)
+
+ return hidden_states, output_states
+
+ def hacked_CrossAttnUpBlock2D_forward(
+ self,
+ hidden_states: torch.Tensor,
+ res_hidden_states_tuple: Tuple[torch.Tensor, ...],
+ temb: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ upsample_size: Optional[int] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ ):
+ eps = 1e-6
+ # TODO(Patrick, William) - attention mask is not used
+ for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)):
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ cross_attention_kwargs=cross_attention_kwargs,
+ attention_mask=attention_mask,
+ encoder_attention_mask=encoder_attention_mask,
+ return_dict=False,
+ )[0]
+
+ if MODE == "write":
+ if gn_auto_machine_weight >= self.gn_weight:
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
+ self.mean_bank.append([mean])
+ self.var_bank.append([var])
+ if MODE == "read":
+ if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
+ std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
+ mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i]))
+ var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i]))
+ std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
+ hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc
+ hidden_states_c = hidden_states_uc.clone()
+ if do_classifier_free_guidance and style_fidelity > 0:
+ hidden_states_c[uc_mask] = hidden_states[uc_mask]
+ hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc
+
+ if MODE == "read":
+ self.mean_bank = []
+ self.var_bank = []
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states, upsample_size)
+
+ return hidden_states
+
+ def hacked_UpBlock2D_forward(
+ self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, **kwargs
+ ):
+ eps = 1e-6
+ for i, resnet in enumerate(self.resnets):
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+ hidden_states = resnet(hidden_states, temb)
+
+ if MODE == "write":
+ if gn_auto_machine_weight >= self.gn_weight:
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
+ self.mean_bank.append([mean])
+ self.var_bank.append([var])
+ if MODE == "read":
+ if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
+ std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
+ mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i]))
+ var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i]))
+ std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
+ hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc
+ hidden_states_c = hidden_states_uc.clone()
+ if do_classifier_free_guidance and style_fidelity > 0:
+ hidden_states_c[uc_mask] = hidden_states[uc_mask]
+ hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc
+
+ if MODE == "read":
+ self.mean_bank = []
+ self.var_bank = []
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states, upsample_size)
+
+ return hidden_states
+
+ if reference_attn:
+ attn_modules = [module for module in torch_dfs(self.unet) if isinstance(module, BasicTransformerBlock)]
+ attn_modules = sorted(attn_modules, key=lambda x: -x.norm1.normalized_shape[0])
+
+ for i, module in enumerate(attn_modules):
+ module._original_inner_forward = module.forward
+ module.forward = hacked_basic_transformer_inner_forward.__get__(module, BasicTransformerBlock)
+ module.bank = []
+ module.attn_weight = float(i) / float(len(attn_modules))
+
+ if reference_adain:
+ gn_modules = [self.unet.mid_block]
+ self.unet.mid_block.gn_weight = 0
+
+ down_blocks = self.unet.down_blocks
+ for w, module in enumerate(down_blocks):
+ module.gn_weight = 1.0 - float(w) / float(len(down_blocks))
+ gn_modules.append(module)
+
+ up_blocks = self.unet.up_blocks
+ for w, module in enumerate(up_blocks):
+ module.gn_weight = float(w) / float(len(up_blocks))
+ gn_modules.append(module)
+
+ for i, module in enumerate(gn_modules):
+ if getattr(module, "original_forward", None) is None:
+ module.original_forward = module.forward
+ if i == 0:
+ # mid_block
+ module.forward = hacked_mid_forward.__get__(module, torch.nn.Module)
+ elif isinstance(module, CrossAttnDownBlock2D):
+ module.forward = hack_CrossAttnDownBlock2D_forward.__get__(module, CrossAttnDownBlock2D)
+ elif isinstance(module, DownBlock2D):
+ module.forward = hacked_DownBlock2D_forward.__get__(module, DownBlock2D)
+ elif isinstance(module, CrossAttnUpBlock2D):
+ module.forward = hacked_CrossAttnUpBlock2D_forward.__get__(module, CrossAttnUpBlock2D)
+ elif isinstance(module, UpBlock2D):
+ module.forward = hacked_UpBlock2D_forward.__get__(module, UpBlock2D)
+ module.mean_bank = []
+ module.var_bank = []
+ module.gn_weight *= 2
+
+ # 10. Prepare added time ids & embeddings
+ add_text_embeds = pooled_prompt_embeds
+ if self.text_encoder_2 is None:
+ text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
+ else:
+ text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
+
+ add_time_ids = self._get_add_time_ids(
+ original_size,
+ crops_coords_top_left,
+ target_size,
+ dtype=prompt_embeds.dtype,
+ text_encoder_projection_dim=text_encoder_projection_dim,
+ )
+
+ if do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
+ add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)
+
+ prompt_embeds = prompt_embeds.to(device)
+ add_text_embeds = add_text_embeds.to(device)
+ add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
+
+ # 11. Denoising loop
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+
+ # 10.1 Apply denoising_end
+ if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1:
+ discrete_timestep_cutoff = int(
+ round(
+ self.scheduler.config.num_train_timesteps
+ - (denoising_end * self.scheduler.config.num_train_timesteps)
+ )
+ )
+ num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
+ timesteps = timesteps[:num_inference_steps]
+
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
+
+ # ref only part
+ noise = randn_tensor(
+ ref_image_latents.shape, generator=generator, device=device, dtype=ref_image_latents.dtype
+ )
+ ref_xt = self.scheduler.add_noise(
+ ref_image_latents,
+ noise,
+ t.reshape(
+ 1,
+ ),
+ )
+ ref_xt = self.scheduler.scale_model_input(ref_xt, t)
+
+ MODE = "write"
+
+ self.unet(
+ ref_xt,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs=cross_attention_kwargs,
+ added_cond_kwargs=added_cond_kwargs,
+ return_dict=False,
+ )
+
+ # predict the noise residual
+ MODE = "read"
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs=cross_attention_kwargs,
+ added_cond_kwargs=added_cond_kwargs,
+ return_dict=False,
+ )[0]
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
+
+ if not output_type == "latent":
+ # make sure the VAE is in float32 mode, as it overflows in float16
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
+
+ if needs_upcasting:
+ self.upcast_vae()
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
+
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+
+ # cast back to fp16 if needed
+ if needs_upcasting:
+ self.vae.to(dtype=torch.float16)
+ else:
+ image = latents
+ return StableDiffusionXLPipelineOutput(images=image)
+
+ # apply watermark if available
+ if self.watermark is not None:
+ image = self.watermark.apply_watermark(image)
+
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload last model to CPU
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.final_offload_hook.offload()
+
+ if not return_dict:
+ return (image,)
+
+ return StableDiffusionXLPipelineOutput(images=image)
diff --git a/diffusers/examples/community/stable_unclip.py b/diffusers/examples/community/stable_unclip.py
new file mode 100644
index 0000000000000000000000000000000000000000..f13c4e0a490bdc114eb2dbc3cb5ed9253912f965
--- /dev/null
+++ b/diffusers/examples/community/stable_unclip.py
@@ -0,0 +1,288 @@
+import types
+from typing import List, Optional, Tuple, Union
+
+import torch
+from transformers import CLIPTextModelWithProjection, CLIPTokenizer
+from transformers.models.clip.modeling_clip import CLIPTextModelOutput
+
+from diffusers.models import PriorTransformer
+from diffusers.pipelines import DiffusionPipeline, StableDiffusionImageVariationPipeline
+from diffusers.schedulers import UnCLIPScheduler
+from diffusers.utils import logging
+from diffusers.utils.torch_utils import randn_tensor
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+def _encode_image(self, image, device, num_images_per_prompt, do_classifier_free_guidance):
+ image = image.to(device=device)
+ image_embeddings = image # take image as image_embeddings
+ image_embeddings = image_embeddings.unsqueeze(1)
+
+ # duplicate image embeddings for each generation per prompt, using mps friendly method
+ bs_embed, seq_len, _ = image_embeddings.shape
+ image_embeddings = image_embeddings.repeat(1, num_images_per_prompt, 1)
+ image_embeddings = image_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ if do_classifier_free_guidance:
+ uncond_embeddings = torch.zeros_like(image_embeddings)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ image_embeddings = torch.cat([uncond_embeddings, image_embeddings])
+
+ return image_embeddings
+
+
+class StableUnCLIPPipeline(DiffusionPipeline):
+ def __init__(
+ self,
+ prior: PriorTransformer,
+ tokenizer: CLIPTokenizer,
+ text_encoder: CLIPTextModelWithProjection,
+ prior_scheduler: UnCLIPScheduler,
+ decoder_pipe_kwargs: Optional[dict] = None,
+ ):
+ super().__init__()
+
+ decoder_pipe_kwargs = {"image_encoder": None} if decoder_pipe_kwargs is None else decoder_pipe_kwargs
+
+ decoder_pipe_kwargs["torch_dtype"] = decoder_pipe_kwargs.get("torch_dtype", None) or prior.dtype
+
+ self.decoder_pipe = StableDiffusionImageVariationPipeline.from_pretrained(
+ "lambdalabs/sd-image-variations-diffusers", **decoder_pipe_kwargs
+ )
+
+ # replace `_encode_image` method
+ self.decoder_pipe._encode_image = types.MethodType(_encode_image, self.decoder_pipe)
+
+ self.register_modules(
+ prior=prior,
+ tokenizer=tokenizer,
+ text_encoder=text_encoder,
+ prior_scheduler=prior_scheduler,
+ )
+
+ def _encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ text_model_output: Optional[Union[CLIPTextModelOutput, Tuple]] = None,
+ text_attention_mask: Optional[torch.Tensor] = None,
+ ):
+ if text_model_output is None:
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
+ # get prompt text embeddings
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ text_mask = text_inputs.attention_mask.bool().to(device)
+
+ if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
+ removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+ text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
+
+ text_encoder_output = self.text_encoder(text_input_ids.to(device))
+
+ text_embeddings = text_encoder_output.text_embeds
+ text_encoder_hidden_states = text_encoder_output.last_hidden_state
+
+ else:
+ batch_size = text_model_output[0].shape[0]
+ text_embeddings, text_encoder_hidden_states = text_model_output[0], text_model_output[1]
+ text_mask = text_attention_mask
+
+ text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
+ text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
+ text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0)
+
+ if do_classifier_free_guidance:
+ uncond_tokens = [""] * batch_size
+
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ uncond_text_mask = uncond_input.attention_mask.bool().to(device)
+ uncond_embeddings_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(device))
+
+ uncond_embeddings = uncond_embeddings_text_encoder_output.text_embeds
+ uncond_text_encoder_hidden_states = uncond_embeddings_text_encoder_output.last_hidden_state
+
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+
+ seq_len = uncond_embeddings.shape[1]
+ uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt)
+ uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len)
+
+ seq_len = uncond_text_encoder_hidden_states.shape[1]
+ uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1)
+ uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view(
+ batch_size * num_images_per_prompt, seq_len, -1
+ )
+ uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0)
+
+ # done duplicates
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
+ text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states])
+
+ text_mask = torch.cat([uncond_text_mask, text_mask])
+
+ return text_embeddings, text_encoder_hidden_states, text_mask
+
+ @property
+ def _execution_device(self):
+ r"""
+ Returns the device on which the pipeline's models will be executed. After calling
+ `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
+ hooks.
+ """
+ if self.device != torch.device("meta") or not hasattr(self.prior, "_hf_hook"):
+ return self.device
+ for module in self.prior.modules():
+ if (
+ hasattr(module, "_hf_hook")
+ and hasattr(module._hf_hook, "execution_device")
+ and module._hf_hook.execution_device is not None
+ ):
+ return torch.device(module._hf_hook.execution_device)
+ return self.device
+
+ def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ if latents.shape != shape:
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
+ latents = latents.to(device)
+
+ latents = latents * scheduler.init_noise_sigma
+ return latents
+
+ def to(self, torch_device: Optional[Union[str, torch.device]] = None):
+ self.decoder_pipe.to(torch_device)
+ super().to(torch_device)
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Optional[Union[str, List[str]]] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_images_per_prompt: int = 1,
+ prior_num_inference_steps: int = 25,
+ generator: Optional[torch.Generator] = None,
+ prior_latents: Optional[torch.Tensor] = None,
+ text_model_output: Optional[Union[CLIPTextModelOutput, Tuple]] = None,
+ text_attention_mask: Optional[torch.Tensor] = None,
+ prior_guidance_scale: float = 4.0,
+ decoder_guidance_scale: float = 8.0,
+ decoder_num_inference_steps: int = 50,
+ decoder_num_images_per_prompt: Optional[int] = 1,
+ decoder_eta: float = 0.0,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ ):
+ if prompt is not None:
+ if isinstance(prompt, str):
+ batch_size = 1
+ elif isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+ else:
+ batch_size = text_model_output[0].shape[0]
+
+ device = self._execution_device
+
+ batch_size = batch_size * num_images_per_prompt
+
+ do_classifier_free_guidance = prior_guidance_scale > 1.0 or decoder_guidance_scale > 1.0
+
+ text_embeddings, text_encoder_hidden_states, text_mask = self._encode_prompt(
+ prompt, device, num_images_per_prompt, do_classifier_free_guidance, text_model_output, text_attention_mask
+ )
+
+ # prior
+
+ self.prior_scheduler.set_timesteps(prior_num_inference_steps, device=device)
+ prior_timesteps_tensor = self.prior_scheduler.timesteps
+
+ embedding_dim = self.prior.config.embedding_dim
+
+ prior_latents = self.prepare_latents(
+ (batch_size, embedding_dim),
+ text_embeddings.dtype,
+ device,
+ generator,
+ prior_latents,
+ self.prior_scheduler,
+ )
+
+ for i, t in enumerate(self.progress_bar(prior_timesteps_tensor)):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([prior_latents] * 2) if do_classifier_free_guidance else prior_latents
+
+ predicted_image_embedding = self.prior(
+ latent_model_input,
+ timestep=t,
+ proj_embedding=text_embeddings,
+ encoder_hidden_states=text_encoder_hidden_states,
+ attention_mask=text_mask,
+ ).predicted_image_embedding
+
+ if do_classifier_free_guidance:
+ predicted_image_embedding_uncond, predicted_image_embedding_text = predicted_image_embedding.chunk(2)
+ predicted_image_embedding = predicted_image_embedding_uncond + prior_guidance_scale * (
+ predicted_image_embedding_text - predicted_image_embedding_uncond
+ )
+
+ if i + 1 == prior_timesteps_tensor.shape[0]:
+ prev_timestep = None
+ else:
+ prev_timestep = prior_timesteps_tensor[i + 1]
+
+ prior_latents = self.prior_scheduler.step(
+ predicted_image_embedding,
+ timestep=t,
+ sample=prior_latents,
+ generator=generator,
+ prev_timestep=prev_timestep,
+ ).prev_sample
+
+ prior_latents = self.prior.post_process_latents(prior_latents)
+
+ image_embeddings = prior_latents
+
+ output = self.decoder_pipe(
+ image=image_embeddings,
+ height=height,
+ width=width,
+ num_inference_steps=decoder_num_inference_steps,
+ guidance_scale=decoder_guidance_scale,
+ generator=generator,
+ output_type=output_type,
+ return_dict=return_dict,
+ num_images_per_prompt=decoder_num_images_per_prompt,
+ eta=decoder_eta,
+ )
+ return output
diff --git a/diffusers/examples/community/text_inpainting.py b/diffusers/examples/community/text_inpainting.py
new file mode 100644
index 0000000000000000000000000000000000000000..c4378ab96f28972988ce419e7370d582621cd69d
--- /dev/null
+++ b/diffusers/examples/community/text_inpainting.py
@@ -0,0 +1,240 @@
+from typing import Callable, List, Optional, Union
+
+import PIL.Image
+import torch
+from transformers import (
+ CLIPImageProcessor,
+ CLIPSegForImageSegmentation,
+ CLIPSegProcessor,
+ CLIPTextModel,
+ CLIPTokenizer,
+)
+
+from diffusers import DiffusionPipeline
+from diffusers.configuration_utils import FrozenDict
+from diffusers.models import AutoencoderKL, UNet2DConditionModel
+from diffusers.pipelines.pipeline_utils import StableDiffusionMixin
+from diffusers.pipelines.stable_diffusion import StableDiffusionInpaintPipeline
+from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
+from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
+from diffusers.utils import deprecate, logging
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class TextInpainting(DiffusionPipeline, StableDiffusionMixin):
+ r"""
+ Pipeline for text based inpainting using Stable Diffusion.
+ Uses CLIPSeg to get a mask from the given text, then calls the Inpainting pipeline with the generated mask
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ segmentation_model ([`CLIPSegForImageSegmentation`]):
+ CLIPSeg Model to generate mask from the given text. Please refer to the [model card]() for details.
+ segmentation_processor ([`CLIPSegProcessor`]):
+ CLIPSeg processor to get image, text features to translate prompt to English, if necessary. Please refer to the
+ [model card](https://huggingface.co/docs/transformers/model_doc/clipseg) for details.
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ feature_extractor ([`CLIPImageProcessor`]):
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+ """
+
+ def __init__(
+ self,
+ segmentation_model: CLIPSegForImageSegmentation,
+ segmentation_processor: CLIPSegProcessor,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPImageProcessor,
+ ):
+ super().__init__()
+
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
+ " file"
+ )
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["steps_offset"] = 1
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if hasattr(scheduler.config, "skip_prk_steps") and scheduler.config.skip_prk_steps is False:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration"
+ " `skip_prk_steps`. `skip_prk_steps` should be set to True in the configuration file. Please make"
+ " sure to update the config accordingly as not setting `skip_prk_steps` in the config might lead to"
+ " incorrect results in future versions. If you have downloaded this checkpoint from the Hugging Face"
+ " Hub, it would be very nice if you could open a Pull request for the"
+ " `scheduler/scheduler_config.json` file"
+ )
+ deprecate("skip_prk_steps not set", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["skip_prk_steps"] = True
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if safety_checker is None:
+ logger.warning(
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
+ )
+
+ self.register_modules(
+ segmentation_model=segmentation_model,
+ segmentation_processor=segmentation_processor,
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]],
+ image: Union[torch.Tensor, PIL.Image.Image],
+ text: str,
+ height: int = 512,
+ width: int = 512,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[torch.Generator] = None,
+ latents: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
+ callback_steps: int = 1,
+ **kwargs,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image generation.
+ image (`PIL.Image.Image`):
+ `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will
+ be masked out with `mask_image` and repainted according to `prompt`.
+ text (`str``):
+ The text to use to generate the mask.
+ height (`int`, *optional*, defaults to 512):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to 512):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator`, *optional*):
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
+ deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+
+ # We use the input text to generate the mask
+ inputs = self.segmentation_processor(
+ text=[text], images=[image], padding="max_length", return_tensors="pt"
+ ).to(self.device)
+ outputs = self.segmentation_model(**inputs)
+ mask = torch.sigmoid(outputs.logits).cpu().detach().unsqueeze(-1).numpy()
+ mask_pil = self.numpy_to_pil(mask)[0].resize(image.size)
+
+ # Run inpainting pipeline with the generated mask
+ inpainting_pipeline = StableDiffusionInpaintPipeline(
+ vae=self.vae,
+ text_encoder=self.text_encoder,
+ tokenizer=self.tokenizer,
+ unet=self.unet,
+ scheduler=self.scheduler,
+ safety_checker=self.safety_checker,
+ feature_extractor=self.feature_extractor,
+ )
+ return inpainting_pipeline(
+ prompt=prompt,
+ image=image,
+ mask_image=mask_pil,
+ height=height,
+ width=width,
+ num_inference_steps=num_inference_steps,
+ guidance_scale=guidance_scale,
+ negative_prompt=negative_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ eta=eta,
+ generator=generator,
+ latents=latents,
+ output_type=output_type,
+ return_dict=return_dict,
+ callback=callback,
+ callback_steps=callback_steps,
+ )
diff --git a/diffusers/examples/community/tiled_upscaling.py b/diffusers/examples/community/tiled_upscaling.py
new file mode 100644
index 0000000000000000000000000000000000000000..3fe7399c8aeddf9693fe8e2d07ac5bd696679dc2
--- /dev/null
+++ b/diffusers/examples/community/tiled_upscaling.py
@@ -0,0 +1,298 @@
+# Copyright 2024 Peter Willemsen . All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from typing import Callable, List, Optional, Union
+
+import numpy as np
+import PIL.Image
+import torch
+from PIL import Image
+from transformers import CLIPTextModel, CLIPTokenizer
+
+from diffusers.models import AutoencoderKL, UNet2DConditionModel
+from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale import StableDiffusionUpscalePipeline
+from diffusers.schedulers import DDIMScheduler, DDPMScheduler, LMSDiscreteScheduler, PNDMScheduler
+
+
+def make_transparency_mask(size, overlap_pixels, remove_borders=[]):
+ size_x = size[0] - overlap_pixels * 2
+ size_y = size[1] - overlap_pixels * 2
+ for letter in ["l", "r"]:
+ if letter in remove_borders:
+ size_x += overlap_pixels
+ for letter in ["t", "b"]:
+ if letter in remove_borders:
+ size_y += overlap_pixels
+ mask = np.ones((size_y, size_x), dtype=np.uint8) * 255
+ mask = np.pad(mask, mode="linear_ramp", pad_width=overlap_pixels, end_values=0)
+
+ if "l" in remove_borders:
+ mask = mask[:, overlap_pixels : mask.shape[1]]
+ if "r" in remove_borders:
+ mask = mask[:, 0 : mask.shape[1] - overlap_pixels]
+ if "t" in remove_borders:
+ mask = mask[overlap_pixels : mask.shape[0], :]
+ if "b" in remove_borders:
+ mask = mask[0 : mask.shape[0] - overlap_pixels, :]
+ return mask
+
+
+def clamp(n, smallest, largest):
+ return max(smallest, min(n, largest))
+
+
+def clamp_rect(rect: [int], min: [int], max: [int]):
+ return (
+ clamp(rect[0], min[0], max[0]),
+ clamp(rect[1], min[1], max[1]),
+ clamp(rect[2], min[0], max[0]),
+ clamp(rect[3], min[1], max[1]),
+ )
+
+
+def add_overlap_rect(rect: [int], overlap: int, image_size: [int]):
+ rect = list(rect)
+ rect[0] -= overlap
+ rect[1] -= overlap
+ rect[2] += overlap
+ rect[3] += overlap
+ rect = clamp_rect(rect, [0, 0], [image_size[0], image_size[1]])
+ return rect
+
+
+def squeeze_tile(tile, original_image, original_slice, slice_x):
+ result = Image.new("RGB", (tile.size[0] + original_slice, tile.size[1]))
+ result.paste(
+ original_image.resize((tile.size[0], tile.size[1]), Image.BICUBIC).crop(
+ (slice_x, 0, slice_x + original_slice, tile.size[1])
+ ),
+ (0, 0),
+ )
+ result.paste(tile, (original_slice, 0))
+ return result
+
+
+def unsqueeze_tile(tile, original_image_slice):
+ crop_rect = (original_image_slice * 4, 0, tile.size[0], tile.size[1])
+ tile = tile.crop(crop_rect)
+ return tile
+
+
+def next_divisible(n, d):
+ divisor = n % d
+ return n - divisor
+
+
+class StableDiffusionTiledUpscalePipeline(StableDiffusionUpscalePipeline):
+ r"""
+ Pipeline for tile-based text-guided image super-resolution using Stable Diffusion 2, trading memory for compute
+ to create gigantic images.
+
+ This model inherits from [`StableDiffusionUpscalePipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ low_res_scheduler ([`SchedulerMixin`]):
+ A scheduler used to add initial noise to the low res conditioning image. It must be an instance of
+ [`DDPMScheduler`].
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ """
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ low_res_scheduler: DDPMScheduler,
+ scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
+ max_noise_level: int = 350,
+ ):
+ super().__init__(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ low_res_scheduler=low_res_scheduler,
+ scheduler=scheduler,
+ max_noise_level=max_noise_level,
+ )
+
+ def _process_tile(self, original_image_slice, x, y, tile_size, tile_border, image, final_image, **kwargs):
+ torch.manual_seed(0)
+ crop_rect = (
+ min(image.size[0] - (tile_size + original_image_slice), x * tile_size),
+ min(image.size[1] - (tile_size + original_image_slice), y * tile_size),
+ min(image.size[0], (x + 1) * tile_size),
+ min(image.size[1], (y + 1) * tile_size),
+ )
+ crop_rect_with_overlap = add_overlap_rect(crop_rect, tile_border, image.size)
+ tile = image.crop(crop_rect_with_overlap)
+ translated_slice_x = ((crop_rect[0] + ((crop_rect[2] - crop_rect[0]) / 2)) / image.size[0]) * tile.size[0]
+ translated_slice_x = translated_slice_x - (original_image_slice / 2)
+ translated_slice_x = max(0, translated_slice_x)
+ to_input = squeeze_tile(tile, image, original_image_slice, translated_slice_x)
+ orig_input_size = to_input.size
+ to_input = to_input.resize((tile_size, tile_size), Image.BICUBIC)
+ upscaled_tile = super(StableDiffusionTiledUpscalePipeline, self).__call__(image=to_input, **kwargs).images[0]
+ upscaled_tile = upscaled_tile.resize((orig_input_size[0] * 4, orig_input_size[1] * 4), Image.BICUBIC)
+ upscaled_tile = unsqueeze_tile(upscaled_tile, original_image_slice)
+ upscaled_tile = upscaled_tile.resize((tile.size[0] * 4, tile.size[1] * 4), Image.BICUBIC)
+ remove_borders = []
+ if x == 0:
+ remove_borders.append("l")
+ elif crop_rect[2] == image.size[0]:
+ remove_borders.append("r")
+ if y == 0:
+ remove_borders.append("t")
+ elif crop_rect[3] == image.size[1]:
+ remove_borders.append("b")
+ transparency_mask = Image.fromarray(
+ make_transparency_mask(
+ (upscaled_tile.size[0], upscaled_tile.size[1]), tile_border * 4, remove_borders=remove_borders
+ ),
+ mode="L",
+ )
+ final_image.paste(
+ upscaled_tile, (crop_rect_with_overlap[0] * 4, crop_rect_with_overlap[1] * 4), transparency_mask
+ )
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]],
+ image: Union[PIL.Image.Image, List[PIL.Image.Image]],
+ num_inference_steps: int = 75,
+ guidance_scale: float = 9.0,
+ noise_level: int = 50,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[torch.Generator] = None,
+ latents: Optional[torch.Tensor] = None,
+ callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
+ callback_steps: int = 1,
+ tile_size: int = 128,
+ tile_border: int = 32,
+ original_image_slice: int = 32,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image generation.
+ image (`PIL.Image.Image` or List[`PIL.Image.Image`] or `torch.Tensor`):
+ `Image`, or tensor representing an image batch which will be upscaled. *
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator`, *optional*):
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
+ deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ tile_size (`int`, *optional*):
+ The size of the tiles. Too big can result in an OOM-error.
+ tile_border (`int`, *optional*):
+ The number of pixels around a tile to consider (bigger means less seams, too big can lead to an OOM-error).
+ original_image_slice (`int`, *optional*):
+ The amount of pixels of the original image to calculate with the current tile (bigger means more depth
+ is preserved, less blur occurs in the final image, too big can lead to an OOM-error or loss in detail).
+ callback (`Callable`, *optional*):
+ A function that take a callback function with a single argument, a dict,
+ that contains the (partially) processed image under "image",
+ as well as the progress (0 to 1, where 1 is completed) under "progress".
+
+ Returns: A PIL.Image that is 4 times larger than the original input image.
+
+ """
+
+ final_image = Image.new("RGB", (image.size[0] * 4, image.size[1] * 4))
+ tcx = math.ceil(image.size[0] / tile_size)
+ tcy = math.ceil(image.size[1] / tile_size)
+ total_tile_count = tcx * tcy
+ current_count = 0
+ for y in range(tcy):
+ for x in range(tcx):
+ self._process_tile(
+ original_image_slice,
+ x,
+ y,
+ tile_size,
+ tile_border,
+ image,
+ final_image,
+ prompt=prompt,
+ num_inference_steps=num_inference_steps,
+ guidance_scale=guidance_scale,
+ noise_level=noise_level,
+ negative_prompt=negative_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ eta=eta,
+ generator=generator,
+ latents=latents,
+ )
+ current_count += 1
+ if callback is not None:
+ callback({"progress": current_count / total_tile_count, "image": final_image})
+ return final_image
+
+
+def main():
+ # Run a demo
+ model_id = "stabilityai/stable-diffusion-x4-upscaler"
+ pipe = StableDiffusionTiledUpscalePipeline.from_pretrained(model_id, variant="fp16", torch_dtype=torch.float16)
+ pipe = pipe.to("cuda")
+ image = Image.open("../../docs/source/imgs/diffusers_library.jpg")
+
+ def callback(obj):
+ print(f"progress: {obj['progress']:.4f}")
+ obj["image"].save("diffusers_library_progress.jpg")
+
+ final_image = pipe(image=image, prompt="Black font, white background, vector", noise_level=40, callback=callback)
+ final_image.save("diffusers_library.jpg")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/diffusers/examples/community/unclip_image_interpolation.py b/diffusers/examples/community/unclip_image_interpolation.py
new file mode 100644
index 0000000000000000000000000000000000000000..210bd61ecd1d46523e3c48061703133e5a585fc0
--- /dev/null
+++ b/diffusers/examples/community/unclip_image_interpolation.py
@@ -0,0 +1,452 @@
+import inspect
+from typing import List, Optional, Union
+
+import PIL.Image
+import torch
+from torch.nn import functional as F
+from transformers import (
+ CLIPImageProcessor,
+ CLIPTextModelWithProjection,
+ CLIPTokenizer,
+ CLIPVisionModelWithProjection,
+)
+
+from diffusers import (
+ DiffusionPipeline,
+ ImagePipelineOutput,
+ UnCLIPScheduler,
+ UNet2DConditionModel,
+ UNet2DModel,
+)
+from diffusers.pipelines.unclip import UnCLIPTextProjModel
+from diffusers.utils import logging
+from diffusers.utils.torch_utils import randn_tensor
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+def slerp(val, low, high):
+ """
+ Find the interpolation point between the 'low' and 'high' values for the given 'val'. See https://en.wikipedia.org/wiki/Slerp for more details on the topic.
+ """
+ low_norm = low / torch.norm(low)
+ high_norm = high / torch.norm(high)
+ omega = torch.acos((low_norm * high_norm))
+ so = torch.sin(omega)
+ res = (torch.sin((1.0 - val) * omega) / so) * low + (torch.sin(val * omega) / so) * high
+ return res
+
+
+class UnCLIPImageInterpolationPipeline(DiffusionPipeline):
+ """
+ Pipeline to generate variations from an input image using unCLIP
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ text_encoder ([`CLIPTextModelWithProjection`]):
+ Frozen text-encoder.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ feature_extractor ([`CLIPImageProcessor`]):
+ Model that extracts features from generated images to be used as inputs for the `image_encoder`.
+ image_encoder ([`CLIPVisionModelWithProjection`]):
+ Frozen CLIP image-encoder. unCLIP Image Variation uses the vision portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPVisionModelWithProjection),
+ specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ text_proj ([`UnCLIPTextProjModel`]):
+ Utility class to prepare and combine the embeddings before they are passed to the decoder.
+ decoder ([`UNet2DConditionModel`]):
+ The decoder to invert the image embedding into an image.
+ super_res_first ([`UNet2DModel`]):
+ Super resolution unet. Used in all but the last step of the super resolution diffusion process.
+ super_res_last ([`UNet2DModel`]):
+ Super resolution unet. Used in the last step of the super resolution diffusion process.
+ decoder_scheduler ([`UnCLIPScheduler`]):
+ Scheduler used in the decoder denoising process. Just a modified DDPMScheduler.
+ super_res_scheduler ([`UnCLIPScheduler`]):
+ Scheduler used in the super resolution denoising process. Just a modified DDPMScheduler.
+
+ """
+
+ decoder: UNet2DConditionModel
+ text_proj: UnCLIPTextProjModel
+ text_encoder: CLIPTextModelWithProjection
+ tokenizer: CLIPTokenizer
+ feature_extractor: CLIPImageProcessor
+ image_encoder: CLIPVisionModelWithProjection
+ super_res_first: UNet2DModel
+ super_res_last: UNet2DModel
+
+ decoder_scheduler: UnCLIPScheduler
+ super_res_scheduler: UnCLIPScheduler
+
+ # Copied from diffusers.pipelines.unclip.pipeline_unclip_image_variation.UnCLIPImageVariationPipeline.__init__
+ def __init__(
+ self,
+ decoder: UNet2DConditionModel,
+ text_encoder: CLIPTextModelWithProjection,
+ tokenizer: CLIPTokenizer,
+ text_proj: UnCLIPTextProjModel,
+ feature_extractor: CLIPImageProcessor,
+ image_encoder: CLIPVisionModelWithProjection,
+ super_res_first: UNet2DModel,
+ super_res_last: UNet2DModel,
+ decoder_scheduler: UnCLIPScheduler,
+ super_res_scheduler: UnCLIPScheduler,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ decoder=decoder,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ text_proj=text_proj,
+ feature_extractor=feature_extractor,
+ image_encoder=image_encoder,
+ super_res_first=super_res_first,
+ super_res_last=super_res_last,
+ decoder_scheduler=decoder_scheduler,
+ super_res_scheduler=super_res_scheduler,
+ )
+
+ # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents
+ def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ if latents.shape != shape:
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
+ latents = latents.to(device)
+
+ latents = latents * scheduler.init_noise_sigma
+ return latents
+
+ # Copied from diffusers.pipelines.unclip.pipeline_unclip_image_variation.UnCLIPImageVariationPipeline._encode_prompt
+ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance):
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
+
+ # get prompt text embeddings
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ text_mask = text_inputs.attention_mask.bool().to(device)
+ text_encoder_output = self.text_encoder(text_input_ids.to(device))
+
+ prompt_embeds = text_encoder_output.text_embeds
+ text_encoder_hidden_states = text_encoder_output.last_hidden_state
+
+ prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+ text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
+ text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0)
+
+ if do_classifier_free_guidance:
+ uncond_tokens = [""] * batch_size
+
+ max_length = text_input_ids.shape[-1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ uncond_text_mask = uncond_input.attention_mask.bool().to(device)
+ negative_prompt_embeds_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(device))
+
+ negative_prompt_embeds = negative_prompt_embeds_text_encoder_output.text_embeds
+ uncond_text_encoder_hidden_states = negative_prompt_embeds_text_encoder_output.last_hidden_state
+
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+
+ seq_len = negative_prompt_embeds.shape[1]
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len)
+
+ seq_len = uncond_text_encoder_hidden_states.shape[1]
+ uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1)
+ uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view(
+ batch_size * num_images_per_prompt, seq_len, -1
+ )
+ uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0)
+
+ # done duplicates
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+ text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states])
+
+ text_mask = torch.cat([uncond_text_mask, text_mask])
+
+ return prompt_embeds, text_encoder_hidden_states, text_mask
+
+ # Copied from diffusers.pipelines.unclip.pipeline_unclip_image_variation.UnCLIPImageVariationPipeline._encode_image
+ def _encode_image(self, image, device, num_images_per_prompt, image_embeddings: Optional[torch.Tensor] = None):
+ dtype = next(self.image_encoder.parameters()).dtype
+
+ if image_embeddings is None:
+ if not isinstance(image, torch.Tensor):
+ image = self.feature_extractor(images=image, return_tensors="pt").pixel_values
+
+ image = image.to(device=device, dtype=dtype)
+ image_embeddings = self.image_encoder(image).image_embeds
+
+ image_embeddings = image_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
+
+ return image_embeddings
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ image: Optional[Union[List[PIL.Image.Image], torch.Tensor]] = None,
+ steps: int = 5,
+ decoder_num_inference_steps: int = 25,
+ super_res_num_inference_steps: int = 7,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ image_embeddings: Optional[torch.Tensor] = None,
+ decoder_latents: Optional[torch.Tensor] = None,
+ super_res_latents: Optional[torch.Tensor] = None,
+ decoder_guidance_scale: float = 8.0,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ ):
+ """
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ image (`List[PIL.Image.Image]` or `torch.Tensor`):
+ The images to use for the image interpolation. Only accepts a list of two PIL Images or If you provide a tensor, it needs to comply with the
+ configuration of
+ [this](https://huggingface.co/fusing/karlo-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json)
+ `CLIPImageProcessor` while still having a shape of two in the 0th dimension. Can be left to `None` only when `image_embeddings` are passed.
+ steps (`int`, *optional*, defaults to 5):
+ The number of interpolation images to generate.
+ decoder_num_inference_steps (`int`, *optional*, defaults to 25):
+ The number of denoising steps for the decoder. More denoising steps usually lead to a higher quality
+ image at the expense of slower inference.
+ super_res_num_inference_steps (`int`, *optional*, defaults to 7):
+ The number of denoising steps for super resolution. More denoising steps usually lead to a higher
+ quality image at the expense of slower inference.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ image_embeddings (`torch.Tensor`, *optional*):
+ Pre-defined image embeddings that can be derived from the image encoder. Pre-defined image embeddings
+ can be passed for tasks like image interpolations. `image` can the be left to `None`.
+ decoder_latents (`torch.Tensor` of shape (batch size, channels, height, width), *optional*):
+ Pre-generated noisy latents to be used as inputs for the decoder.
+ super_res_latents (`torch.Tensor` of shape (batch size, channels, super res height, super res width), *optional*):
+ Pre-generated noisy latents to be used as inputs for the decoder.
+ decoder_guidance_scale (`float`, *optional*, defaults to 4.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generated image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
+ """
+
+ batch_size = steps
+
+ device = self._execution_device
+
+ if isinstance(image, List):
+ if len(image) != 2:
+ raise AssertionError(
+ f"Expected 'image' List to be of size 2, but passed 'image' length is {len(image)}"
+ )
+ elif not (isinstance(image[0], PIL.Image.Image) and isinstance(image[0], PIL.Image.Image)):
+ raise AssertionError(
+ f"Expected 'image' List to contain PIL.Image.Image, but passed 'image' contents are {type(image[0])} and {type(image[1])}"
+ )
+ elif isinstance(image, torch.Tensor):
+ if image.shape[0] != 2:
+ raise AssertionError(
+ f"Expected 'image' to be torch.Tensor of shape 2 in 0th dimension, but passed 'image' size is {image.shape[0]}"
+ )
+ elif isinstance(image_embeddings, torch.Tensor):
+ if image_embeddings.shape[0] != 2:
+ raise AssertionError(
+ f"Expected 'image_embeddings' to be torch.Tensor of shape 2 in 0th dimension, but passed 'image_embeddings' shape is {image_embeddings.shape[0]}"
+ )
+ else:
+ raise AssertionError(
+ f"Expected 'image' or 'image_embeddings' to be not None with types List[PIL.Image] or torch.Tensor respectively. Received {type(image)} and {type(image_embeddings)} repsectively"
+ )
+
+ original_image_embeddings = self._encode_image(
+ image=image, device=device, num_images_per_prompt=1, image_embeddings=image_embeddings
+ )
+
+ image_embeddings = []
+
+ for interp_step in torch.linspace(0, 1, steps):
+ temp_image_embeddings = slerp(
+ interp_step, original_image_embeddings[0], original_image_embeddings[1]
+ ).unsqueeze(0)
+ image_embeddings.append(temp_image_embeddings)
+
+ image_embeddings = torch.cat(image_embeddings).to(device)
+
+ do_classifier_free_guidance = decoder_guidance_scale > 1.0
+
+ prompt_embeds, text_encoder_hidden_states, text_mask = self._encode_prompt(
+ prompt=["" for i in range(steps)],
+ device=device,
+ num_images_per_prompt=1,
+ do_classifier_free_guidance=do_classifier_free_guidance,
+ )
+
+ text_encoder_hidden_states, additive_clip_time_embeddings = self.text_proj(
+ image_embeddings=image_embeddings,
+ prompt_embeds=prompt_embeds,
+ text_encoder_hidden_states=text_encoder_hidden_states,
+ do_classifier_free_guidance=do_classifier_free_guidance,
+ )
+
+ if device.type == "mps":
+ # HACK: MPS: There is a panic when padding bool tensors,
+ # so cast to int tensor for the pad and back to bool afterwards
+ text_mask = text_mask.type(torch.int)
+ decoder_text_mask = F.pad(text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=1)
+ decoder_text_mask = decoder_text_mask.type(torch.bool)
+ else:
+ decoder_text_mask = F.pad(text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=True)
+
+ self.decoder_scheduler.set_timesteps(decoder_num_inference_steps, device=device)
+ decoder_timesteps_tensor = self.decoder_scheduler.timesteps
+
+ num_channels_latents = self.decoder.config.in_channels
+ height = self.decoder.config.sample_size
+ width = self.decoder.config.sample_size
+
+ # Get the decoder latents for 1 step and then repeat the same tensor for the entire batch to keep same noise across all interpolation steps.
+ decoder_latents = self.prepare_latents(
+ (1, num_channels_latents, height, width),
+ text_encoder_hidden_states.dtype,
+ device,
+ generator,
+ decoder_latents,
+ self.decoder_scheduler,
+ )
+ decoder_latents = decoder_latents.repeat((batch_size, 1, 1, 1))
+
+ for i, t in enumerate(self.progress_bar(decoder_timesteps_tensor)):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([decoder_latents] * 2) if do_classifier_free_guidance else decoder_latents
+
+ noise_pred = self.decoder(
+ sample=latent_model_input,
+ timestep=t,
+ encoder_hidden_states=text_encoder_hidden_states,
+ class_labels=additive_clip_time_embeddings,
+ attention_mask=decoder_text_mask,
+ ).sample
+
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred_uncond, _ = noise_pred_uncond.split(latent_model_input.shape[1], dim=1)
+ noise_pred_text, predicted_variance = noise_pred_text.split(latent_model_input.shape[1], dim=1)
+ noise_pred = noise_pred_uncond + decoder_guidance_scale * (noise_pred_text - noise_pred_uncond)
+ noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)
+
+ if i + 1 == decoder_timesteps_tensor.shape[0]:
+ prev_timestep = None
+ else:
+ prev_timestep = decoder_timesteps_tensor[i + 1]
+
+ # compute the previous noisy sample x_t -> x_t-1
+ decoder_latents = self.decoder_scheduler.step(
+ noise_pred, t, decoder_latents, prev_timestep=prev_timestep, generator=generator
+ ).prev_sample
+
+ decoder_latents = decoder_latents.clamp(-1, 1)
+
+ image_small = decoder_latents
+
+ # done decoder
+
+ # super res
+
+ self.super_res_scheduler.set_timesteps(super_res_num_inference_steps, device=device)
+ super_res_timesteps_tensor = self.super_res_scheduler.timesteps
+
+ channels = self.super_res_first.config.in_channels // 2
+ height = self.super_res_first.config.sample_size
+ width = self.super_res_first.config.sample_size
+
+ super_res_latents = self.prepare_latents(
+ (batch_size, channels, height, width),
+ image_small.dtype,
+ device,
+ generator,
+ super_res_latents,
+ self.super_res_scheduler,
+ )
+
+ if device.type == "mps":
+ # MPS does not support many interpolations
+ image_upscaled = F.interpolate(image_small, size=[height, width])
+ else:
+ interpolate_antialias = {}
+ if "antialias" in inspect.signature(F.interpolate).parameters:
+ interpolate_antialias["antialias"] = True
+
+ image_upscaled = F.interpolate(
+ image_small, size=[height, width], mode="bicubic", align_corners=False, **interpolate_antialias
+ )
+
+ for i, t in enumerate(self.progress_bar(super_res_timesteps_tensor)):
+ # no classifier free guidance
+
+ if i == super_res_timesteps_tensor.shape[0] - 1:
+ unet = self.super_res_last
+ else:
+ unet = self.super_res_first
+
+ latent_model_input = torch.cat([super_res_latents, image_upscaled], dim=1)
+
+ noise_pred = unet(
+ sample=latent_model_input,
+ timestep=t,
+ ).sample
+
+ if i + 1 == super_res_timesteps_tensor.shape[0]:
+ prev_timestep = None
+ else:
+ prev_timestep = super_res_timesteps_tensor[i + 1]
+
+ # compute the previous noisy sample x_t -> x_t-1
+ super_res_latents = self.super_res_scheduler.step(
+ noise_pred, t, super_res_latents, prev_timestep=prev_timestep, generator=generator
+ ).prev_sample
+
+ image = super_res_latents
+ # done super res
+
+ # post processing
+
+ image = image * 0.5 + 0.5
+ image = image.clamp(0, 1)
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image,)
+
+ return ImagePipelineOutput(images=image)
diff --git a/diffusers/examples/community/unclip_text_interpolation.py b/diffusers/examples/community/unclip_text_interpolation.py
new file mode 100644
index 0000000000000000000000000000000000000000..84f1c5a21f7a2c0ceb434815fe4e2b6272e83b01
--- /dev/null
+++ b/diffusers/examples/community/unclip_text_interpolation.py
@@ -0,0 +1,528 @@
+import inspect
+from typing import List, Optional, Tuple, Union
+
+import torch
+from torch.nn import functional as F
+from transformers import CLIPTextModelWithProjection, CLIPTokenizer
+from transformers.models.clip.modeling_clip import CLIPTextModelOutput
+
+from diffusers import (
+ DiffusionPipeline,
+ ImagePipelineOutput,
+ PriorTransformer,
+ UnCLIPScheduler,
+ UNet2DConditionModel,
+ UNet2DModel,
+)
+from diffusers.pipelines.unclip import UnCLIPTextProjModel
+from diffusers.utils import logging
+from diffusers.utils.torch_utils import randn_tensor
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+def slerp(val, low, high):
+ """
+ Find the interpolation point between the 'low' and 'high' values for the given 'val'. See https://en.wikipedia.org/wiki/Slerp for more details on the topic.
+ """
+ low_norm = low / torch.norm(low)
+ high_norm = high / torch.norm(high)
+ omega = torch.acos((low_norm * high_norm))
+ so = torch.sin(omega)
+ res = (torch.sin((1.0 - val) * omega) / so) * low + (torch.sin(val * omega) / so) * high
+ return res
+
+
+class UnCLIPTextInterpolationPipeline(DiffusionPipeline):
+ """
+ Pipeline for prompt-to-prompt interpolation on CLIP text embeddings and using the UnCLIP / Dall-E to decode them to images.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ text_encoder ([`CLIPTextModelWithProjection`]):
+ Frozen text-encoder.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ prior ([`PriorTransformer`]):
+ The canonical unCLIP prior to approximate the image embedding from the text embedding.
+ text_proj ([`UnCLIPTextProjModel`]):
+ Utility class to prepare and combine the embeddings before they are passed to the decoder.
+ decoder ([`UNet2DConditionModel`]):
+ The decoder to invert the image embedding into an image.
+ super_res_first ([`UNet2DModel`]):
+ Super resolution unet. Used in all but the last step of the super resolution diffusion process.
+ super_res_last ([`UNet2DModel`]):
+ Super resolution unet. Used in the last step of the super resolution diffusion process.
+ prior_scheduler ([`UnCLIPScheduler`]):
+ Scheduler used in the prior denoising process. Just a modified DDPMScheduler.
+ decoder_scheduler ([`UnCLIPScheduler`]):
+ Scheduler used in the decoder denoising process. Just a modified DDPMScheduler.
+ super_res_scheduler ([`UnCLIPScheduler`]):
+ Scheduler used in the super resolution denoising process. Just a modified DDPMScheduler.
+
+ """
+
+ prior: PriorTransformer
+ decoder: UNet2DConditionModel
+ text_proj: UnCLIPTextProjModel
+ text_encoder: CLIPTextModelWithProjection
+ tokenizer: CLIPTokenizer
+ super_res_first: UNet2DModel
+ super_res_last: UNet2DModel
+
+ prior_scheduler: UnCLIPScheduler
+ decoder_scheduler: UnCLIPScheduler
+ super_res_scheduler: UnCLIPScheduler
+
+ # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.__init__
+ def __init__(
+ self,
+ prior: PriorTransformer,
+ decoder: UNet2DConditionModel,
+ text_encoder: CLIPTextModelWithProjection,
+ tokenizer: CLIPTokenizer,
+ text_proj: UnCLIPTextProjModel,
+ super_res_first: UNet2DModel,
+ super_res_last: UNet2DModel,
+ prior_scheduler: UnCLIPScheduler,
+ decoder_scheduler: UnCLIPScheduler,
+ super_res_scheduler: UnCLIPScheduler,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ prior=prior,
+ decoder=decoder,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ text_proj=text_proj,
+ super_res_first=super_res_first,
+ super_res_last=super_res_last,
+ prior_scheduler=prior_scheduler,
+ decoder_scheduler=decoder_scheduler,
+ super_res_scheduler=super_res_scheduler,
+ )
+
+ # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents
+ def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ if latents.shape != shape:
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
+ latents = latents.to(device)
+
+ latents = latents * scheduler.init_noise_sigma
+ return latents
+
+ # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline._encode_prompt
+ def _encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ text_model_output: Optional[Union[CLIPTextModelOutput, Tuple]] = None,
+ text_attention_mask: Optional[torch.Tensor] = None,
+ ):
+ if text_model_output is None:
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
+ # get prompt text embeddings
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ text_mask = text_inputs.attention_mask.bool().to(device)
+
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = self.tokenizer.batch_decode(
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
+ )
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+ text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
+
+ text_encoder_output = self.text_encoder(text_input_ids.to(device))
+
+ prompt_embeds = text_encoder_output.text_embeds
+ text_encoder_hidden_states = text_encoder_output.last_hidden_state
+
+ else:
+ batch_size = text_model_output[0].shape[0]
+ prompt_embeds, text_encoder_hidden_states = text_model_output[0], text_model_output[1]
+ text_mask = text_attention_mask
+
+ prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+ text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
+ text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0)
+
+ if do_classifier_free_guidance:
+ uncond_tokens = [""] * batch_size
+
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ uncond_text_mask = uncond_input.attention_mask.bool().to(device)
+ negative_prompt_embeds_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(device))
+
+ negative_prompt_embeds = negative_prompt_embeds_text_encoder_output.text_embeds
+ uncond_text_encoder_hidden_states = negative_prompt_embeds_text_encoder_output.last_hidden_state
+
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+
+ seq_len = negative_prompt_embeds.shape[1]
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len)
+
+ seq_len = uncond_text_encoder_hidden_states.shape[1]
+ uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1)
+ uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view(
+ batch_size * num_images_per_prompt, seq_len, -1
+ )
+ uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0)
+
+ # done duplicates
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+ text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states])
+
+ text_mask = torch.cat([uncond_text_mask, text_mask])
+
+ return prompt_embeds, text_encoder_hidden_states, text_mask
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ start_prompt: str,
+ end_prompt: str,
+ steps: int = 5,
+ prior_num_inference_steps: int = 25,
+ decoder_num_inference_steps: int = 25,
+ super_res_num_inference_steps: int = 7,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ prior_guidance_scale: float = 4.0,
+ decoder_guidance_scale: float = 8.0,
+ enable_sequential_cpu_offload=True,
+ gpu_id=0,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ ):
+ """
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ start_prompt (`str`):
+ The prompt to start the image generation interpolation from.
+ end_prompt (`str`):
+ The prompt to end the image generation interpolation at.
+ steps (`int`, *optional*, defaults to 5):
+ The number of steps over which to interpolate from start_prompt to end_prompt. The pipeline returns
+ the same number of images as this value.
+ prior_num_inference_steps (`int`, *optional*, defaults to 25):
+ The number of denoising steps for the prior. More denoising steps usually lead to a higher quality
+ image at the expense of slower inference.
+ decoder_num_inference_steps (`int`, *optional*, defaults to 25):
+ The number of denoising steps for the decoder. More denoising steps usually lead to a higher quality
+ image at the expense of slower inference.
+ super_res_num_inference_steps (`int`, *optional*, defaults to 7):
+ The number of denoising steps for super resolution. More denoising steps usually lead to a higher
+ quality image at the expense of slower inference.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ prior_guidance_scale (`float`, *optional*, defaults to 4.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ decoder_guidance_scale (`float`, *optional*, defaults to 4.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generated image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ enable_sequential_cpu_offload (`bool`, *optional*, defaults to `True`):
+ If True, offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's
+ models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only
+ when their specific submodule has its `forward` method called.
+ gpu_id (`int`, *optional*, defaults to `0`):
+ The gpu_id to be passed to enable_sequential_cpu_offload. Only works when enable_sequential_cpu_offload is set to True.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
+ """
+
+ if not isinstance(start_prompt, str) or not isinstance(end_prompt, str):
+ raise ValueError(
+ f"`start_prompt` and `end_prompt` should be of type `str` but got {type(start_prompt)} and"
+ f" {type(end_prompt)} instead"
+ )
+
+ if enable_sequential_cpu_offload:
+ self.enable_sequential_cpu_offload(gpu_id=gpu_id)
+
+ device = self._execution_device
+
+ # Turn the prompts into embeddings.
+ inputs = self.tokenizer(
+ [start_prompt, end_prompt],
+ padding="max_length",
+ truncation=True,
+ max_length=self.tokenizer.model_max_length,
+ return_tensors="pt",
+ )
+ inputs.to(device)
+ text_model_output = self.text_encoder(**inputs)
+
+ text_attention_mask = torch.max(inputs.attention_mask[0], inputs.attention_mask[1])
+ text_attention_mask = torch.cat([text_attention_mask.unsqueeze(0)] * steps).to(device)
+
+ # Interpolate from the start to end prompt using slerp and add the generated images to an image output pipeline
+ batch_text_embeds = []
+ batch_last_hidden_state = []
+
+ for interp_val in torch.linspace(0, 1, steps):
+ text_embeds = slerp(interp_val, text_model_output.text_embeds[0], text_model_output.text_embeds[1])
+ last_hidden_state = slerp(
+ interp_val, text_model_output.last_hidden_state[0], text_model_output.last_hidden_state[1]
+ )
+ batch_text_embeds.append(text_embeds.unsqueeze(0))
+ batch_last_hidden_state.append(last_hidden_state.unsqueeze(0))
+
+ batch_text_embeds = torch.cat(batch_text_embeds)
+ batch_last_hidden_state = torch.cat(batch_last_hidden_state)
+
+ text_model_output = CLIPTextModelOutput(
+ text_embeds=batch_text_embeds, last_hidden_state=batch_last_hidden_state
+ )
+
+ batch_size = text_model_output[0].shape[0]
+
+ do_classifier_free_guidance = prior_guidance_scale > 1.0 or decoder_guidance_scale > 1.0
+
+ prompt_embeds, text_encoder_hidden_states, text_mask = self._encode_prompt(
+ prompt=None,
+ device=device,
+ num_images_per_prompt=1,
+ do_classifier_free_guidance=do_classifier_free_guidance,
+ text_model_output=text_model_output,
+ text_attention_mask=text_attention_mask,
+ )
+
+ # prior
+
+ self.prior_scheduler.set_timesteps(prior_num_inference_steps, device=device)
+ prior_timesteps_tensor = self.prior_scheduler.timesteps
+
+ embedding_dim = self.prior.config.embedding_dim
+
+ prior_latents = self.prepare_latents(
+ (batch_size, embedding_dim),
+ prompt_embeds.dtype,
+ device,
+ generator,
+ None,
+ self.prior_scheduler,
+ )
+
+ for i, t in enumerate(self.progress_bar(prior_timesteps_tensor)):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([prior_latents] * 2) if do_classifier_free_guidance else prior_latents
+
+ predicted_image_embedding = self.prior(
+ latent_model_input,
+ timestep=t,
+ proj_embedding=prompt_embeds,
+ encoder_hidden_states=text_encoder_hidden_states,
+ attention_mask=text_mask,
+ ).predicted_image_embedding
+
+ if do_classifier_free_guidance:
+ predicted_image_embedding_uncond, predicted_image_embedding_text = predicted_image_embedding.chunk(2)
+ predicted_image_embedding = predicted_image_embedding_uncond + prior_guidance_scale * (
+ predicted_image_embedding_text - predicted_image_embedding_uncond
+ )
+
+ if i + 1 == prior_timesteps_tensor.shape[0]:
+ prev_timestep = None
+ else:
+ prev_timestep = prior_timesteps_tensor[i + 1]
+
+ prior_latents = self.prior_scheduler.step(
+ predicted_image_embedding,
+ timestep=t,
+ sample=prior_latents,
+ generator=generator,
+ prev_timestep=prev_timestep,
+ ).prev_sample
+
+ prior_latents = self.prior.post_process_latents(prior_latents)
+
+ image_embeddings = prior_latents
+
+ # done prior
+
+ # decoder
+
+ text_encoder_hidden_states, additive_clip_time_embeddings = self.text_proj(
+ image_embeddings=image_embeddings,
+ prompt_embeds=prompt_embeds,
+ text_encoder_hidden_states=text_encoder_hidden_states,
+ do_classifier_free_guidance=do_classifier_free_guidance,
+ )
+
+ if device.type == "mps":
+ # HACK: MPS: There is a panic when padding bool tensors,
+ # so cast to int tensor for the pad and back to bool afterwards
+ text_mask = text_mask.type(torch.int)
+ decoder_text_mask = F.pad(text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=1)
+ decoder_text_mask = decoder_text_mask.type(torch.bool)
+ else:
+ decoder_text_mask = F.pad(text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=True)
+
+ self.decoder_scheduler.set_timesteps(decoder_num_inference_steps, device=device)
+ decoder_timesteps_tensor = self.decoder_scheduler.timesteps
+
+ num_channels_latents = self.decoder.config.in_channels
+ height = self.decoder.config.sample_size
+ width = self.decoder.config.sample_size
+
+ decoder_latents = self.prepare_latents(
+ (batch_size, num_channels_latents, height, width),
+ text_encoder_hidden_states.dtype,
+ device,
+ generator,
+ None,
+ self.decoder_scheduler,
+ )
+
+ for i, t in enumerate(self.progress_bar(decoder_timesteps_tensor)):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([decoder_latents] * 2) if do_classifier_free_guidance else decoder_latents
+
+ noise_pred = self.decoder(
+ sample=latent_model_input,
+ timestep=t,
+ encoder_hidden_states=text_encoder_hidden_states,
+ class_labels=additive_clip_time_embeddings,
+ attention_mask=decoder_text_mask,
+ ).sample
+
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred_uncond, _ = noise_pred_uncond.split(latent_model_input.shape[1], dim=1)
+ noise_pred_text, predicted_variance = noise_pred_text.split(latent_model_input.shape[1], dim=1)
+ noise_pred = noise_pred_uncond + decoder_guidance_scale * (noise_pred_text - noise_pred_uncond)
+ noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)
+
+ if i + 1 == decoder_timesteps_tensor.shape[0]:
+ prev_timestep = None
+ else:
+ prev_timestep = decoder_timesteps_tensor[i + 1]
+
+ # compute the previous noisy sample x_t -> x_t-1
+ decoder_latents = self.decoder_scheduler.step(
+ noise_pred, t, decoder_latents, prev_timestep=prev_timestep, generator=generator
+ ).prev_sample
+
+ decoder_latents = decoder_latents.clamp(-1, 1)
+
+ image_small = decoder_latents
+
+ # done decoder
+
+ # super res
+
+ self.super_res_scheduler.set_timesteps(super_res_num_inference_steps, device=device)
+ super_res_timesteps_tensor = self.super_res_scheduler.timesteps
+
+ channels = self.super_res_first.config.in_channels // 2
+ height = self.super_res_first.config.sample_size
+ width = self.super_res_first.config.sample_size
+
+ super_res_latents = self.prepare_latents(
+ (batch_size, channels, height, width),
+ image_small.dtype,
+ device,
+ generator,
+ None,
+ self.super_res_scheduler,
+ )
+
+ if device.type == "mps":
+ # MPS does not support many interpolations
+ image_upscaled = F.interpolate(image_small, size=[height, width])
+ else:
+ interpolate_antialias = {}
+ if "antialias" in inspect.signature(F.interpolate).parameters:
+ interpolate_antialias["antialias"] = True
+
+ image_upscaled = F.interpolate(
+ image_small, size=[height, width], mode="bicubic", align_corners=False, **interpolate_antialias
+ )
+
+ for i, t in enumerate(self.progress_bar(super_res_timesteps_tensor)):
+ # no classifier free guidance
+
+ if i == super_res_timesteps_tensor.shape[0] - 1:
+ unet = self.super_res_last
+ else:
+ unet = self.super_res_first
+
+ latent_model_input = torch.cat([super_res_latents, image_upscaled], dim=1)
+
+ noise_pred = unet(
+ sample=latent_model_input,
+ timestep=t,
+ ).sample
+
+ if i + 1 == super_res_timesteps_tensor.shape[0]:
+ prev_timestep = None
+ else:
+ prev_timestep = super_res_timesteps_tensor[i + 1]
+
+ # compute the previous noisy sample x_t -> x_t-1
+ super_res_latents = self.super_res_scheduler.step(
+ noise_pred, t, super_res_latents, prev_timestep=prev_timestep, generator=generator
+ ).prev_sample
+
+ image = super_res_latents
+ # done super res
+
+ # post processing
+
+ image = image * 0.5 + 0.5
+ image = image.clamp(0, 1)
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image,)
+
+ return ImagePipelineOutput(images=image)
diff --git a/diffusers/examples/community/wildcard_stable_diffusion.py b/diffusers/examples/community/wildcard_stable_diffusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..c866ce2ae9042e8e3ae33b98732330e95aa42962
--- /dev/null
+++ b/diffusers/examples/community/wildcard_stable_diffusion.py
@@ -0,0 +1,419 @@
+import inspect
+import os
+import random
+import re
+from dataclasses import dataclass
+from typing import Callable, Dict, List, Optional, Union
+
+import torch
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
+
+from diffusers.configuration_utils import FrozenDict
+from diffusers.models import AutoencoderKL, UNet2DConditionModel
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
+from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput
+from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
+from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
+from diffusers.utils import deprecate, logging
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+global_re_wildcard = re.compile(r"__([^_]*)__")
+
+
+def get_filename(path: str):
+ # this doesn't work on Windows
+ return os.path.basename(path).split(".txt")[0]
+
+
+def read_wildcard_values(path: str):
+ with open(path, encoding="utf8") as f:
+ return f.read().splitlines()
+
+
+def grab_wildcard_values(wildcard_option_dict: Dict[str, List[str]] = {}, wildcard_files: List[str] = []):
+ for wildcard_file in wildcard_files:
+ filename = get_filename(wildcard_file)
+ read_values = read_wildcard_values(wildcard_file)
+ if filename not in wildcard_option_dict:
+ wildcard_option_dict[filename] = []
+ wildcard_option_dict[filename].extend(read_values)
+ return wildcard_option_dict
+
+
+def replace_prompt_with_wildcards(
+ prompt: str, wildcard_option_dict: Dict[str, List[str]] = {}, wildcard_files: List[str] = []
+):
+ new_prompt = prompt
+
+ # get wildcard options
+ wildcard_option_dict = grab_wildcard_values(wildcard_option_dict, wildcard_files)
+
+ for m in global_re_wildcard.finditer(new_prompt):
+ wildcard_value = m.group()
+ replace_value = random.choice(wildcard_option_dict[wildcard_value.strip("__")])
+ new_prompt = new_prompt.replace(wildcard_value, replace_value, 1)
+
+ return new_prompt
+
+
+@dataclass
+class WildcardStableDiffusionOutput(StableDiffusionPipelineOutput):
+ prompts: List[str]
+
+
+class WildcardStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin):
+ r"""
+ Example Usage:
+ pipe = WildcardStableDiffusionPipeline.from_pretrained(
+ "CompVis/stable-diffusion-v1-4",
+
+ torch_dtype=torch.float16,
+ )
+ prompt = "__animal__ sitting on a __object__ wearing a __clothing__"
+ out = pipe(
+ prompt,
+ wildcard_option_dict={
+ "clothing":["hat", "shirt", "scarf", "beret"]
+ },
+ wildcard_files=["object.txt", "animal.txt"],
+ num_prompt_samples=1
+ )
+
+
+ Pipeline for text-to-image generation with wild cards using Stable Diffusion.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
+ feature_extractor ([`CLIPImageProcessor`]):
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+ """
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPImageProcessor,
+ ):
+ super().__init__()
+
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
+ " file"
+ )
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["steps_offset"] = 1
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if safety_checker is None:
+ logger.warning(
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
+ )
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]],
+ height: int = 512,
+ width: int = 512,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[torch.Generator] = None,
+ latents: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
+ callback_steps: int = 1,
+ wildcard_option_dict: Dict[str, List[str]] = {},
+ wildcard_files: List[str] = [],
+ num_prompt_samples: Optional[int] = 1,
+ **kwargs,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image generation.
+ height (`int`, *optional*, defaults to 512):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to 512):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator`, *optional*):
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
+ deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ wildcard_option_dict (Dict[str, List[str]]):
+ dict with key as `wildcard` and values as a list of possible replacements. For example if a prompt, "A __animal__ sitting on a chair". A wildcard_option_dict can provide possible values for "animal" like this: {"animal":["dog", "cat", "fox"]}
+ wildcard_files: (List[str])
+ List of filenames of txt files for wildcard replacements. For example if a prompt, "A __animal__ sitting on a chair". A file can be provided ["animal.txt"]
+ num_prompt_samples: int
+ Number of times to sample wildcards for each prompt provided
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+
+ if isinstance(prompt, str):
+ prompt = [
+ replace_prompt_with_wildcards(prompt, wildcard_option_dict, wildcard_files)
+ for i in range(num_prompt_samples)
+ ]
+ batch_size = len(prompt)
+ elif isinstance(prompt, list):
+ prompt_list = []
+ for p in prompt:
+ for i in range(num_prompt_samples):
+ prompt_list.append(replace_prompt_with_wildcards(p, wildcard_option_dict, wildcard_files))
+ prompt = prompt_list
+ batch_size = len(prompt)
+ else:
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ # get prompt text embeddings
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+
+ if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
+ removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+ text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
+ text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ bs_embed, seq_len, _ = text_embeddings.shape
+ text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
+ text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ max_length = text_input_ids.shape[-1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
+
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = uncond_embeddings.shape[1]
+ uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
+ uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
+
+ # get the initial random noise unless the user supplied it
+
+ # Unlike in other pipelines, latents need to be generated in the target device
+ # for 1-to-1 results reproducibility with the CompVis implementation.
+ # However this currently doesn't work in `mps`.
+ latents_shape = (batch_size * num_images_per_prompt, self.unet.config.in_channels, height // 8, width // 8)
+ latents_dtype = text_embeddings.dtype
+ if latents is None:
+ if self.device.type == "mps":
+ # randn does not exist on mps
+ latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to(
+ self.device
+ )
+ else:
+ latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)
+ else:
+ if latents.shape != latents_shape:
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
+ latents = latents.to(self.device)
+
+ # set timesteps
+ self.scheduler.set_timesteps(num_inference_steps)
+
+ # Some schedulers like PNDM have timesteps as arrays
+ # It's more optimized to move all timesteps to correct device beforehand
+ timesteps_tensor = self.scheduler.timesteps.to(self.device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ for i, t in enumerate(self.progress_bar(timesteps_tensor)):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+ # call the callback, if provided
+ if callback is not None and i % callback_steps == 0:
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
+
+ latents = 1 / 0.18215 * latents
+ image = self.vae.decode(latents).sample
+
+ image = (image / 2 + 0.5).clamp(0, 1)
+
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+
+ if self.safety_checker is not None:
+ safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(
+ self.device
+ )
+ image, has_nsfw_concept = self.safety_checker(
+ images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)
+ )
+ else:
+ has_nsfw_concept = None
+
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return WildcardStableDiffusionOutput(images=image, nsfw_content_detected=has_nsfw_concept, prompts=prompt)
diff --git a/diffusers/examples/conftest.py b/diffusers/examples/conftest.py
new file mode 100644
index 0000000000000000000000000000000000000000..57d77d18a915095f27965da79463c4698749042e
--- /dev/null
+++ b/diffusers/examples/conftest.py
@@ -0,0 +1,45 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# tests directory-specific settings - this file is run automatically
+# by pytest before any tests are run
+
+import sys
+import warnings
+from os.path import abspath, dirname, join
+
+
+# allow having multiple repository checkouts and not needing to remember to rerun
+# 'pip install -e .[dev]' when switching between checkouts and running tests.
+git_repo_path = abspath(join(dirname(dirname(dirname(__file__))), "src"))
+sys.path.insert(1, git_repo_path)
+
+
+# silence FutureWarning warnings in tests since often we can't act on them until
+# they become normal warnings - i.e. the tests still need to test the current functionality
+warnings.simplefilter(action="ignore", category=FutureWarning)
+
+
+def pytest_addoption(parser):
+ from diffusers.utils.testing_utils import pytest_addoption_shared
+
+ pytest_addoption_shared(parser)
+
+
+def pytest_terminal_summary(terminalreporter):
+ from diffusers.utils.testing_utils import pytest_terminal_summary_main
+
+ make_reports = terminalreporter.config.getoption("--make-reports")
+ if make_reports:
+ pytest_terminal_summary_main(terminalreporter, id=make_reports)
diff --git a/diffusers/examples/consistency_distillation/README.md b/diffusers/examples/consistency_distillation/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..f5cb72fa8682040a61e7372a02ddfae934222b08
--- /dev/null
+++ b/diffusers/examples/consistency_distillation/README.md
@@ -0,0 +1,112 @@
+# Latent Consistency Distillation Example:
+
+[Latent Consistency Models (LCMs)](https://arxiv.org/abs/2310.04378) is a method to distill a latent diffusion model to enable swift inference with minimal steps. This example demonstrates how to use latent consistency distillation to distill stable-diffusion-v1.5 for inference with few timesteps.
+
+## Full model distillation
+
+### Running locally with PyTorch
+
+#### Installing the dependencies
+
+Before running the scripts, make sure to install the library's training dependencies:
+
+**Important**
+
+To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
+```bash
+git clone https://github.com/huggingface/diffusers
+cd diffusers
+pip install -e .
+```
+
+Then cd in the example folder and run
+```bash
+pip install -r requirements.txt
+```
+
+And initialize an [🤗 Accelerate](https://github.com/huggingface/accelerate/) environment with:
+
+```bash
+accelerate config
+```
+
+Or for a default accelerate configuration without answering questions about your environment
+
+```bash
+accelerate config default
+```
+
+Or if your environment doesn't support an interactive shell e.g. a notebook
+
+```python
+from accelerate.utils import write_basic_config
+write_basic_config()
+```
+
+When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
+
+
+#### Example
+
+The following uses the [Conceptual Captions 12M (CC12M) dataset](https://github.com/google-research-datasets/conceptual-12m) as an example, and for illustrative purposes only. For best results you may consider large and high-quality text-image datasets such as [LAION](https://laion.ai/blog/laion-400-open-dataset/). You may also need to search the hyperparameter space according to the dataset you use.
+
+```bash
+export MODEL_NAME="stable-diffusion-v1-5/stable-diffusion-v1-5"
+export OUTPUT_DIR="path/to/saved/model"
+
+accelerate launch train_lcm_distill_sd_wds.py \
+ --pretrained_teacher_model=$MODEL_NAME \
+ --output_dir=$OUTPUT_DIR \
+ --mixed_precision=fp16 \
+ --resolution=512 \
+ --learning_rate=1e-6 --loss_type="huber" --ema_decay=0.95 --adam_weight_decay=0.0 \
+ --max_train_steps=1000 \
+ --max_train_samples=4000000 \
+ --dataloader_num_workers=8 \
+ --train_shards_path_or_url="pipe:curl -L -s https://huggingface.co/datasets/laion/conceptual-captions-12m-webdataset/resolve/main/data/{00000..01099}.tar?download=true" \
+ --validation_steps=200 \
+ --checkpointing_steps=200 --checkpoints_total_limit=10 \
+ --train_batch_size=12 \
+ --gradient_checkpointing --enable_xformers_memory_efficient_attention \
+ --gradient_accumulation_steps=1 \
+ --use_8bit_adam \
+ --resume_from_checkpoint=latest \
+ --report_to=wandb \
+ --seed=453645634 \
+ --push_to_hub
+```
+
+## LCM-LoRA
+
+Instead of fine-tuning the full model, we can also just train a LoRA that can be injected into any SDXL model.
+
+### Example
+
+The following uses the [Conceptual Captions 12M (CC12M) dataset](https://github.com/google-research-datasets/conceptual-12m) as an example. For best results you may consider large and high-quality text-image datasets such as [LAION](https://laion.ai/blog/laion-400-open-dataset/).
+
+```bash
+export MODEL_NAME="stable-diffusion-v1-5/stable-diffusion-v1-5"
+export OUTPUT_DIR="path/to/saved/model"
+
+accelerate launch train_lcm_distill_lora_sd_wds.py \
+ --pretrained_teacher_model=$MODEL_NAME \
+ --output_dir=$OUTPUT_DIR \
+ --mixed_precision=fp16 \
+ --resolution=512 \
+ --lora_rank=64 \
+ --learning_rate=1e-4 --loss_type="huber" --adam_weight_decay=0.0 \
+ --max_train_steps=1000 \
+ --max_train_samples=4000000 \
+ --dataloader_num_workers=8 \
+ --train_shards_path_or_url="pipe:curl -L -s https://huggingface.co/datasets/laion/conceptual-captions-12m-webdataset/resolve/main/data/{00000..01099}.tar?download=true" \
+ --validation_steps=200 \
+ --checkpointing_steps=200 --checkpoints_total_limit=10 \
+ --train_batch_size=12 \
+ --gradient_checkpointing --enable_xformers_memory_efficient_attention \
+ --gradient_accumulation_steps=1 \
+ --use_8bit_adam \
+ --resume_from_checkpoint=latest \
+ --report_to=wandb \
+ --seed=453645634 \
+ --push_to_hub \
+```
\ No newline at end of file
diff --git a/diffusers/examples/consistency_distillation/README_sdxl.md b/diffusers/examples/consistency_distillation/README_sdxl.md
new file mode 100644
index 0000000000000000000000000000000000000000..edb4bd57f291a597fa01dc2ecbe60c4c69b7f9cd
--- /dev/null
+++ b/diffusers/examples/consistency_distillation/README_sdxl.md
@@ -0,0 +1,148 @@
+# Latent Consistency Distillation Example:
+
+[Latent Consistency Models (LCMs)](https://arxiv.org/abs/2310.04378) is a method to distill a latent diffusion model to enable swift inference with minimal steps. This example demonstrates how to use latent consistency distillation to distill SDXL for inference with few timesteps.
+
+## Full model distillation
+
+### Running locally with PyTorch
+
+#### Installing the dependencies
+
+Before running the scripts, make sure to install the library's training dependencies:
+
+**Important**
+
+To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
+```bash
+git clone https://github.com/huggingface/diffusers
+cd diffusers
+pip install -e .
+```
+
+Then cd in the example folder and run
+```bash
+pip install -r requirements.txt
+```
+
+And initialize an [🤗 Accelerate](https://github.com/huggingface/accelerate/) environment with:
+
+```bash
+accelerate config
+```
+
+Or for a default accelerate configuration without answering questions about your environment
+
+```bash
+accelerate config default
+```
+
+Or if your environment doesn't support an interactive shell e.g. a notebook
+
+```python
+from accelerate.utils import write_basic_config
+write_basic_config()
+```
+
+When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
+
+
+#### Example
+
+The following uses the [Conceptual Captions 12M (CC12M) dataset](https://github.com/google-research-datasets/conceptual-12m) as an example, and for illustrative purposes only. For best results you may consider large and high-quality text-image datasets such as [LAION](https://laion.ai/blog/laion-400-open-dataset/). You may also need to search the hyperparameter space according to the dataset you use.
+
+```bash
+export MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0"
+export OUTPUT_DIR="path/to/saved/model"
+
+accelerate launch train_lcm_distill_sdxl_wds.py \
+ --pretrained_teacher_model=$MODEL_NAME \
+ --pretrained_vae_model_name_or_path=madebyollin/sdxl-vae-fp16-fix \
+ --output_dir=$OUTPUT_DIR \
+ --mixed_precision=fp16 \
+ --resolution=1024 \
+ --learning_rate=1e-6 --loss_type="huber" --use_fix_crop_and_size --ema_decay=0.95 --adam_weight_decay=0.0 \
+ --max_train_steps=1000 \
+ --max_train_samples=4000000 \
+ --dataloader_num_workers=8 \
+ --train_shards_path_or_url="pipe:curl -L -s https://huggingface.co/datasets/laion/conceptual-captions-12m-webdataset/resolve/main/data/{00000..01099}.tar?download=true" \
+ --validation_steps=200 \
+ --checkpointing_steps=200 --checkpoints_total_limit=10 \
+ --train_batch_size=12 \
+ --gradient_checkpointing --enable_xformers_memory_efficient_attention \
+ --gradient_accumulation_steps=1 \
+ --use_8bit_adam \
+ --resume_from_checkpoint=latest \
+ --report_to=wandb \
+ --seed=453645634 \
+ --push_to_hub \
+```
+
+## LCM-LoRA
+
+Instead of fine-tuning the full model, we can also just train a LoRA that can be injected into any SDXL model.
+
+### Example
+
+The following uses the [Conceptual Captions 12M (CC12M) dataset](https://github.com/google-research-datasets/conceptual-12m) as an example. For best results you may consider large and high-quality text-image datasets such as [LAION](https://laion.ai/blog/laion-400-open-dataset/).
+
+```bash
+export MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0"
+export OUTPUT_DIR="path/to/saved/model"
+
+accelerate launch train_lcm_distill_lora_sdxl_wds.py \
+ --pretrained_teacher_model=$MODEL_DIR \
+ --pretrained_vae_model_name_or_path=madebyollin/sdxl-vae-fp16-fix \
+ --output_dir=$OUTPUT_DIR \
+ --mixed_precision=fp16 \
+ --resolution=1024 \
+ --lora_rank=64 \
+ --learning_rate=1e-4 --loss_type="huber" --use_fix_crop_and_size --adam_weight_decay=0.0 \
+ --max_train_steps=1000 \
+ --max_train_samples=4000000 \
+ --dataloader_num_workers=8 \
+ --train_shards_path_or_url="pipe:curl -L -s https://huggingface.co/datasets/laion/conceptual-captions-12m-webdataset/resolve/main/data/{00000..01099}.tar?download=true" \
+ --validation_steps=200 \
+ --checkpointing_steps=200 --checkpoints_total_limit=10 \
+ --train_batch_size=12 \
+ --gradient_checkpointing --enable_xformers_memory_efficient_attention \
+ --gradient_accumulation_steps=1 \
+ --use_8bit_adam \
+ --resume_from_checkpoint=latest \
+ --report_to=wandb \
+ --seed=453645634 \
+ --push_to_hub \
+```
+
+We provide another version for LCM LoRA SDXL that follows best practices of `peft` and leverages the `datasets` library for quick experimentation. The script doesn't load two UNets unlike `train_lcm_distill_lora_sdxl_wds.py` which reduces the memory requirements quite a bit.
+
+Below is an example training command that trains an LCM LoRA on the [Narutos dataset](https://huggingface.co/datasets/lambdalabs/naruto-blip-captions):
+
+```bash
+export MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0"
+export DATASET_NAME="lambdalabs/naruto-blip-captions"
+export VAE_PATH="madebyollin/sdxl-vae-fp16-fix"
+
+accelerate launch train_lcm_distill_lora_sdxl.py \
+ --pretrained_teacher_model=${MODEL_NAME} \
+ --pretrained_vae_model_name_or_path=${VAE_PATH} \
+ --output_dir="narutos-lora-lcm-sdxl" \
+ --mixed_precision="fp16" \
+ --dataset_name=$DATASET_NAME \
+ --resolution=1024 \
+ --train_batch_size=24 \
+ --gradient_accumulation_steps=1 \
+ --gradient_checkpointing \
+ --use_8bit_adam \
+ --lora_rank=64 \
+ --learning_rate=1e-4 \
+ --report_to="wandb" \
+ --lr_scheduler="constant" \
+ --lr_warmup_steps=0 \
+ --max_train_steps=3000 \
+ --checkpointing_steps=500 \
+ --validation_steps=50 \
+ --seed="0" \
+ --report_to="wandb" \
+ --push_to_hub
+```
+
diff --git a/diffusers/examples/consistency_distillation/requirements.txt b/diffusers/examples/consistency_distillation/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..09fb84270a8a897c2082748fff25220b94d4532e
--- /dev/null
+++ b/diffusers/examples/consistency_distillation/requirements.txt
@@ -0,0 +1,7 @@
+accelerate>=0.16.0
+torchvision
+transformers>=4.25.1
+ftfy
+tensorboard
+Jinja2
+webdataset
\ No newline at end of file
diff --git a/diffusers/examples/consistency_distillation/test_lcm_lora.py b/diffusers/examples/consistency_distillation/test_lcm_lora.py
new file mode 100644
index 0000000000000000000000000000000000000000..5ca66909a2061dd5918bb657ea8cd1d26c099a3a
--- /dev/null
+++ b/diffusers/examples/consistency_distillation/test_lcm_lora.py
@@ -0,0 +1,112 @@
+# coding=utf-8
+# Copyright 2024 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import os
+import sys
+import tempfile
+
+import safetensors
+
+
+sys.path.append("..")
+from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
+
+
+logging.basicConfig(level=logging.DEBUG)
+
+logger = logging.getLogger()
+stream_handler = logging.StreamHandler(sys.stdout)
+logger.addHandler(stream_handler)
+
+
+class TextToImageLCM(ExamplesTestsAccelerate):
+ def test_text_to_image_lcm_lora_sdxl(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/consistency_distillation/train_lcm_distill_lora_sdxl.py
+ --pretrained_teacher_model hf-internal-testing/tiny-stable-diffusion-xl-pipe
+ --dataset_name hf-internal-testing/dummy_image_text_data
+ --resolution 64
+ --lora_rank 4
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ """.split()
+
+ run_command(self._launch_args + test_args)
+ # save_pretrained smoke test
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
+
+ # make sure the state_dict has the correct naming in the parameters.
+ lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
+ is_lora = all("lora" in k for k in lora_state_dict.keys())
+ self.assertTrue(is_lora)
+
+ def test_text_to_image_lcm_lora_sdxl_checkpointing(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/consistency_distillation/train_lcm_distill_lora_sdxl.py
+ --pretrained_teacher_model hf-internal-testing/tiny-stable-diffusion-xl-pipe
+ --dataset_name hf-internal-testing/dummy_image_text_data
+ --resolution 64
+ --lora_rank 4
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 7
+ --checkpointing_steps 2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ """.split()
+
+ run_command(self._launch_args + test_args)
+
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-2", "checkpoint-4", "checkpoint-6"},
+ )
+
+ test_args = f"""
+ examples/consistency_distillation/train_lcm_distill_lora_sdxl.py
+ --pretrained_teacher_model hf-internal-testing/tiny-stable-diffusion-xl-pipe
+ --dataset_name hf-internal-testing/dummy_image_text_data
+ --resolution 64
+ --lora_rank 4
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 9
+ --checkpointing_steps 2
+ --resume_from_checkpoint latest
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ """.split()
+
+ run_command(self._launch_args + test_args)
+
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"},
+ )
diff --git a/diffusers/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py b/diffusers/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py
new file mode 100644
index 0000000000000000000000000000000000000000..611026675daf619b4cd91250740d0f76426ac640
--- /dev/null
+++ b/diffusers/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py
@@ -0,0 +1,1467 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+import argparse
+import functools
+import gc
+import itertools
+import json
+import logging
+import math
+import os
+import random
+import shutil
+from contextlib import nullcontext
+from pathlib import Path
+from typing import List, Union
+
+import accelerate
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import torchvision.transforms.functional as TF
+import transformers
+import webdataset as wds
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import ProjectConfiguration, set_seed
+from braceexpand import braceexpand
+from huggingface_hub import create_repo, upload_folder
+from packaging import version
+from peft import LoraConfig, get_peft_model, get_peft_model_state_dict
+from torch.utils.data import default_collate
+from torchvision import transforms
+from tqdm.auto import tqdm
+from transformers import AutoTokenizer, CLIPTextModel, PretrainedConfig
+from webdataset.tariterators import (
+ base_plus_ext,
+ tar_file_expander,
+ url_opener,
+ valid_sample,
+)
+
+import diffusers
+from diffusers import (
+ AutoencoderKL,
+ DDPMScheduler,
+ LCMScheduler,
+ StableDiffusionPipeline,
+ UNet2DConditionModel,
+)
+from diffusers.optimization import get_scheduler
+from diffusers.training_utils import resolve_interpolation_mode
+from diffusers.utils import check_min_version, is_wandb_available
+from diffusers.utils.import_utils import is_xformers_available
+
+
+MAX_SEQ_LENGTH = 77
+
+if is_wandb_available():
+ import wandb
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.31.0.dev0")
+
+logger = get_logger(__name__)
+
+
+def get_module_kohya_state_dict(module, prefix: str, dtype: torch.dtype, adapter_name: str = "default"):
+ kohya_ss_state_dict = {}
+ for peft_key, weight in get_peft_model_state_dict(module, adapter_name=adapter_name).items():
+ kohya_key = peft_key.replace("base_model.model", prefix)
+ kohya_key = kohya_key.replace("lora_A", "lora_down")
+ kohya_key = kohya_key.replace("lora_B", "lora_up")
+ kohya_key = kohya_key.replace(".", "_", kohya_key.count(".") - 2)
+ kohya_ss_state_dict[kohya_key] = weight.to(dtype)
+
+ # Set alpha parameter
+ if "lora_down" in kohya_key:
+ alpha_key = f'{kohya_key.split(".")[0]}.alpha'
+ kohya_ss_state_dict[alpha_key] = torch.tensor(module.peft_config[adapter_name].lora_alpha).to(dtype)
+
+ return kohya_ss_state_dict
+
+
+def filter_keys(key_set):
+ def _f(dictionary):
+ return {k: v for k, v in dictionary.items() if k in key_set}
+
+ return _f
+
+
+def group_by_keys_nothrow(data, keys=base_plus_ext, lcase=True, suffixes=None, handler=None):
+ """Return function over iterator that groups key, value pairs into samples.
+
+ :param keys: function that splits the key into key and extension (base_plus_ext) :param lcase: convert suffixes to
+ lower case (Default value = True)
+ """
+ current_sample = None
+ for filesample in data:
+ assert isinstance(filesample, dict)
+ fname, value = filesample["fname"], filesample["data"]
+ prefix, suffix = keys(fname)
+ if prefix is None:
+ continue
+ if lcase:
+ suffix = suffix.lower()
+ # FIXME webdataset version throws if suffix in current_sample, but we have a potential for
+ # this happening in the current LAION400m dataset if a tar ends with same prefix as the next
+ # begins, rare, but can happen since prefix aren't unique across tar files in that dataset
+ if current_sample is None or prefix != current_sample["__key__"] or suffix in current_sample:
+ if valid_sample(current_sample):
+ yield current_sample
+ current_sample = {"__key__": prefix, "__url__": filesample["__url__"]}
+ if suffixes is None or suffix in suffixes:
+ current_sample[suffix] = value
+ if valid_sample(current_sample):
+ yield current_sample
+
+
+def tarfile_to_samples_nothrow(src, handler=wds.warn_and_continue):
+ # NOTE this is a re-impl of the webdataset impl with group_by_keys that doesn't throw
+ streams = url_opener(src, handler=handler)
+ files = tar_file_expander(streams, handler=handler)
+ samples = group_by_keys_nothrow(files, handler=handler)
+ return samples
+
+
+class WebdatasetFilter:
+ def __init__(self, min_size=1024, max_pwatermark=0.5):
+ self.min_size = min_size
+ self.max_pwatermark = max_pwatermark
+
+ def __call__(self, x):
+ try:
+ if "json" in x:
+ x_json = json.loads(x["json"])
+ filter_size = (x_json.get("original_width", 0.0) or 0.0) >= self.min_size and x_json.get(
+ "original_height", 0
+ ) >= self.min_size
+ filter_watermark = (x_json.get("pwatermark", 1.0) or 1.0) <= self.max_pwatermark
+ return filter_size and filter_watermark
+ else:
+ return False
+ except Exception:
+ return False
+
+
+class SDText2ImageDataset:
+ def __init__(
+ self,
+ train_shards_path_or_url: Union[str, List[str]],
+ num_train_examples: int,
+ per_gpu_batch_size: int,
+ global_batch_size: int,
+ num_workers: int,
+ resolution: int = 512,
+ interpolation_type: str = "bilinear",
+ shuffle_buffer_size: int = 1000,
+ pin_memory: bool = False,
+ persistent_workers: bool = False,
+ ):
+ if not isinstance(train_shards_path_or_url, str):
+ train_shards_path_or_url = [list(braceexpand(urls)) for urls in train_shards_path_or_url]
+ # flatten list using itertools
+ train_shards_path_or_url = list(itertools.chain.from_iterable(train_shards_path_or_url))
+
+ interpolation_mode = resolve_interpolation_mode(interpolation_type)
+
+ def transform(example):
+ # resize image
+ image = example["image"]
+ image = TF.resize(image, resolution, interpolation=interpolation_mode)
+
+ # get crop coordinates and crop image
+ c_top, c_left, _, _ = transforms.RandomCrop.get_params(image, output_size=(resolution, resolution))
+ image = TF.crop(image, c_top, c_left, resolution, resolution)
+ image = TF.to_tensor(image)
+ image = TF.normalize(image, [0.5], [0.5])
+
+ example["image"] = image
+ return example
+
+ processing_pipeline = [
+ wds.decode("pil", handler=wds.ignore_and_continue),
+ wds.rename(image="jpg;png;jpeg;webp", text="text;txt;caption", handler=wds.warn_and_continue),
+ wds.map(filter_keys({"image", "text"})),
+ wds.map(transform),
+ wds.to_tuple("image", "text"),
+ ]
+
+ # Create train dataset and loader
+ pipeline = [
+ wds.ResampledShards(train_shards_path_or_url),
+ tarfile_to_samples_nothrow,
+ wds.shuffle(shuffle_buffer_size),
+ *processing_pipeline,
+ wds.batched(per_gpu_batch_size, partial=False, collation_fn=default_collate),
+ ]
+
+ num_worker_batches = math.ceil(num_train_examples / (global_batch_size * num_workers)) # per dataloader worker
+ num_batches = num_worker_batches * num_workers
+ num_samples = num_batches * global_batch_size
+
+ # each worker is iterating over this
+ self._train_dataset = wds.DataPipeline(*pipeline).with_epoch(num_worker_batches)
+ self._train_dataloader = wds.WebLoader(
+ self._train_dataset,
+ batch_size=None,
+ shuffle=False,
+ num_workers=num_workers,
+ pin_memory=pin_memory,
+ persistent_workers=persistent_workers,
+ )
+ # add meta-data to dataloader instance for convenience
+ self._train_dataloader.num_batches = num_batches
+ self._train_dataloader.num_samples = num_samples
+
+ @property
+ def train_dataset(self):
+ return self._train_dataset
+
+ @property
+ def train_dataloader(self):
+ return self._train_dataloader
+
+
+def log_validation(vae, unet, args, accelerator, weight_dtype, step):
+ logger.info("Running validation... ")
+ if torch.backends.mps.is_available():
+ autocast_ctx = nullcontext()
+ else:
+ autocast_ctx = torch.autocast(accelerator.device.type, dtype=weight_dtype)
+
+ unet = accelerator.unwrap_model(unet)
+ pipeline = StableDiffusionPipeline.from_pretrained(
+ args.pretrained_teacher_model,
+ vae=vae,
+ scheduler=LCMScheduler.from_pretrained(args.pretrained_teacher_model, subfolder="scheduler"),
+ revision=args.revision,
+ torch_dtype=weight_dtype,
+ safety_checker=None,
+ )
+ pipeline.set_progress_bar_config(disable=True)
+
+ lora_state_dict = get_module_kohya_state_dict(unet, "lora_unet", weight_dtype)
+ pipeline.load_lora_weights(lora_state_dict)
+ pipeline.fuse_lora()
+
+ pipeline = pipeline.to(accelerator.device, dtype=weight_dtype)
+ if args.enable_xformers_memory_efficient_attention:
+ pipeline.enable_xformers_memory_efficient_attention()
+
+ if args.seed is None:
+ generator = None
+ else:
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
+
+ validation_prompts = [
+ "portrait photo of a girl, photograph, highly detailed face, depth of field, moody light, golden hour, style by Dan Winters, Russell James, Steve McCurry, centered, extremely detailed, Nikon D850, award winning photography",
+ "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k",
+ "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
+ "A photo of beautiful mountain with realistic sunset and blue lake, highly detailed, masterpiece",
+ ]
+
+ image_logs = []
+
+ for _, prompt in enumerate(validation_prompts):
+ images = []
+ with autocast_ctx:
+ images = pipeline(
+ prompt=prompt,
+ num_inference_steps=4,
+ num_images_per_prompt=4,
+ generator=generator,
+ guidance_scale=1.0,
+ ).images
+ image_logs.append({"validation_prompt": prompt, "images": images})
+
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ for log in image_logs:
+ images = log["images"]
+ validation_prompt = log["validation_prompt"]
+ formatted_images = []
+ for image in images:
+ formatted_images.append(np.asarray(image))
+
+ formatted_images = np.stack(formatted_images)
+
+ tracker.writer.add_images(validation_prompt, formatted_images, step, dataformats="NHWC")
+ elif tracker.name == "wandb":
+ formatted_images = []
+
+ for log in image_logs:
+ images = log["images"]
+ validation_prompt = log["validation_prompt"]
+ for image in images:
+ image = wandb.Image(image, caption=validation_prompt)
+ formatted_images.append(image)
+
+ tracker.log({"validation": formatted_images})
+ else:
+ logger.warning(f"image logging not implemented for {tracker.name}")
+
+ del pipeline
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ return image_logs
+
+
+# From LatentConsistencyModel.get_guidance_scale_embedding
+def guidance_scale_embedding(w, embedding_dim=512, dtype=torch.float32):
+ """
+ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
+
+ Args:
+ timesteps (`torch.Tensor`):
+ generate embedding vectors at these timesteps
+ embedding_dim (`int`, *optional*, defaults to 512):
+ dimension of the embeddings to generate
+ dtype:
+ data type of the generated embeddings
+
+ Returns:
+ `torch.Tensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
+ """
+ assert len(w.shape) == 1
+ w = w * 1000.0
+
+ half_dim = embedding_dim // 2
+ emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
+ emb = w.to(dtype)[:, None] * emb[None, :]
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+ if embedding_dim % 2 == 1: # zero pad
+ emb = torch.nn.functional.pad(emb, (0, 1))
+ assert emb.shape == (w.shape[0], embedding_dim)
+ return emb
+
+
+def append_dims(x, target_dims):
+ """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
+ dims_to_append = target_dims - x.ndim
+ if dims_to_append < 0:
+ raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less")
+ return x[(...,) + (None,) * dims_to_append]
+
+
+# From LCMScheduler.get_scalings_for_boundary_condition_discrete
+def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=10.0):
+ scaled_timestep = timestep_scaling * timestep
+ c_skip = sigma_data**2 / (scaled_timestep**2 + sigma_data**2)
+ c_out = scaled_timestep / (scaled_timestep**2 + sigma_data**2) ** 0.5
+ return c_skip, c_out
+
+
+# Compare LCMScheduler.step, Step 4
+def get_predicted_original_sample(model_output, timesteps, sample, prediction_type, alphas, sigmas):
+ alphas = extract_into_tensor(alphas, timesteps, sample.shape)
+ sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
+ if prediction_type == "epsilon":
+ pred_x_0 = (sample - sigmas * model_output) / alphas
+ elif prediction_type == "sample":
+ pred_x_0 = model_output
+ elif prediction_type == "v_prediction":
+ pred_x_0 = alphas * sample - sigmas * model_output
+ else:
+ raise ValueError(
+ f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`"
+ f" are supported."
+ )
+
+ return pred_x_0
+
+
+# Based on step 4 in DDIMScheduler.step
+def get_predicted_noise(model_output, timesteps, sample, prediction_type, alphas, sigmas):
+ alphas = extract_into_tensor(alphas, timesteps, sample.shape)
+ sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
+ if prediction_type == "epsilon":
+ pred_epsilon = model_output
+ elif prediction_type == "sample":
+ pred_epsilon = (sample - alphas * model_output) / sigmas
+ elif prediction_type == "v_prediction":
+ pred_epsilon = alphas * model_output + sigmas * sample
+ else:
+ raise ValueError(
+ f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`"
+ f" are supported."
+ )
+
+ return pred_epsilon
+
+
+def extract_into_tensor(a, t, x_shape):
+ b, *_ = t.shape
+ out = a.gather(-1, t)
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
+
+
+class DDIMSolver:
+ def __init__(self, alpha_cumprods, timesteps=1000, ddim_timesteps=50):
+ # DDIM sampling parameters
+ step_ratio = timesteps // ddim_timesteps
+ self.ddim_timesteps = (np.arange(1, ddim_timesteps + 1) * step_ratio).round().astype(np.int64) - 1
+ self.ddim_alpha_cumprods = alpha_cumprods[self.ddim_timesteps]
+ self.ddim_alpha_cumprods_prev = np.asarray(
+ [alpha_cumprods[0]] + alpha_cumprods[self.ddim_timesteps[:-1]].tolist()
+ )
+ # convert to torch tensors
+ self.ddim_timesteps = torch.from_numpy(self.ddim_timesteps).long()
+ self.ddim_alpha_cumprods = torch.from_numpy(self.ddim_alpha_cumprods)
+ self.ddim_alpha_cumprods_prev = torch.from_numpy(self.ddim_alpha_cumprods_prev)
+
+ def to(self, device):
+ self.ddim_timesteps = self.ddim_timesteps.to(device)
+ self.ddim_alpha_cumprods = self.ddim_alpha_cumprods.to(device)
+ self.ddim_alpha_cumprods_prev = self.ddim_alpha_cumprods_prev.to(device)
+ return self
+
+ def ddim_step(self, pred_x0, pred_noise, timestep_index):
+ alpha_cumprod_prev = extract_into_tensor(self.ddim_alpha_cumprods_prev, timestep_index, pred_x0.shape)
+ dir_xt = (1.0 - alpha_cumprod_prev).sqrt() * pred_noise
+ x_prev = alpha_cumprod_prev.sqrt() * pred_x0 + dir_xt
+ return x_prev
+
+
+@torch.no_grad()
+def update_ema(target_params, source_params, rate=0.99):
+ """
+ Update target parameters to be closer to those of source parameters using
+ an exponential moving average.
+
+ :param target_params: the target parameter sequence.
+ :param source_params: the source parameter sequence.
+ :param rate: the EMA rate (closer to 1 means slower).
+ """
+ for targ, src in zip(target_params, source_params):
+ targ.detach().mul_(rate).add_(src, alpha=1 - rate)
+
+
+def import_model_class_from_model_name_or_path(
+ pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
+):
+ text_encoder_config = PretrainedConfig.from_pretrained(
+ pretrained_model_name_or_path, subfolder=subfolder, revision=revision
+ )
+ model_class = text_encoder_config.architectures[0]
+
+ if model_class == "CLIPTextModel":
+ from transformers import CLIPTextModel
+
+ return CLIPTextModel
+ elif model_class == "CLIPTextModelWithProjection":
+ from transformers import CLIPTextModelWithProjection
+
+ return CLIPTextModelWithProjection
+ else:
+ raise ValueError(f"{model_class} is not supported.")
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ # ----------Model Checkpoint Loading Arguments----------
+ parser.add_argument(
+ "--pretrained_teacher_model",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained LDM teacher model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--pretrained_vae_model_name_or_path",
+ type=str,
+ default=None,
+ help="Path to pretrained VAE model with better numerical stability. More details: https://github.com/huggingface/diffusers/pull/4038.",
+ )
+ parser.add_argument(
+ "--teacher_revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained LDM teacher model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained LDM model identifier from huggingface.co/models.",
+ )
+ # ----------Training Arguments----------
+ # ----General Training Arguments----
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="lcm-xl-distilled",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ # ----Logging----
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ # ----Checkpointing----
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ # ----Image Processing----
+ parser.add_argument(
+ "--train_shards_path_or_url",
+ type=str,
+ default=None,
+ help=(
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
+ " or to a folder containing files that 🤗 Datasets can understand."
+ ),
+ )
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--interpolation_type",
+ type=str,
+ default="bilinear",
+ help=(
+ "The interpolation function used when resizing images to the desired resolution. Choose between `bilinear`,"
+ " `bicubic`, `box`, `nearest`, `nearest_exact`, `hamming`, and `lanczos`."
+ ),
+ )
+ parser.add_argument(
+ "--center_crop",
+ default=False,
+ action="store_true",
+ help=(
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
+ " cropped. The images will be resized to the resolution first before cropping."
+ ),
+ )
+ parser.add_argument(
+ "--random_flip",
+ action="store_true",
+ help="whether to randomly flip images horizontally",
+ )
+ # ----Dataloader----
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ # ----Batch Size and Training Steps----
+ parser.add_argument(
+ "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=100)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--max_train_samples",
+ type=int,
+ default=None,
+ help=(
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ ),
+ )
+ # ----Learning Rate----
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ # ----Optimizer (Adam)----
+ parser.add_argument(
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ # ----Diffusion Training Arguments----
+ parser.add_argument(
+ "--proportion_empty_prompts",
+ type=float,
+ default=0,
+ help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).",
+ )
+ # ----Latent Consistency Distillation (LCD) Specific Arguments----
+ parser.add_argument(
+ "--w_min",
+ type=float,
+ default=5.0,
+ required=False,
+ help=(
+ "The minimum guidance scale value for guidance scale sampling. Note that we are using the Imagen CFG"
+ " formulation rather than the LCM formulation, which means all guidance scales have 1 added to them as"
+ " compared to the original paper."
+ ),
+ )
+ parser.add_argument(
+ "--w_max",
+ type=float,
+ default=15.0,
+ required=False,
+ help=(
+ "The maximum guidance scale value for guidance scale sampling. Note that we are using the Imagen CFG"
+ " formulation rather than the LCM formulation, which means all guidance scales have 1 added to them as"
+ " compared to the original paper."
+ ),
+ )
+ parser.add_argument(
+ "--num_ddim_timesteps",
+ type=int,
+ default=50,
+ help="The number of timesteps to use for DDIM sampling.",
+ )
+ parser.add_argument(
+ "--loss_type",
+ type=str,
+ default="l2",
+ choices=["l2", "huber"],
+ help="The type of loss to use for the LCD loss.",
+ )
+ parser.add_argument(
+ "--huber_c",
+ type=float,
+ default=0.001,
+ help="The huber loss parameter. Only used if `--loss_type=huber`.",
+ )
+ parser.add_argument(
+ "--lora_rank",
+ type=int,
+ default=64,
+ help="The rank of the LoRA projection matrix.",
+ )
+ parser.add_argument(
+ "--lora_alpha",
+ type=int,
+ default=64,
+ help=(
+ "The value of the LoRA alpha parameter, which controls the scaling factor in front of the LoRA weight"
+ " update delta_W. No scaling will be performed if this value is equal to `lora_rank`."
+ ),
+ )
+ parser.add_argument(
+ "--lora_dropout",
+ type=float,
+ default=0.0,
+ help="The dropout probability for the dropout layer added before applying the LoRA to each layer input.",
+ )
+ parser.add_argument(
+ "--lora_target_modules",
+ type=str,
+ default=None,
+ help=(
+ "A comma-separated string of target module keys to add LoRA to. If not set, a default list of modules will"
+ " be used. By default, LoRA will be applied to all conv and linear layers."
+ ),
+ )
+ parser.add_argument(
+ "--vae_encode_batch_size",
+ type=int,
+ default=32,
+ required=False,
+ help=(
+ "The batch size used when encoding (and decoding) images to latents (and vice versa) using the VAE."
+ " Encoding or decoding the whole batch at once may run into OOM issues."
+ ),
+ )
+ parser.add_argument(
+ "--timestep_scaling_factor",
+ type=float,
+ default=10.0,
+ help=(
+ "The multiplicative timestep scaling factor used when calculating the boundary scalings for LCM. The"
+ " higher the scaling is, the lower the approximation error, but the default value of 10.0 should typically"
+ " suffice."
+ ),
+ )
+ # ----Mixed Precision----
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument(
+ "--cast_teacher_unet",
+ action="store_true",
+ help="Whether to cast the teacher U-Net to the precision specified by `--mixed_precision`.",
+ )
+ # ----Training Optimizations----
+ parser.add_argument(
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ # ----Distributed Training----
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+ # ----------Validation Arguments----------
+ parser.add_argument(
+ "--validation_steps",
+ type=int,
+ default=200,
+ help="Run validation every X steps.",
+ )
+ # ----------Huggingface Hub Arguments-----------
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ # ----------Accelerate Arguments----------
+ parser.add_argument(
+ "--tracker_project_name",
+ type=str,
+ default="text2image-fine-tune",
+ help=(
+ "The `project_name` argument passed to Accelerator.init_trackers for"
+ " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
+ ),
+ )
+
+ args = parser.parse_args()
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1:
+ raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].")
+
+ return args
+
+
+# Adapted from pipelines.StableDiffusionPipeline.encode_prompt
+def encode_prompt(prompt_batch, text_encoder, tokenizer, proportion_empty_prompts, is_train=True):
+ captions = []
+ for caption in prompt_batch:
+ if random.random() < proportion_empty_prompts:
+ captions.append("")
+ elif isinstance(caption, str):
+ captions.append(caption)
+ elif isinstance(caption, (list, np.ndarray)):
+ # take a random caption if there are multiple
+ captions.append(random.choice(caption) if is_train else caption[0])
+
+ with torch.no_grad():
+ text_inputs = tokenizer(
+ captions,
+ padding="max_length",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ prompt_embeds = text_encoder(text_input_ids.to(text_encoder.device))[0]
+
+ return prompt_embeds
+
+
+def main(args):
+ if args.report_to == "wandb" and args.hub_token is not None:
+ raise ValueError(
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
+ " Please use `huggingface-cli login` to authenticate with the Hub."
+ )
+
+ logging_dir = Path(args.output_dir, args.logging_dir)
+
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ split_batches=True, # It's important to set this to True when using webdataset to get the right number of steps for lr scheduling. If set to False, the number of steps will be devide by the number of processes assuming batches are multiplied by the number of processes
+ )
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name,
+ exist_ok=True,
+ token=args.hub_token,
+ private=True,
+ ).repo_id
+
+ # 1. Create the noise scheduler and the desired noise schedule.
+ noise_scheduler = DDPMScheduler.from_pretrained(
+ args.pretrained_teacher_model, subfolder="scheduler", revision=args.teacher_revision
+ )
+
+ # DDPMScheduler calculates the alpha and sigma noise schedules (based on the alpha bars) for us
+ alpha_schedule = torch.sqrt(noise_scheduler.alphas_cumprod)
+ sigma_schedule = torch.sqrt(1 - noise_scheduler.alphas_cumprod)
+ # Initialize the DDIM ODE solver for distillation.
+ solver = DDIMSolver(
+ noise_scheduler.alphas_cumprod.numpy(),
+ timesteps=noise_scheduler.config.num_train_timesteps,
+ ddim_timesteps=args.num_ddim_timesteps,
+ )
+
+ # 2. Load tokenizers from SD 1.X/2.X checkpoint.
+ tokenizer = AutoTokenizer.from_pretrained(
+ args.pretrained_teacher_model, subfolder="tokenizer", revision=args.teacher_revision, use_fast=False
+ )
+
+ # 3. Load text encoders from SD 1.X/2.X checkpoint.
+ # import correct text encoder classes
+ text_encoder = CLIPTextModel.from_pretrained(
+ args.pretrained_teacher_model, subfolder="text_encoder", revision=args.teacher_revision
+ )
+
+ # 4. Load VAE from SD 1.X/2.X checkpoint
+ vae = AutoencoderKL.from_pretrained(
+ args.pretrained_teacher_model,
+ subfolder="vae",
+ revision=args.teacher_revision,
+ )
+
+ # 5. Load teacher U-Net from SD 1.X/2.X checkpoint
+ teacher_unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_teacher_model, subfolder="unet", revision=args.teacher_revision
+ )
+
+ # 6. Freeze teacher vae, text_encoder, and teacher_unet
+ vae.requires_grad_(False)
+ text_encoder.requires_grad_(False)
+ teacher_unet.requires_grad_(False)
+
+ # 7. Create online student U-Net.
+ unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_teacher_model, subfolder="unet", revision=args.teacher_revision
+ )
+ unet.train()
+
+ # Check that all trainable models are in full precision
+ low_precision_error_string = (
+ " Please make sure to always have all model weights in full float32 precision when starting training - even if"
+ " doing mixed precision training, copy of the weights should still be float32."
+ )
+
+ if accelerator.unwrap_model(unet).dtype != torch.float32:
+ raise ValueError(
+ f"Controlnet loaded as datatype {accelerator.unwrap_model(unet).dtype}. {low_precision_error_string}"
+ )
+
+ # 8. Add LoRA to the student U-Net, only the LoRA projection matrix will be updated by the optimizer.
+ if args.lora_target_modules is not None:
+ lora_target_modules = [module_key.strip() for module_key in args.lora_target_modules.split(",")]
+ else:
+ lora_target_modules = [
+ "to_q",
+ "to_k",
+ "to_v",
+ "to_out.0",
+ "proj_in",
+ "proj_out",
+ "ff.net.0.proj",
+ "ff.net.2",
+ "conv1",
+ "conv2",
+ "conv_shortcut",
+ "downsamplers.0.conv",
+ "upsamplers.0.conv",
+ "time_emb_proj",
+ ]
+ lora_config = LoraConfig(
+ r=args.lora_rank,
+ target_modules=lora_target_modules,
+ lora_alpha=args.lora_alpha,
+ lora_dropout=args.lora_dropout,
+ )
+ unet = get_peft_model(unet, lora_config)
+
+ # 9. Handle mixed precision and device placement
+ # For mixed precision training we cast all non-trainable weigths to half-precision
+ # as these weights are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ # Move unet, vae and text_encoder to device and cast to weight_dtype
+ # The VAE is in float32 to avoid NaN losses.
+ vae.to(accelerator.device)
+ if args.pretrained_vae_model_name_or_path is not None:
+ vae.to(dtype=weight_dtype)
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
+
+ # Move teacher_unet to device, optionally cast to weight_dtype
+ teacher_unet.to(accelerator.device)
+ if args.cast_teacher_unet:
+ teacher_unet.to(dtype=weight_dtype)
+
+ # Also move the alpha and sigma noise schedules to accelerator.device.
+ alpha_schedule = alpha_schedule.to(accelerator.device)
+ sigma_schedule = sigma_schedule.to(accelerator.device)
+ # Move the ODE solver to accelerator.device.
+ solver = solver.to(accelerator.device)
+
+ # 10. Handle saving and loading of checkpoints
+ # `accelerate` 0.16.0 will have better support for customized saving
+ if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
+ def save_model_hook(models, weights, output_dir):
+ if accelerator.is_main_process:
+ unet_ = accelerator.unwrap_model(unet)
+ lora_state_dict = get_peft_model_state_dict(unet_, adapter_name="default")
+ StableDiffusionPipeline.save_lora_weights(os.path.join(output_dir, "unet_lora"), lora_state_dict)
+ # save weights in peft format to be able to load them back
+ unet_.save_pretrained(output_dir)
+
+ for _, model in enumerate(models):
+ # make sure to pop weight so that corresponding model is not saved again
+ weights.pop()
+
+ def load_model_hook(models, input_dir):
+ # load the LoRA into the model
+ unet_ = accelerator.unwrap_model(unet)
+ unet_.load_adapter(input_dir, "default", is_trainable=True)
+
+ for _ in range(len(models)):
+ # pop models so that they are not loaded again
+ models.pop()
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
+
+ # 11. Enable optimizations
+ if args.enable_xformers_memory_efficient_attention:
+ if is_xformers_available():
+ import xformers
+
+ xformers_version = version.parse(xformers.__version__)
+ if xformers_version == version.parse("0.0.16"):
+ logger.warning(
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
+ )
+ unet.enable_xformers_memory_efficient_attention()
+ teacher_unet.enable_xformers_memory_efficient_attention()
+ # target_unet.enable_xformers_memory_efficient_attention()
+ else:
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.gradient_checkpointing:
+ unet.enable_gradient_checkpointing()
+
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
+ )
+
+ optimizer_class = bnb.optim.AdamW8bit
+ else:
+ optimizer_class = torch.optim.AdamW
+
+ # 12. Optimizer creation
+ optimizer = optimizer_class(
+ unet.parameters(),
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ # 13. Dataset creation and data processing
+ # Here, we compute not just the text embeddings but also the additional embeddings
+ # needed for the SD XL UNet to operate.
+ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tokenizer, is_train=True):
+ prompt_embeds = encode_prompt(prompt_batch, text_encoder, tokenizer, proportion_empty_prompts, is_train)
+ return {"prompt_embeds": prompt_embeds}
+
+ dataset = SDText2ImageDataset(
+ train_shards_path_or_url=args.train_shards_path_or_url,
+ num_train_examples=args.max_train_samples,
+ per_gpu_batch_size=args.train_batch_size,
+ global_batch_size=args.train_batch_size * accelerator.num_processes,
+ num_workers=args.dataloader_num_workers,
+ resolution=args.resolution,
+ interpolation_type=args.interpolation_type,
+ shuffle_buffer_size=1000,
+ pin_memory=True,
+ persistent_workers=True,
+ )
+ train_dataloader = dataset.train_dataloader
+
+ compute_embeddings_fn = functools.partial(
+ compute_embeddings,
+ proportion_empty_prompts=0,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ )
+
+ # 14. LR Scheduler creation
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(train_dataloader.num_batches / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps,
+ num_training_steps=args.max_train_steps,
+ )
+
+ # 15. Prepare for training
+ # Prepare everything with our `accelerator`.
+ unet, optimizer, lr_scheduler = accelerator.prepare(unet, optimizer, lr_scheduler)
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(train_dataloader.num_batches / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ tracker_config = dict(vars(args))
+ accelerator.init_trackers(args.tracker_project_name, config=tracker_config)
+
+ uncond_input_ids = tokenizer(
+ [""] * args.train_batch_size, return_tensors="pt", padding="max_length", max_length=77
+ ).input_ids.to(accelerator.device)
+ uncond_prompt_embeds = text_encoder(uncond_input_ids)[0]
+
+ if torch.backends.mps.is_available():
+ autocast_ctx = nullcontext()
+ else:
+ autocast_ctx = torch.autocast(accelerator.device.type)
+
+ # 16. Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num batches each epoch = {train_dataloader.num_batches}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the most recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ initial_global_step = 0
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ initial_global_step = global_step
+ first_epoch = global_step // num_update_steps_per_epoch
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
+
+ for epoch in range(first_epoch, args.num_train_epochs):
+ for step, batch in enumerate(train_dataloader):
+ with accelerator.accumulate(unet):
+ # 1. Load and process the image and text conditioning
+ image, text = batch
+
+ image = image.to(accelerator.device, non_blocking=True)
+ encoded_text = compute_embeddings_fn(text)
+
+ pixel_values = image.to(dtype=weight_dtype)
+ if vae.dtype != weight_dtype:
+ vae.to(dtype=weight_dtype)
+
+ # encode pixel values with batch size of at most args.vae_encode_batch_size
+ latents = []
+ for i in range(0, pixel_values.shape[0], args.vae_encode_batch_size):
+ latents.append(vae.encode(pixel_values[i : i + args.vae_encode_batch_size]).latent_dist.sample())
+ latents = torch.cat(latents, dim=0)
+
+ latents = latents * vae.config.scaling_factor
+ latents = latents.to(weight_dtype)
+ bsz = latents.shape[0]
+
+ # 2. Sample a random timestep for each image t_n from the ODE solver timesteps without bias.
+ # For the DDIM solver, the timestep schedule is [T - 1, T - k - 1, T - 2 * k - 1, ...]
+ topk = noise_scheduler.config.num_train_timesteps // args.num_ddim_timesteps
+ index = torch.randint(0, args.num_ddim_timesteps, (bsz,), device=latents.device).long()
+ start_timesteps = solver.ddim_timesteps[index]
+ timesteps = start_timesteps - topk
+ timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps)
+
+ # 3. Get boundary scalings for start_timesteps and (end) timesteps.
+ c_skip_start, c_out_start = scalings_for_boundary_conditions(
+ start_timesteps, timestep_scaling=args.timestep_scaling_factor
+ )
+ c_skip_start, c_out_start = [append_dims(x, latents.ndim) for x in [c_skip_start, c_out_start]]
+ c_skip, c_out = scalings_for_boundary_conditions(
+ timesteps, timestep_scaling=args.timestep_scaling_factor
+ )
+ c_skip, c_out = [append_dims(x, latents.ndim) for x in [c_skip, c_out]]
+
+ # 4. Sample noise from the prior and add it to the latents according to the noise magnitude at each
+ # timestep (this is the forward diffusion process) [z_{t_{n + k}} in Algorithm 1]
+ noise = torch.randn_like(latents)
+ noisy_model_input = noise_scheduler.add_noise(latents, noise, start_timesteps)
+
+ # 5. Sample a random guidance scale w from U[w_min, w_max]
+ # Note that for LCM-LoRA distillation it is not necessary to use a guidance scale embedding
+ w = (args.w_max - args.w_min) * torch.rand((bsz,)) + args.w_min
+ w = w.reshape(bsz, 1, 1, 1)
+ w = w.to(device=latents.device, dtype=latents.dtype)
+
+ # 6. Prepare prompt embeds and unet_added_conditions
+ prompt_embeds = encoded_text.pop("prompt_embeds")
+
+ # 7. Get online LCM prediction on z_{t_{n + k}} (noisy_model_input), w, c, t_{n + k} (start_timesteps)
+ noise_pred = unet(
+ noisy_model_input,
+ start_timesteps,
+ timestep_cond=None,
+ encoder_hidden_states=prompt_embeds.float(),
+ added_cond_kwargs=encoded_text,
+ ).sample
+
+ pred_x_0 = get_predicted_original_sample(
+ noise_pred,
+ start_timesteps,
+ noisy_model_input,
+ noise_scheduler.config.prediction_type,
+ alpha_schedule,
+ sigma_schedule,
+ )
+
+ model_pred = c_skip_start * noisy_model_input + c_out_start * pred_x_0
+
+ # 8. Compute the conditional and unconditional teacher model predictions to get CFG estimates of the
+ # predicted noise eps_0 and predicted original sample x_0, then run the ODE solver using these
+ # estimates to predict the data point in the augmented PF-ODE trajectory corresponding to the next ODE
+ # solver timestep.
+ with torch.no_grad():
+ with autocast_ctx:
+ # 1. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and conditional embedding c
+ cond_teacher_output = teacher_unet(
+ noisy_model_input.to(weight_dtype),
+ start_timesteps,
+ encoder_hidden_states=prompt_embeds.to(weight_dtype),
+ ).sample
+ cond_pred_x0 = get_predicted_original_sample(
+ cond_teacher_output,
+ start_timesteps,
+ noisy_model_input,
+ noise_scheduler.config.prediction_type,
+ alpha_schedule,
+ sigma_schedule,
+ )
+ cond_pred_noise = get_predicted_noise(
+ cond_teacher_output,
+ start_timesteps,
+ noisy_model_input,
+ noise_scheduler.config.prediction_type,
+ alpha_schedule,
+ sigma_schedule,
+ )
+
+ # 2. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and unconditional embedding 0
+ uncond_teacher_output = teacher_unet(
+ noisy_model_input.to(weight_dtype),
+ start_timesteps,
+ encoder_hidden_states=uncond_prompt_embeds.to(weight_dtype),
+ ).sample
+ uncond_pred_x0 = get_predicted_original_sample(
+ uncond_teacher_output,
+ start_timesteps,
+ noisy_model_input,
+ noise_scheduler.config.prediction_type,
+ alpha_schedule,
+ sigma_schedule,
+ )
+ uncond_pred_noise = get_predicted_noise(
+ uncond_teacher_output,
+ start_timesteps,
+ noisy_model_input,
+ noise_scheduler.config.prediction_type,
+ alpha_schedule,
+ sigma_schedule,
+ )
+
+ # 3. Calculate the CFG estimate of x_0 (pred_x0) and eps_0 (pred_noise)
+ # Note that this uses the LCM paper's CFG formulation rather than the Imagen CFG formulation
+ pred_x0 = cond_pred_x0 + w * (cond_pred_x0 - uncond_pred_x0)
+ pred_noise = cond_pred_noise + w * (cond_pred_noise - uncond_pred_noise)
+ # 4. Run one step of the ODE solver to estimate the next point x_prev on the
+ # augmented PF-ODE trajectory (solving backward in time)
+ # Note that the DDIM step depends on both the predicted x_0 and source noise eps_0.
+ x_prev = solver.ddim_step(pred_x0, pred_noise, index)
+
+ # 9. Get target LCM prediction on x_prev, w, c, t_n (timesteps)
+ # Note that we do not use a separate target network for LCM-LoRA distillation.
+ with torch.no_grad():
+ with autocast_ctx:
+ target_noise_pred = unet(
+ x_prev.float(),
+ timesteps,
+ timestep_cond=None,
+ encoder_hidden_states=prompt_embeds.float(),
+ ).sample
+ pred_x_0 = get_predicted_original_sample(
+ target_noise_pred,
+ timesteps,
+ x_prev,
+ noise_scheduler.config.prediction_type,
+ alpha_schedule,
+ sigma_schedule,
+ )
+ target = c_skip * x_prev + c_out * pred_x_0
+
+ # 10. Calculate loss
+ if args.loss_type == "l2":
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
+ elif args.loss_type == "huber":
+ loss = torch.mean(
+ torch.sqrt((model_pred.float() - target.float()) ** 2 + args.huber_c**2) - args.huber_c
+ )
+
+ # 11. Backpropagate on the online student model (`unet`)
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad(set_to_none=True)
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ global_step += 1
+
+ if accelerator.is_main_process:
+ if global_step % args.checkpointing_steps == 0:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ if global_step % args.validation_steps == 0:
+ log_validation(vae, unet, args, accelerator, weight_dtype, global_step)
+
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+ accelerator.log(logs, step=global_step)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ # Create the pipeline using using the trained modules and save it.
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ unet = accelerator.unwrap_model(unet)
+ unet.save_pretrained(args.output_dir)
+ lora_state_dict = get_peft_model_state_dict(unet, adapter_name="default")
+ StableDiffusionPipeline.save_lora_weights(os.path.join(args.output_dir, "unet_lora"), lora_state_dict)
+
+ if args.push_to_hub:
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ main(args)
diff --git a/diffusers/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py b/diffusers/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py
new file mode 100644
index 0000000000000000000000000000000000000000..8090926974c44463d1475c94ac8e0bb055d0d756
--- /dev/null
+++ b/diffusers/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py
@@ -0,0 +1,1467 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2024 The LCM team and the HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+import argparse
+import copy
+import functools
+import gc
+import logging
+import math
+import os
+import random
+import shutil
+from contextlib import nullcontext
+from pathlib import Path
+
+import accelerate
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import ProjectConfiguration, set_seed
+from datasets import load_dataset
+from huggingface_hub import create_repo, upload_folder
+from packaging import version
+from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict
+from torchvision import transforms
+from torchvision.transforms.functional import crop
+from tqdm.auto import tqdm
+from transformers import AutoTokenizer, PretrainedConfig
+
+import diffusers
+from diffusers import (
+ AutoencoderKL,
+ DDPMScheduler,
+ LCMScheduler,
+ StableDiffusionXLPipeline,
+ UNet2DConditionModel,
+)
+from diffusers.optimization import get_scheduler
+from diffusers.training_utils import cast_training_params, resolve_interpolation_mode
+from diffusers.utils import (
+ check_min_version,
+ convert_state_dict_to_diffusers,
+ convert_unet_state_dict_to_peft,
+ is_wandb_available,
+)
+from diffusers.utils.import_utils import is_xformers_available
+
+
+if is_wandb_available():
+ import wandb
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.31.0.dev0")
+
+logger = get_logger(__name__)
+
+DATASET_NAME_MAPPING = {
+ "lambdalabs/naruto-blip-captions": ("image", "text"),
+}
+
+
+class DDIMSolver:
+ def __init__(self, alpha_cumprods, timesteps=1000, ddim_timesteps=50):
+ # DDIM sampling parameters
+ step_ratio = timesteps // ddim_timesteps
+
+ self.ddim_timesteps = (np.arange(1, ddim_timesteps + 1) * step_ratio).round().astype(np.int64) - 1
+ self.ddim_alpha_cumprods = alpha_cumprods[self.ddim_timesteps]
+ self.ddim_alpha_cumprods_prev = np.asarray(
+ [alpha_cumprods[0]] + alpha_cumprods[self.ddim_timesteps[:-1]].tolist()
+ )
+ # convert to torch tensors
+ self.ddim_timesteps = torch.from_numpy(self.ddim_timesteps).long()
+ self.ddim_alpha_cumprods = torch.from_numpy(self.ddim_alpha_cumprods)
+ self.ddim_alpha_cumprods_prev = torch.from_numpy(self.ddim_alpha_cumprods_prev)
+
+ def to(self, device):
+ self.ddim_timesteps = self.ddim_timesteps.to(device)
+ self.ddim_alpha_cumprods = self.ddim_alpha_cumprods.to(device)
+ self.ddim_alpha_cumprods_prev = self.ddim_alpha_cumprods_prev.to(device)
+ return self
+
+ def ddim_step(self, pred_x0, pred_noise, timestep_index):
+ alpha_cumprod_prev = extract_into_tensor(self.ddim_alpha_cumprods_prev, timestep_index, pred_x0.shape)
+ dir_xt = (1.0 - alpha_cumprod_prev).sqrt() * pred_noise
+ x_prev = alpha_cumprod_prev.sqrt() * pred_x0 + dir_xt
+ return x_prev
+
+
+def log_validation(vae, args, accelerator, weight_dtype, step, unet=None, is_final_validation=False):
+ logger.info("Running validation... ")
+
+ pipeline = StableDiffusionXLPipeline.from_pretrained(
+ args.pretrained_teacher_model,
+ vae=vae,
+ scheduler=LCMScheduler.from_pretrained(args.pretrained_teacher_model, subfolder="scheduler"),
+ revision=args.revision,
+ torch_dtype=weight_dtype,
+ ).to(accelerator.device)
+ pipeline.set_progress_bar_config(disable=True)
+
+ to_load = None
+ if not is_final_validation:
+ if unet is None:
+ raise ValueError("Must provide a `unet` when doing intermediate validation.")
+ unet = accelerator.unwrap_model(unet)
+ state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet))
+ to_load = state_dict
+ else:
+ to_load = args.output_dir
+
+ pipeline.load_lora_weights(to_load)
+ pipeline.fuse_lora()
+
+ if args.enable_xformers_memory_efficient_attention:
+ pipeline.enable_xformers_memory_efficient_attention()
+
+ if args.seed is None:
+ generator = None
+ else:
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
+
+ validation_prompts = [
+ "cute sundar pichai character",
+ "robotic cat with wings",
+ "a photo of yoda",
+ "a cute creature with blue eyes",
+ ]
+
+ image_logs = []
+
+ for _, prompt in enumerate(validation_prompts):
+ images = []
+ if torch.backends.mps.is_available():
+ autocast_ctx = nullcontext()
+ else:
+ autocast_ctx = torch.autocast(accelerator.device.type, dtype=weight_dtype)
+
+ with autocast_ctx:
+ images = pipeline(
+ prompt=prompt,
+ num_inference_steps=4,
+ num_images_per_prompt=4,
+ generator=generator,
+ guidance_scale=0.0,
+ ).images
+ image_logs.append({"validation_prompt": prompt, "images": images})
+
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ for log in image_logs:
+ images = log["images"]
+ validation_prompt = log["validation_prompt"]
+ formatted_images = []
+ for image in images:
+ formatted_images.append(np.asarray(image))
+
+ formatted_images = np.stack(formatted_images)
+
+ tracker.writer.add_images(validation_prompt, formatted_images, step, dataformats="NHWC")
+ elif tracker.name == "wandb":
+ formatted_images = []
+
+ for log in image_logs:
+ images = log["images"]
+ validation_prompt = log["validation_prompt"]
+ for image in images:
+ image = wandb.Image(image, caption=validation_prompt)
+ formatted_images.append(image)
+ logger_name = "test" if is_final_validation else "validation"
+ tracker.log({logger_name: formatted_images})
+ else:
+ logger.warning(f"image logging not implemented for {tracker.name}")
+
+ del pipeline
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ return image_logs
+
+
+def append_dims(x, target_dims):
+ """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
+ dims_to_append = target_dims - x.ndim
+ if dims_to_append < 0:
+ raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less")
+ return x[(...,) + (None,) * dims_to_append]
+
+
+# From LCMScheduler.get_scalings_for_boundary_condition_discrete
+def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=10.0):
+ scaled_timestep = timestep_scaling * timestep
+ c_skip = sigma_data**2 / (scaled_timestep**2 + sigma_data**2)
+ c_out = scaled_timestep / (scaled_timestep**2 + sigma_data**2) ** 0.5
+ return c_skip, c_out
+
+
+# Compare LCMScheduler.step, Step 4
+def get_predicted_original_sample(model_output, timesteps, sample, prediction_type, alphas, sigmas):
+ alphas = extract_into_tensor(alphas, timesteps, sample.shape)
+ sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
+ if prediction_type == "epsilon":
+ pred_x_0 = (sample - sigmas * model_output) / alphas
+ elif prediction_type == "sample":
+ pred_x_0 = model_output
+ elif prediction_type == "v_prediction":
+ pred_x_0 = alphas * sample - sigmas * model_output
+ else:
+ raise ValueError(
+ f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`"
+ f" are supported."
+ )
+
+ return pred_x_0
+
+
+# Based on step 4 in DDIMScheduler.step
+def get_predicted_noise(model_output, timesteps, sample, prediction_type, alphas, sigmas):
+ alphas = extract_into_tensor(alphas, timesteps, sample.shape)
+ sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
+ if prediction_type == "epsilon":
+ pred_epsilon = model_output
+ elif prediction_type == "sample":
+ pred_epsilon = (sample - alphas * model_output) / sigmas
+ elif prediction_type == "v_prediction":
+ pred_epsilon = alphas * model_output + sigmas * sample
+ else:
+ raise ValueError(
+ f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`"
+ f" are supported."
+ )
+
+ return pred_epsilon
+
+
+def extract_into_tensor(a, t, x_shape):
+ b, *_ = t.shape
+ out = a.gather(-1, t)
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
+
+
+def import_model_class_from_model_name_or_path(
+ pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
+):
+ text_encoder_config = PretrainedConfig.from_pretrained(
+ pretrained_model_name_or_path, subfolder=subfolder, revision=revision
+ )
+ model_class = text_encoder_config.architectures[0]
+
+ if model_class == "CLIPTextModel":
+ from transformers import CLIPTextModel
+
+ return CLIPTextModel
+ elif model_class == "CLIPTextModelWithProjection":
+ from transformers import CLIPTextModelWithProjection
+
+ return CLIPTextModelWithProjection
+ else:
+ raise ValueError(f"{model_class} is not supported.")
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ # ----------Model Checkpoint Loading Arguments----------
+ parser.add_argument(
+ "--pretrained_teacher_model",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained LDM teacher model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--pretrained_vae_model_name_or_path",
+ type=str,
+ default=None,
+ help="Path to pretrained VAE model with better numerical stability. More details: https://github.com/huggingface/diffusers/pull/4038.",
+ )
+ parser.add_argument(
+ "--teacher_revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained LDM teacher model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained LDM model identifier from huggingface.co/models.",
+ )
+ # ----------Training Arguments----------
+ # ----General Training Arguments----
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="lcm-xl-distilled",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ # ----Logging----
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ # ----Checkpointing----
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ # ----Image Processing----
+ parser.add_argument(
+ "--dataset_name",
+ type=str,
+ default=None,
+ help=(
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
+ " or to a folder containing files that 🤗 Datasets can understand."
+ ),
+ )
+ parser.add_argument(
+ "--dataset_config_name",
+ type=str,
+ default=None,
+ help="The config of the Dataset, leave as None if there's only one config.",
+ )
+ parser.add_argument(
+ "--train_data_dir",
+ type=str,
+ default=None,
+ help=(
+ "A folder containing the training data. Folder contents must follow the structure described in"
+ " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
+ " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
+ ),
+ )
+ parser.add_argument(
+ "--image_column", type=str, default="image", help="The column of the dataset containing an image."
+ )
+ parser.add_argument(
+ "--caption_column",
+ type=str,
+ default="text",
+ help="The column of the dataset containing a caption or a list of captions.",
+ )
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=1024,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--interpolation_type",
+ type=str,
+ default="bilinear",
+ help=(
+ "The interpolation function used when resizing images to the desired resolution. Choose between `bilinear`,"
+ " `bicubic`, `box`, `nearest`, `nearest_exact`, `hamming`, and `lanczos`."
+ ),
+ )
+ parser.add_argument(
+ "--center_crop",
+ default=False,
+ action="store_true",
+ help=(
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
+ " cropped. The images will be resized to the resolution first before cropping."
+ ),
+ )
+ parser.add_argument(
+ "--random_flip",
+ action="store_true",
+ help="whether to randomly flip images horizontally",
+ )
+ parser.add_argument(
+ "--encode_batch_size",
+ type=int,
+ default=8,
+ help="Batch size to use for VAE encoding of the images for efficient processing.",
+ )
+ # ----Dataloader----
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ # ----Batch Size and Training Steps----
+ parser.add_argument(
+ "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=100)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--max_train_samples",
+ type=int,
+ default=None,
+ help=(
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ ),
+ )
+ # ----Learning Rate----
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-6,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ # ----Optimizer (Adam)----
+ parser.add_argument(
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ # ----Diffusion Training Arguments----
+ # ----Latent Consistency Distillation (LCD) Specific Arguments----
+ parser.add_argument(
+ "--w_min",
+ type=float,
+ default=3.0,
+ required=False,
+ help=(
+ "The minimum guidance scale value for guidance scale sampling. Note that we are using the Imagen CFG"
+ " formulation rather than the LCM formulation, which means all guidance scales have 1 added to them as"
+ " compared to the original paper."
+ ),
+ )
+ parser.add_argument(
+ "--w_max",
+ type=float,
+ default=15.0,
+ required=False,
+ help=(
+ "The maximum guidance scale value for guidance scale sampling. Note that we are using the Imagen CFG"
+ " formulation rather than the LCM formulation, which means all guidance scales have 1 added to them as"
+ " compared to the original paper."
+ ),
+ )
+ parser.add_argument(
+ "--num_ddim_timesteps",
+ type=int,
+ default=50,
+ help="The number of timesteps to use for DDIM sampling.",
+ )
+ parser.add_argument(
+ "--loss_type",
+ type=str,
+ default="l2",
+ choices=["l2", "huber"],
+ help="The type of loss to use for the LCD loss.",
+ )
+ parser.add_argument(
+ "--huber_c",
+ type=float,
+ default=0.001,
+ help="The huber loss parameter. Only used if `--loss_type=huber`.",
+ )
+ parser.add_argument(
+ "--lora_rank",
+ type=int,
+ default=64,
+ help="The rank of the LoRA projection matrix.",
+ )
+ parser.add_argument(
+ "--lora_alpha",
+ type=int,
+ default=64,
+ help=(
+ "The value of the LoRA alpha parameter, which controls the scaling factor in front of the LoRA weight"
+ " update delta_W. No scaling will be performed if this value is equal to `lora_rank`."
+ ),
+ )
+ parser.add_argument(
+ "--lora_dropout",
+ type=float,
+ default=0.0,
+ help="The dropout probability for the dropout layer added before applying the LoRA to each layer input.",
+ )
+ parser.add_argument(
+ "--lora_target_modules",
+ type=str,
+ default=None,
+ help=(
+ "A comma-separated string of target module keys to add LoRA to. If not set, a default list of modules will"
+ " be used. By default, LoRA will be applied to all conv and linear layers."
+ ),
+ )
+ parser.add_argument(
+ "--vae_encode_batch_size",
+ type=int,
+ default=8,
+ required=False,
+ help=(
+ "The batch size used when encoding (and decoding) images to latents (and vice versa) using the VAE."
+ " Encoding or decoding the whole batch at once may run into OOM issues."
+ ),
+ )
+ parser.add_argument(
+ "--timestep_scaling_factor",
+ type=float,
+ default=10.0,
+ help=(
+ "The multiplicative timestep scaling factor used when calculating the boundary scalings for LCM. The"
+ " higher the scaling is, the lower the approximation error, but the default value of 10.0 should typically"
+ " suffice."
+ ),
+ )
+ # ----Mixed Precision----
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ # ----Training Optimizations----
+ parser.add_argument(
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ # ----Distributed Training----
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+ # ----------Validation Arguments----------
+ parser.add_argument(
+ "--validation_steps",
+ type=int,
+ default=200,
+ help="Run validation every X steps.",
+ )
+ # ----------Huggingface Hub Arguments-----------
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ # ----------Accelerate Arguments----------
+ parser.add_argument(
+ "--tracker_project_name",
+ type=str,
+ default="text2image-fine-tune",
+ help=(
+ "The `project_name` argument passed to Accelerator.init_trackers for"
+ " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
+ ),
+ )
+
+ args = parser.parse_args()
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ return args
+
+
+# Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt
+def encode_prompt(prompt_batch, text_encoders, tokenizers, is_train=True):
+ prompt_embeds_list = []
+
+ captions = []
+ for caption in prompt_batch:
+ if isinstance(caption, str):
+ captions.append(caption)
+ elif isinstance(caption, (list, np.ndarray)):
+ # take a random caption if there are multiple
+ captions.append(random.choice(caption) if is_train else caption[0])
+
+ with torch.no_grad():
+ for tokenizer, text_encoder in zip(tokenizers, text_encoders):
+ text_inputs = tokenizer(
+ captions,
+ padding="max_length",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ prompt_embeds = text_encoder(
+ text_input_ids.to(text_encoder.device),
+ output_hidden_states=True,
+ )
+
+ # We are only ALWAYS interested in the pooled output of the final text encoder
+ pooled_prompt_embeds = prompt_embeds[0]
+ prompt_embeds = prompt_embeds.hidden_states[-2]
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
+ prompt_embeds_list.append(prompt_embeds)
+
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
+ pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
+ return prompt_embeds, pooled_prompt_embeds
+
+
+def main(args):
+ if args.report_to == "wandb" and args.hub_token is not None:
+ raise ValueError(
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
+ " Please use `huggingface-cli login` to authenticate with the Hub."
+ )
+
+ logging_dir = Path(args.output_dir, args.logging_dir)
+
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ split_batches=True, # It's important to set this to True when using webdataset to get the right number of steps for lr scheduling. If set to False, the number of steps will be devide by the number of processes assuming batches are multiplied by the number of processes
+ )
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name,
+ exist_ok=True,
+ token=args.hub_token,
+ private=True,
+ ).repo_id
+
+ # 1. Create the noise scheduler and the desired noise schedule.
+ noise_scheduler = DDPMScheduler.from_pretrained(
+ args.pretrained_teacher_model, subfolder="scheduler", revision=args.teacher_revision
+ )
+
+ # DDPMScheduler calculates the alpha and sigma noise schedules (based on the alpha bars) for us
+ alpha_schedule = torch.sqrt(noise_scheduler.alphas_cumprod)
+ sigma_schedule = torch.sqrt(1 - noise_scheduler.alphas_cumprod)
+ # Initialize the DDIM ODE solver for distillation.
+ solver = DDIMSolver(
+ noise_scheduler.alphas_cumprod.numpy(),
+ timesteps=noise_scheduler.config.num_train_timesteps,
+ ddim_timesteps=args.num_ddim_timesteps,
+ )
+
+ # 2. Load tokenizers from SDXL checkpoint.
+ tokenizer_one = AutoTokenizer.from_pretrained(
+ args.pretrained_teacher_model, subfolder="tokenizer", revision=args.teacher_revision, use_fast=False
+ )
+ tokenizer_two = AutoTokenizer.from_pretrained(
+ args.pretrained_teacher_model, subfolder="tokenizer_2", revision=args.teacher_revision, use_fast=False
+ )
+
+ # 3. Load text encoders from SDXL checkpoint.
+ # import correct text encoder classes
+ text_encoder_cls_one = import_model_class_from_model_name_or_path(
+ args.pretrained_teacher_model, args.teacher_revision
+ )
+ text_encoder_cls_two = import_model_class_from_model_name_or_path(
+ args.pretrained_teacher_model, args.teacher_revision, subfolder="text_encoder_2"
+ )
+
+ text_encoder_one = text_encoder_cls_one.from_pretrained(
+ args.pretrained_teacher_model, subfolder="text_encoder", revision=args.teacher_revision
+ )
+ text_encoder_two = text_encoder_cls_two.from_pretrained(
+ args.pretrained_teacher_model, subfolder="text_encoder_2", revision=args.teacher_revision
+ )
+
+ # 4. Load VAE from SDXL checkpoint (or more stable VAE)
+ vae_path = (
+ args.pretrained_teacher_model
+ if args.pretrained_vae_model_name_or_path is None
+ else args.pretrained_vae_model_name_or_path
+ )
+ vae = AutoencoderKL.from_pretrained(
+ vae_path,
+ subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
+ revision=args.teacher_revision,
+ )
+
+ # 6. Freeze teacher vae, text_encoders.
+ vae.requires_grad_(False)
+ text_encoder_one.requires_grad_(False)
+ text_encoder_two.requires_grad_(False)
+
+ # 7. Create online student U-Net.
+ unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_teacher_model, subfolder="unet", revision=args.teacher_revision
+ )
+ unet.requires_grad_(False)
+
+ # Check that all trainable models are in full precision
+ low_precision_error_string = (
+ " Please make sure to always have all model weights in full float32 precision when starting training - even if"
+ " doing mixed precision training, copy of the weights should still be float32."
+ )
+
+ if accelerator.unwrap_model(unet).dtype != torch.float32:
+ raise ValueError(
+ f"Controlnet loaded as datatype {accelerator.unwrap_model(unet).dtype}. {low_precision_error_string}"
+ )
+
+ # 8. Handle mixed precision and device placement
+ # For mixed precision training we cast all non-trainable weigths to half-precision
+ # as these weights are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ # Move unet, vae and text_encoder to device and cast to weight_dtype
+ # The VAE is in float32 to avoid NaN losses.
+ unet.to(accelerator.device, dtype=weight_dtype)
+ if args.pretrained_vae_model_name_or_path is None:
+ vae.to(accelerator.device, dtype=torch.float32)
+ else:
+ vae.to(accelerator.device, dtype=weight_dtype)
+ text_encoder_one.to(accelerator.device, dtype=weight_dtype)
+ text_encoder_two.to(accelerator.device, dtype=weight_dtype)
+
+ # 9. Add LoRA to the student U-Net, only the LoRA projection matrix will be updated by the optimizer.
+ if args.lora_target_modules is not None:
+ lora_target_modules = [module_key.strip() for module_key in args.lora_target_modules.split(",")]
+ else:
+ lora_target_modules = [
+ "to_q",
+ "to_k",
+ "to_v",
+ "to_out.0",
+ "proj_in",
+ "proj_out",
+ "ff.net.0.proj",
+ "ff.net.2",
+ "conv1",
+ "conv2",
+ "conv_shortcut",
+ "downsamplers.0.conv",
+ "upsamplers.0.conv",
+ "time_emb_proj",
+ ]
+ lora_config = LoraConfig(
+ r=args.lora_rank,
+ target_modules=lora_target_modules,
+ lora_alpha=args.lora_alpha,
+ lora_dropout=args.lora_dropout,
+ )
+ unet.add_adapter(lora_config)
+
+ # Also move the alpha and sigma noise schedules to accelerator.device.
+ alpha_schedule = alpha_schedule.to(accelerator.device)
+ sigma_schedule = sigma_schedule.to(accelerator.device)
+ solver = solver.to(accelerator.device)
+
+ # 10. Handle saving and loading of checkpoints
+ # `accelerate` 0.16.0 will have better support for customized saving
+ if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
+ def save_model_hook(models, weights, output_dir):
+ if accelerator.is_main_process:
+ unet_ = accelerator.unwrap_model(unet)
+ # also save the checkpoints in native `diffusers` format so that it can be easily
+ # be independently loaded via `load_lora_weights()`.
+ state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet_))
+ StableDiffusionXLPipeline.save_lora_weights(output_dir, unet_lora_layers=state_dict)
+
+ for _, model in enumerate(models):
+ # make sure to pop weight so that corresponding model is not saved again
+ weights.pop()
+
+ def load_model_hook(models, input_dir):
+ # load the LoRA into the model
+ unet_ = accelerator.unwrap_model(unet)
+ lora_state_dict, _ = StableDiffusionXLPipeline.lora_state_dict(input_dir)
+ unet_state_dict = {
+ f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")
+ }
+ unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
+ incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
+ if incompatible_keys is not None:
+ # check only for unexpected keys
+ unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
+ if unexpected_keys:
+ logger.warning(
+ f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
+ f" {unexpected_keys}. "
+ )
+
+ for _ in range(len(models)):
+ # pop models so that they are not loaded again
+ models.pop()
+
+ # Make sure the trainable params are in float32. This is again needed since the base models
+ # are in `weight_dtype`. More details:
+ # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
+ if args.mixed_precision == "fp16":
+ cast_training_params(unet_, dtype=torch.float32)
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
+
+ # 11. Enable optimizations
+ if args.enable_xformers_memory_efficient_attention:
+ if is_xformers_available():
+ import xformers
+
+ xformers_version = version.parse(xformers.__version__)
+ if xformers_version == version.parse("0.0.16"):
+ logger.warning(
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
+ )
+ unet.enable_xformers_memory_efficient_attention()
+ else:
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.gradient_checkpointing:
+ unet.enable_gradient_checkpointing()
+
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
+ )
+
+ optimizer_class = bnb.optim.AdamW8bit
+ else:
+ optimizer_class = torch.optim.AdamW
+
+ # 12. Optimizer creation
+ params_to_optimize = filter(lambda p: p.requires_grad, unet.parameters())
+ optimizer = optimizer_class(
+ params_to_optimize,
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ # 13. Dataset creation and data processing
+ # In distributed training, the load_dataset function guarantees that only one local process can concurrently
+ # download the dataset.
+ if args.dataset_name is not None:
+ # Downloading and loading a dataset from the hub.
+ dataset = load_dataset(
+ args.dataset_name,
+ args.dataset_config_name,
+ cache_dir=args.cache_dir,
+ )
+ else:
+ data_files = {}
+ if args.train_data_dir is not None:
+ data_files["train"] = os.path.join(args.train_data_dir, "**")
+ dataset = load_dataset(
+ "imagefolder",
+ data_files=data_files,
+ cache_dir=args.cache_dir,
+ )
+ # See more about loading custom images at
+ # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
+
+ # Preprocessing the datasets.
+ column_names = dataset["train"].column_names
+
+ # Get the column names for input/target.
+ dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None)
+ if args.image_column is None:
+ image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
+ else:
+ image_column = args.image_column
+ if image_column not in column_names:
+ raise ValueError(
+ f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}"
+ )
+ if args.caption_column is None:
+ caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
+ else:
+ caption_column = args.caption_column
+ if caption_column not in column_names:
+ raise ValueError(
+ f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}"
+ )
+
+ # Preprocessing the datasets.
+ interpolation_mode = resolve_interpolation_mode(args.interpolation_type)
+ train_resize = transforms.Resize(args.resolution, interpolation=interpolation_mode)
+ train_crop = transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution)
+ train_flip = transforms.RandomHorizontalFlip(p=1.0)
+ train_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
+
+ def preprocess_train(examples):
+ images = [image.convert("RGB") for image in examples[image_column]]
+ # image aug
+ original_sizes = []
+ all_images = []
+ crop_top_lefts = []
+ for image in images:
+ original_sizes.append((image.height, image.width))
+ image = train_resize(image)
+ if args.center_crop:
+ y1 = max(0, int(round((image.height - args.resolution) / 2.0)))
+ x1 = max(0, int(round((image.width - args.resolution) / 2.0)))
+ image = train_crop(image)
+ else:
+ y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution))
+ image = crop(image, y1, x1, h, w)
+ if args.random_flip and random.random() < 0.5:
+ # flip
+ x1 = image.width - x1
+ image = train_flip(image)
+ crop_top_left = (y1, x1)
+ crop_top_lefts.append(crop_top_left)
+ image = train_transforms(image)
+ all_images.append(image)
+
+ examples["original_sizes"] = original_sizes
+ examples["crop_top_lefts"] = crop_top_lefts
+ examples["pixel_values"] = all_images
+ examples["captions"] = list(examples[caption_column])
+ return examples
+
+ with accelerator.main_process_first():
+ if args.max_train_samples is not None:
+ dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
+ # Set the training transforms
+ train_dataset = dataset["train"].with_transform(preprocess_train)
+
+ def collate_fn(examples):
+ pixel_values = torch.stack([example["pixel_values"] for example in examples])
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
+ original_sizes = [example["original_sizes"] for example in examples]
+ crop_top_lefts = [example["crop_top_lefts"] for example in examples]
+ captions = [example["captions"] for example in examples]
+
+ return {
+ "pixel_values": pixel_values,
+ "captions": captions,
+ "original_sizes": original_sizes,
+ "crop_top_lefts": crop_top_lefts,
+ }
+
+ # DataLoaders creation:
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset,
+ shuffle=True,
+ collate_fn=collate_fn,
+ batch_size=args.train_batch_size,
+ num_workers=args.dataloader_num_workers,
+ )
+
+ # 14. Embeddings for the UNet.
+ # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
+ def compute_embeddings(prompt_batch, original_sizes, crop_coords, text_encoders, tokenizers, is_train=True):
+ def compute_time_ids(original_size, crops_coords_top_left):
+ target_size = (args.resolution, args.resolution)
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
+ add_time_ids = torch.tensor([add_time_ids])
+ add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype)
+ return add_time_ids
+
+ prompt_embeds, pooled_prompt_embeds = encode_prompt(prompt_batch, text_encoders, tokenizers, is_train)
+ add_text_embeds = pooled_prompt_embeds
+
+ add_time_ids = torch.cat([compute_time_ids(s, c) for s, c in zip(original_sizes, crop_coords)])
+
+ prompt_embeds = prompt_embeds.to(accelerator.device)
+ add_text_embeds = add_text_embeds.to(accelerator.device)
+ unet_added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
+
+ return {"prompt_embeds": prompt_embeds, **unet_added_cond_kwargs}
+
+ text_encoders = [text_encoder_one, text_encoder_two]
+ tokenizers = [tokenizer_one, tokenizer_two]
+
+ compute_embeddings_fn = functools.partial(compute_embeddings, text_encoders=text_encoders, tokenizers=tokenizers)
+
+ # 15. LR Scheduler creation
+ # Scheduler and math around the number of training steps.
+ # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
+ num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
+ if args.max_train_steps is None:
+ len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
+ num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
+ num_training_steps_for_scheduler = (
+ args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes
+ )
+ else:
+ num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ # Make sure the trainable params are in float32.
+ if args.mixed_precision == "fp16":
+ # only upcast trainable parameters (LoRA) into fp32
+ cast_training_params(unet, dtype=torch.float32)
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=num_warmup_steps_for_scheduler,
+ num_training_steps=num_training_steps_for_scheduler,
+ )
+
+ # 16. Prepare for training
+ # Prepare everything with our `accelerator`.
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ unet, optimizer, train_dataloader, lr_scheduler
+ )
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ if num_training_steps_for_scheduler != args.max_train_steps * accelerator.num_processes:
+ logger.warning(
+ f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
+ f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
+ f"This inconsistency may result in the learning rate scheduler not functioning properly."
+ )
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ tracker_config = dict(vars(args))
+ accelerator.init_trackers(args.tracker_project_name, config=tracker_config)
+
+ # 17. Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the most recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ initial_global_step = 0
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ initial_global_step = global_step
+ first_epoch = global_step // num_update_steps_per_epoch
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
+
+ unet.train()
+ for epoch in range(first_epoch, args.num_train_epochs):
+ for step, batch in enumerate(train_dataloader):
+ with accelerator.accumulate(unet):
+ # 1. Load and process the image and text conditioning
+ pixel_values, text, orig_size, crop_coords = (
+ batch["pixel_values"],
+ batch["captions"],
+ batch["original_sizes"],
+ batch["crop_top_lefts"],
+ )
+
+ encoded_text = compute_embeddings_fn(text, orig_size, crop_coords)
+
+ # encode pixel values with batch size of at most args.vae_encode_batch_size
+ pixel_values = pixel_values.to(dtype=vae.dtype)
+ latents = []
+ for i in range(0, pixel_values.shape[0], args.vae_encode_batch_size):
+ latents.append(vae.encode(pixel_values[i : i + args.vae_encode_batch_size]).latent_dist.sample())
+ latents = torch.cat(latents, dim=0)
+
+ latents = latents * vae.config.scaling_factor
+ if args.pretrained_vae_model_name_or_path is None:
+ latents = latents.to(weight_dtype)
+
+ # 2. Sample a random timestep for each image t_n from the ODE solver timesteps without bias.
+ # For the DDIM solver, the timestep schedule is [T - 1, T - k - 1, T - 2 * k - 1, ...]
+ bsz = latents.shape[0]
+ topk = noise_scheduler.config.num_train_timesteps // args.num_ddim_timesteps
+ index = torch.randint(0, args.num_ddim_timesteps, (bsz,), device=latents.device).long()
+ start_timesteps = solver.ddim_timesteps[index]
+ timesteps = start_timesteps - topk
+ timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps)
+
+ # 3. Get boundary scalings for start_timesteps and (end) timesteps.
+ c_skip_start, c_out_start = scalings_for_boundary_conditions(
+ start_timesteps, timestep_scaling=args.timestep_scaling_factor
+ )
+ c_skip_start, c_out_start = [append_dims(x, latents.ndim) for x in [c_skip_start, c_out_start]]
+ c_skip, c_out = scalings_for_boundary_conditions(
+ timesteps, timestep_scaling=args.timestep_scaling_factor
+ )
+ c_skip, c_out = [append_dims(x, latents.ndim) for x in [c_skip, c_out]]
+
+ # 4. Sample noise from the prior and add it to the latents according to the noise magnitude at each
+ # timestep (this is the forward diffusion process) [z_{t_{n + k}} in Algorithm 1]
+ noise = torch.randn_like(latents)
+ noisy_model_input = noise_scheduler.add_noise(latents, noise, start_timesteps)
+
+ # 5. Sample a random guidance scale w from U[w_min, w_max]
+ # Note that for LCM-LoRA distillation it is not necessary to use a guidance scale embedding
+ w = (args.w_max - args.w_min) * torch.rand((bsz,)) + args.w_min
+ w = w.reshape(bsz, 1, 1, 1)
+ w = w.to(device=latents.device, dtype=latents.dtype)
+
+ # 6. Prepare prompt embeds and unet_added_conditions
+ prompt_embeds = encoded_text.pop("prompt_embeds")
+
+ # 7. Get online LCM prediction on z_{t_{n + k}} (noisy_model_input), w, c, t_{n + k} (start_timesteps)
+ noise_pred = unet(
+ noisy_model_input,
+ start_timesteps,
+ encoder_hidden_states=prompt_embeds,
+ added_cond_kwargs=encoded_text,
+ ).sample
+ pred_x_0 = get_predicted_original_sample(
+ noise_pred,
+ start_timesteps,
+ noisy_model_input,
+ noise_scheduler.config.prediction_type,
+ alpha_schedule,
+ sigma_schedule,
+ )
+ model_pred = c_skip_start * noisy_model_input + c_out_start * pred_x_0
+
+ # 8. Compute the conditional and unconditional teacher model predictions to get CFG estimates of the
+ # predicted noise eps_0 and predicted original sample x_0, then run the ODE solver using these
+ # estimates to predict the data point in the augmented PF-ODE trajectory corresponding to the next ODE
+ # solver timestep.
+
+ # With the adapters disabled, the `unet` is the regular teacher model.
+ accelerator.unwrap_model(unet).disable_adapters()
+ with torch.no_grad():
+ # 1. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and conditional embedding c
+ cond_teacher_output = unet(
+ noisy_model_input,
+ start_timesteps,
+ encoder_hidden_states=prompt_embeds,
+ added_cond_kwargs={k: v.to(weight_dtype) for k, v in encoded_text.items()},
+ ).sample
+ cond_pred_x0 = get_predicted_original_sample(
+ cond_teacher_output,
+ start_timesteps,
+ noisy_model_input,
+ noise_scheduler.config.prediction_type,
+ alpha_schedule,
+ sigma_schedule,
+ )
+ cond_pred_noise = get_predicted_noise(
+ cond_teacher_output,
+ start_timesteps,
+ noisy_model_input,
+ noise_scheduler.config.prediction_type,
+ alpha_schedule,
+ sigma_schedule,
+ )
+
+ # 2. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and unconditional embedding 0
+ uncond_prompt_embeds = torch.zeros_like(prompt_embeds)
+ uncond_pooled_prompt_embeds = torch.zeros_like(encoded_text["text_embeds"])
+ uncond_added_conditions = copy.deepcopy(encoded_text)
+ uncond_added_conditions["text_embeds"] = uncond_pooled_prompt_embeds
+ uncond_teacher_output = unet(
+ noisy_model_input,
+ start_timesteps,
+ encoder_hidden_states=uncond_prompt_embeds.to(weight_dtype),
+ added_cond_kwargs={k: v.to(weight_dtype) for k, v in uncond_added_conditions.items()},
+ ).sample
+ uncond_pred_x0 = get_predicted_original_sample(
+ uncond_teacher_output,
+ start_timesteps,
+ noisy_model_input,
+ noise_scheduler.config.prediction_type,
+ alpha_schedule,
+ sigma_schedule,
+ )
+ uncond_pred_noise = get_predicted_noise(
+ uncond_teacher_output,
+ start_timesteps,
+ noisy_model_input,
+ noise_scheduler.config.prediction_type,
+ alpha_schedule,
+ sigma_schedule,
+ )
+
+ # 3. Calculate the CFG estimate of x_0 (pred_x0) and eps_0 (pred_noise)
+ # Note that this uses the LCM paper's CFG formulation rather than the Imagen CFG formulation
+ pred_x0 = cond_pred_x0 + w * (cond_pred_x0 - uncond_pred_x0)
+ pred_noise = cond_pred_noise + w * (cond_pred_noise - uncond_pred_noise)
+ # 4. Run one step of the ODE solver to estimate the next point x_prev on the
+ # augmented PF-ODE trajectory (solving backward in time)
+ # Note that the DDIM step depends on both the predicted x_0 and source noise eps_0.
+ x_prev = solver.ddim_step(pred_x0, pred_noise, index).to(unet.dtype)
+
+ # re-enable unet adapters to turn the `unet` into a student unet.
+ accelerator.unwrap_model(unet).enable_adapters()
+
+ # 9. Get target LCM prediction on x_prev, w, c, t_n (timesteps)
+ # Note that we do not use a separate target network for LCM-LoRA distillation.
+ with torch.no_grad():
+ target_noise_pred = unet(
+ x_prev,
+ timesteps,
+ encoder_hidden_states=prompt_embeds,
+ added_cond_kwargs={k: v.to(weight_dtype) for k, v in encoded_text.items()},
+ ).sample
+ pred_x_0 = get_predicted_original_sample(
+ target_noise_pred,
+ timesteps,
+ x_prev,
+ noise_scheduler.config.prediction_type,
+ alpha_schedule,
+ sigma_schedule,
+ )
+ target = c_skip * x_prev + c_out * pred_x_0
+
+ # 10. Calculate loss
+ if args.loss_type == "l2":
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
+ elif args.loss_type == "huber":
+ loss = torch.mean(
+ torch.sqrt((model_pred.float() - target.float()) ** 2 + args.huber_c**2) - args.huber_c
+ )
+
+ # 11. Backpropagate on the online student model (`unet`) (only LoRA)
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ accelerator.clip_grad_norm_(params_to_optimize, args.max_grad_norm)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad(set_to_none=True)
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ global_step += 1
+
+ if accelerator.is_main_process:
+ if global_step % args.checkpointing_steps == 0:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ if global_step % args.validation_steps == 0:
+ log_validation(
+ vae, args, accelerator, weight_dtype, global_step, unet=unet, is_final_validation=False
+ )
+
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+ accelerator.log(logs, step=global_step)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ # Create the pipeline using using the trained modules and save it.
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ unet = accelerator.unwrap_model(unet)
+ unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet))
+ StableDiffusionXLPipeline.save_lora_weights(args.output_dir, unet_lora_layers=unet_lora_state_dict)
+
+ if args.push_to_hub:
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ del unet
+ torch.cuda.empty_cache()
+
+ # Final inference.
+ if args.validation_steps is not None:
+ log_validation(vae, args, accelerator, weight_dtype, step=global_step, unet=None, is_final_validation=True)
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ main(args)
diff --git a/diffusers/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py b/diffusers/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py
new file mode 100644
index 0000000000000000000000000000000000000000..fa7e7f1febee8a57c0fc6bb767deb72366df0281
--- /dev/null
+++ b/diffusers/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py
@@ -0,0 +1,1530 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+import argparse
+import copy
+import functools
+import gc
+import itertools
+import json
+import logging
+import math
+import os
+import random
+import shutil
+from contextlib import nullcontext
+from pathlib import Path
+from typing import List, Union
+
+import accelerate
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import torchvision.transforms.functional as TF
+import transformers
+import webdataset as wds
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import ProjectConfiguration, set_seed
+from braceexpand import braceexpand
+from huggingface_hub import create_repo, upload_folder
+from packaging import version
+from peft import LoraConfig, get_peft_model, get_peft_model_state_dict
+from torch.utils.data import default_collate
+from torchvision import transforms
+from tqdm.auto import tqdm
+from transformers import AutoTokenizer, PretrainedConfig
+from webdataset.tariterators import (
+ base_plus_ext,
+ tar_file_expander,
+ url_opener,
+ valid_sample,
+)
+
+import diffusers
+from diffusers import (
+ AutoencoderKL,
+ DDPMScheduler,
+ LCMScheduler,
+ StableDiffusionXLPipeline,
+ UNet2DConditionModel,
+)
+from diffusers.optimization import get_scheduler
+from diffusers.training_utils import resolve_interpolation_mode
+from diffusers.utils import check_min_version, is_wandb_available
+from diffusers.utils.import_utils import is_xformers_available
+
+
+MAX_SEQ_LENGTH = 77
+
+# Adjust for your dataset
+WDS_JSON_WIDTH = "width" # original_width for LAION
+WDS_JSON_HEIGHT = "height" # original_height for LAION
+MIN_SIZE = 700 # ~960 for LAION, ideal: 1024 if the dataset contains large images
+
+if is_wandb_available():
+ import wandb
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.31.0.dev0")
+
+logger = get_logger(__name__)
+
+
+def get_module_kohya_state_dict(module, prefix: str, dtype: torch.dtype, adapter_name: str = "default"):
+ kohya_ss_state_dict = {}
+ for peft_key, weight in get_peft_model_state_dict(module, adapter_name=adapter_name).items():
+ kohya_key = peft_key.replace("base_model.model", prefix)
+ kohya_key = kohya_key.replace("lora_A", "lora_down")
+ kohya_key = kohya_key.replace("lora_B", "lora_up")
+ kohya_key = kohya_key.replace(".", "_", kohya_key.count(".") - 2)
+ kohya_ss_state_dict[kohya_key] = weight.to(dtype)
+
+ # Set alpha parameter
+ if "lora_down" in kohya_key:
+ alpha_key = f'{kohya_key.split(".")[0]}.alpha'
+ kohya_ss_state_dict[alpha_key] = torch.tensor(module.peft_config[adapter_name].lora_alpha).to(dtype)
+
+ return kohya_ss_state_dict
+
+
+def filter_keys(key_set):
+ def _f(dictionary):
+ return {k: v for k, v in dictionary.items() if k in key_set}
+
+ return _f
+
+
+def group_by_keys_nothrow(data, keys=base_plus_ext, lcase=True, suffixes=None, handler=None):
+ """Return function over iterator that groups key, value pairs into samples.
+
+ :param keys: function that splits the key into key and extension (base_plus_ext) :param lcase: convert suffixes to
+ lower case (Default value = True)
+ """
+ current_sample = None
+ for filesample in data:
+ assert isinstance(filesample, dict)
+ fname, value = filesample["fname"], filesample["data"]
+ prefix, suffix = keys(fname)
+ if prefix is None:
+ continue
+ if lcase:
+ suffix = suffix.lower()
+ # FIXME webdataset version throws if suffix in current_sample, but we have a potential for
+ # this happening in the current LAION400m dataset if a tar ends with same prefix as the next
+ # begins, rare, but can happen since prefix aren't unique across tar files in that dataset
+ if current_sample is None or prefix != current_sample["__key__"] or suffix in current_sample:
+ if valid_sample(current_sample):
+ yield current_sample
+ current_sample = {"__key__": prefix, "__url__": filesample["__url__"]}
+ if suffixes is None or suffix in suffixes:
+ current_sample[suffix] = value
+ if valid_sample(current_sample):
+ yield current_sample
+
+
+def tarfile_to_samples_nothrow(src, handler=wds.warn_and_continue):
+ # NOTE this is a re-impl of the webdataset impl with group_by_keys that doesn't throw
+ streams = url_opener(src, handler=handler)
+ files = tar_file_expander(streams, handler=handler)
+ samples = group_by_keys_nothrow(files, handler=handler)
+ return samples
+
+
+class WebdatasetFilter:
+ def __init__(self, min_size=1024, max_pwatermark=0.5):
+ self.min_size = min_size
+ self.max_pwatermark = max_pwatermark
+
+ def __call__(self, x):
+ try:
+ if "json" in x:
+ x_json = json.loads(x["json"])
+ filter_size = (x_json.get(WDS_JSON_WIDTH, 0.0) or 0.0) >= self.min_size and x_json.get(
+ WDS_JSON_HEIGHT, 0
+ ) >= self.min_size
+ filter_watermark = (x_json.get("pwatermark", 0.0) or 0.0) <= self.max_pwatermark
+ return filter_size and filter_watermark
+ else:
+ return False
+ except Exception:
+ return False
+
+
+class SDXLText2ImageDataset:
+ def __init__(
+ self,
+ train_shards_path_or_url: Union[str, List[str]],
+ num_train_examples: int,
+ per_gpu_batch_size: int,
+ global_batch_size: int,
+ num_workers: int,
+ resolution: int = 1024,
+ interpolation_type: str = "bilinear",
+ shuffle_buffer_size: int = 1000,
+ pin_memory: bool = False,
+ persistent_workers: bool = False,
+ use_fix_crop_and_size: bool = False,
+ ):
+ if not isinstance(train_shards_path_or_url, str):
+ train_shards_path_or_url = [list(braceexpand(urls)) for urls in train_shards_path_or_url]
+ # flatten list using itertools
+ train_shards_path_or_url = list(itertools.chain.from_iterable(train_shards_path_or_url))
+
+ def get_orig_size(json):
+ if use_fix_crop_and_size:
+ return (resolution, resolution)
+ else:
+ return (int(json.get(WDS_JSON_WIDTH, 0.0)), int(json.get(WDS_JSON_HEIGHT, 0.0)))
+
+ interpolation_mode = resolve_interpolation_mode(interpolation_type)
+
+ def transform(example):
+ # resize image
+ image = example["image"]
+ image = TF.resize(image, resolution, interpolation=interpolation_mode)
+
+ # get crop coordinates and crop image
+ c_top, c_left, _, _ = transforms.RandomCrop.get_params(image, output_size=(resolution, resolution))
+ image = TF.crop(image, c_top, c_left, resolution, resolution)
+ image = TF.to_tensor(image)
+ image = TF.normalize(image, [0.5], [0.5])
+
+ example["image"] = image
+ example["crop_coords"] = (c_top, c_left) if not use_fix_crop_and_size else (0, 0)
+ return example
+
+ processing_pipeline = [
+ wds.decode("pil", handler=wds.ignore_and_continue),
+ wds.rename(
+ image="jpg;png;jpeg;webp", text="text;txt;caption", orig_size="json", handler=wds.warn_and_continue
+ ),
+ wds.map(filter_keys({"image", "text", "orig_size"})),
+ wds.map_dict(orig_size=get_orig_size),
+ wds.map(transform),
+ wds.to_tuple("image", "text", "orig_size", "crop_coords"),
+ ]
+
+ # Create train dataset and loader
+ pipeline = [
+ wds.ResampledShards(train_shards_path_or_url),
+ tarfile_to_samples_nothrow,
+ wds.select(WebdatasetFilter(min_size=MIN_SIZE)),
+ wds.shuffle(shuffle_buffer_size),
+ *processing_pipeline,
+ wds.batched(per_gpu_batch_size, partial=False, collation_fn=default_collate),
+ ]
+
+ num_worker_batches = math.ceil(num_train_examples / (global_batch_size * num_workers)) # per dataloader worker
+ num_batches = num_worker_batches * num_workers
+ num_samples = num_batches * global_batch_size
+
+ # each worker is iterating over this
+ self._train_dataset = wds.DataPipeline(*pipeline).with_epoch(num_worker_batches)
+ self._train_dataloader = wds.WebLoader(
+ self._train_dataset,
+ batch_size=None,
+ shuffle=False,
+ num_workers=num_workers,
+ pin_memory=pin_memory,
+ persistent_workers=persistent_workers,
+ )
+ # add meta-data to dataloader instance for convenience
+ self._train_dataloader.num_batches = num_batches
+ self._train_dataloader.num_samples = num_samples
+
+ @property
+ def train_dataset(self):
+ return self._train_dataset
+
+ @property
+ def train_dataloader(self):
+ return self._train_dataloader
+
+
+def log_validation(vae, unet, args, accelerator, weight_dtype, step):
+ logger.info("Running validation... ")
+ if torch.backends.mps.is_available():
+ autocast_ctx = nullcontext()
+ else:
+ autocast_ctx = torch.autocast(accelerator.device.type, dtype=weight_dtype)
+
+ unet = accelerator.unwrap_model(unet)
+ pipeline = StableDiffusionXLPipeline.from_pretrained(
+ args.pretrained_teacher_model,
+ vae=vae,
+ scheduler=LCMScheduler.from_pretrained(args.pretrained_teacher_model, subfolder="scheduler"),
+ revision=args.revision,
+ torch_dtype=weight_dtype,
+ )
+ pipeline = pipeline.to(accelerator.device)
+ pipeline.set_progress_bar_config(disable=True)
+
+ lora_state_dict = get_module_kohya_state_dict(unet, "lora_unet", weight_dtype)
+ pipeline.load_lora_weights(lora_state_dict)
+ pipeline.fuse_lora()
+
+ if args.enable_xformers_memory_efficient_attention:
+ pipeline.enable_xformers_memory_efficient_attention()
+
+ if args.seed is None:
+ generator = None
+ else:
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
+
+ validation_prompts = [
+ "portrait photo of a girl, photograph, highly detailed face, depth of field, moody light, golden hour, style by Dan Winters, Russell James, Steve McCurry, centered, extremely detailed, Nikon D850, award winning photography",
+ "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k",
+ "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
+ "A photo of beautiful mountain with realistic sunset and blue lake, highly detailed, masterpiece",
+ ]
+
+ image_logs = []
+
+ for _, prompt in enumerate(validation_prompts):
+ images = []
+ with autocast_ctx:
+ images = pipeline(
+ prompt=prompt,
+ num_inference_steps=4,
+ num_images_per_prompt=4,
+ generator=generator,
+ guidance_scale=0.0,
+ ).images
+ image_logs.append({"validation_prompt": prompt, "images": images})
+
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ for log in image_logs:
+ images = log["images"]
+ validation_prompt = log["validation_prompt"]
+ formatted_images = []
+ for image in images:
+ formatted_images.append(np.asarray(image))
+
+ formatted_images = np.stack(formatted_images)
+
+ tracker.writer.add_images(validation_prompt, formatted_images, step, dataformats="NHWC")
+ elif tracker.name == "wandb":
+ formatted_images = []
+
+ for log in image_logs:
+ images = log["images"]
+ validation_prompt = log["validation_prompt"]
+ for image in images:
+ image = wandb.Image(image, caption=validation_prompt)
+ formatted_images.append(image)
+
+ tracker.log({"validation": formatted_images})
+ else:
+ logger.warning(f"image logging not implemented for {tracker.name}")
+
+ del pipeline
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ return image_logs
+
+
+def append_dims(x, target_dims):
+ """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
+ dims_to_append = target_dims - x.ndim
+ if dims_to_append < 0:
+ raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less")
+ return x[(...,) + (None,) * dims_to_append]
+
+
+# From LCMScheduler.get_scalings_for_boundary_condition_discrete
+def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=10.0):
+ scaled_timestep = timestep_scaling * timestep
+ c_skip = sigma_data**2 / (scaled_timestep**2 + sigma_data**2)
+ c_out = scaled_timestep / (scaled_timestep**2 + sigma_data**2) ** 0.5
+ return c_skip, c_out
+
+
+# Compare LCMScheduler.step, Step 4
+def get_predicted_original_sample(model_output, timesteps, sample, prediction_type, alphas, sigmas):
+ alphas = extract_into_tensor(alphas, timesteps, sample.shape)
+ sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
+ if prediction_type == "epsilon":
+ pred_x_0 = (sample - sigmas * model_output) / alphas
+ elif prediction_type == "sample":
+ pred_x_0 = model_output
+ elif prediction_type == "v_prediction":
+ pred_x_0 = alphas * sample - sigmas * model_output
+ else:
+ raise ValueError(
+ f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`"
+ f" are supported."
+ )
+
+ return pred_x_0
+
+
+# Based on step 4 in DDIMScheduler.step
+def get_predicted_noise(model_output, timesteps, sample, prediction_type, alphas, sigmas):
+ alphas = extract_into_tensor(alphas, timesteps, sample.shape)
+ sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
+ if prediction_type == "epsilon":
+ pred_epsilon = model_output
+ elif prediction_type == "sample":
+ pred_epsilon = (sample - alphas * model_output) / sigmas
+ elif prediction_type == "v_prediction":
+ pred_epsilon = alphas * model_output + sigmas * sample
+ else:
+ raise ValueError(
+ f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`"
+ f" are supported."
+ )
+
+ return pred_epsilon
+
+
+def extract_into_tensor(a, t, x_shape):
+ b, *_ = t.shape
+ out = a.gather(-1, t)
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
+
+
+class DDIMSolver:
+ def __init__(self, alpha_cumprods, timesteps=1000, ddim_timesteps=50):
+ # DDIM sampling parameters
+ step_ratio = timesteps // ddim_timesteps
+
+ self.ddim_timesteps = (np.arange(1, ddim_timesteps + 1) * step_ratio).round().astype(np.int64) - 1
+ self.ddim_alpha_cumprods = alpha_cumprods[self.ddim_timesteps]
+ self.ddim_alpha_cumprods_prev = np.asarray(
+ [alpha_cumprods[0]] + alpha_cumprods[self.ddim_timesteps[:-1]].tolist()
+ )
+ # convert to torch tensors
+ self.ddim_timesteps = torch.from_numpy(self.ddim_timesteps).long()
+ self.ddim_alpha_cumprods = torch.from_numpy(self.ddim_alpha_cumprods)
+ self.ddim_alpha_cumprods_prev = torch.from_numpy(self.ddim_alpha_cumprods_prev)
+
+ def to(self, device):
+ self.ddim_timesteps = self.ddim_timesteps.to(device)
+ self.ddim_alpha_cumprods = self.ddim_alpha_cumprods.to(device)
+ self.ddim_alpha_cumprods_prev = self.ddim_alpha_cumprods_prev.to(device)
+ return self
+
+ def ddim_step(self, pred_x0, pred_noise, timestep_index):
+ alpha_cumprod_prev = extract_into_tensor(self.ddim_alpha_cumprods_prev, timestep_index, pred_x0.shape)
+ dir_xt = (1.0 - alpha_cumprod_prev).sqrt() * pred_noise
+ x_prev = alpha_cumprod_prev.sqrt() * pred_x0 + dir_xt
+ return x_prev
+
+
+def import_model_class_from_model_name_or_path(
+ pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
+):
+ text_encoder_config = PretrainedConfig.from_pretrained(
+ pretrained_model_name_or_path, subfolder=subfolder, revision=revision
+ )
+ model_class = text_encoder_config.architectures[0]
+
+ if model_class == "CLIPTextModel":
+ from transformers import CLIPTextModel
+
+ return CLIPTextModel
+ elif model_class == "CLIPTextModelWithProjection":
+ from transformers import CLIPTextModelWithProjection
+
+ return CLIPTextModelWithProjection
+ else:
+ raise ValueError(f"{model_class} is not supported.")
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ # ----------Model Checkpoint Loading Arguments----------
+ parser.add_argument(
+ "--pretrained_teacher_model",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained LDM teacher model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--pretrained_vae_model_name_or_path",
+ type=str,
+ default=None,
+ help="Path to pretrained VAE model with better numerical stability. More details: https://github.com/huggingface/diffusers/pull/4038.",
+ )
+ parser.add_argument(
+ "--teacher_revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained LDM teacher model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained LDM model identifier from huggingface.co/models.",
+ )
+ # ----------Training Arguments----------
+ # ----General Training Arguments----
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="lcm-xl-distilled",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ # ----Logging----
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ # ----Checkpointing----
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ # ----Image Processing----
+ parser.add_argument(
+ "--train_shards_path_or_url",
+ type=str,
+ default=None,
+ help=(
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
+ " or to a folder containing files that 🤗 Datasets can understand."
+ ),
+ )
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=1024,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--interpolation_type",
+ type=str,
+ default="bilinear",
+ help=(
+ "The interpolation function used when resizing images to the desired resolution. Choose between `bilinear`,"
+ " `bicubic`, `box`, `nearest`, `nearest_exact`, `hamming`, and `lanczos`."
+ ),
+ )
+ parser.add_argument(
+ "--use_fix_crop_and_size",
+ action="store_true",
+ help="Whether or not to use the fixed crop and size for the teacher model.",
+ default=False,
+ )
+ parser.add_argument(
+ "--center_crop",
+ default=False,
+ action="store_true",
+ help=(
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
+ " cropped. The images will be resized to the resolution first before cropping."
+ ),
+ )
+ parser.add_argument(
+ "--random_flip",
+ action="store_true",
+ help="whether to randomly flip images horizontally",
+ )
+ # ----Dataloader----
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ # ----Batch Size and Training Steps----
+ parser.add_argument(
+ "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=100)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--max_train_samples",
+ type=int,
+ default=None,
+ help=(
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ ),
+ )
+ # ----Learning Rate----
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ # ----Optimizer (Adam)----
+ parser.add_argument(
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ # ----Diffusion Training Arguments----
+ parser.add_argument(
+ "--proportion_empty_prompts",
+ type=float,
+ default=0,
+ help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).",
+ )
+ # ----Latent Consistency Distillation (LCD) Specific Arguments----
+ parser.add_argument(
+ "--w_min",
+ type=float,
+ default=3.0,
+ required=False,
+ help=(
+ "The minimum guidance scale value for guidance scale sampling. Note that we are using the Imagen CFG"
+ " formulation rather than the LCM formulation, which means all guidance scales have 1 added to them as"
+ " compared to the original paper."
+ ),
+ )
+ parser.add_argument(
+ "--w_max",
+ type=float,
+ default=15.0,
+ required=False,
+ help=(
+ "The maximum guidance scale value for guidance scale sampling. Note that we are using the Imagen CFG"
+ " formulation rather than the LCM formulation, which means all guidance scales have 1 added to them as"
+ " compared to the original paper."
+ ),
+ )
+ parser.add_argument(
+ "--num_ddim_timesteps",
+ type=int,
+ default=50,
+ help="The number of timesteps to use for DDIM sampling.",
+ )
+ parser.add_argument(
+ "--loss_type",
+ type=str,
+ default="l2",
+ choices=["l2", "huber"],
+ help="The type of loss to use for the LCD loss.",
+ )
+ parser.add_argument(
+ "--huber_c",
+ type=float,
+ default=0.001,
+ help="The huber loss parameter. Only used if `--loss_type=huber`.",
+ )
+ parser.add_argument(
+ "--lora_rank",
+ type=int,
+ default=64,
+ help="The rank of the LoRA projection matrix.",
+ )
+ parser.add_argument(
+ "--lora_alpha",
+ type=int,
+ default=64,
+ help=(
+ "The value of the LoRA alpha parameter, which controls the scaling factor in front of the LoRA weight"
+ " update delta_W. No scaling will be performed if this value is equal to `lora_rank`."
+ ),
+ )
+ parser.add_argument(
+ "--lora_dropout",
+ type=float,
+ default=0.0,
+ help="The dropout probability for the dropout layer added before applying the LoRA to each layer input.",
+ )
+ parser.add_argument(
+ "--lora_target_modules",
+ type=str,
+ default=None,
+ help=(
+ "A comma-separated string of target module keys to add LoRA to. If not set, a default list of modules will"
+ " be used. By default, LoRA will be applied to all conv and linear layers."
+ ),
+ )
+ parser.add_argument(
+ "--vae_encode_batch_size",
+ type=int,
+ default=8,
+ required=False,
+ help=(
+ "The batch size used when encoding (and decoding) images to latents (and vice versa) using the VAE."
+ " Encoding or decoding the whole batch at once may run into OOM issues."
+ ),
+ )
+ parser.add_argument(
+ "--timestep_scaling_factor",
+ type=float,
+ default=10.0,
+ help=(
+ "The multiplicative timestep scaling factor used when calculating the boundary scalings for LCM. The"
+ " higher the scaling is, the lower the approximation error, but the default value of 10.0 should typically"
+ " suffice."
+ ),
+ )
+ # ----Mixed Precision----
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument(
+ "--cast_teacher_unet",
+ action="store_true",
+ help="Whether to cast the teacher U-Net to the precision specified by `--mixed_precision`.",
+ )
+ # ----Training Optimizations----
+ parser.add_argument(
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ # ----Distributed Training----
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+ # ----------Validation Arguments----------
+ parser.add_argument(
+ "--validation_steps",
+ type=int,
+ default=200,
+ help="Run validation every X steps.",
+ )
+ # ----------Huggingface Hub Arguments-----------
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ # ----------Accelerate Arguments----------
+ parser.add_argument(
+ "--tracker_project_name",
+ type=str,
+ default="text2image-fine-tune",
+ help=(
+ "The `project_name` argument passed to Accelerator.init_trackers for"
+ " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
+ ),
+ )
+
+ args = parser.parse_args()
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1:
+ raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].")
+
+ return args
+
+
+# Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt
+def encode_prompt(prompt_batch, text_encoders, tokenizers, proportion_empty_prompts, is_train=True):
+ prompt_embeds_list = []
+
+ captions = []
+ for caption in prompt_batch:
+ if random.random() < proportion_empty_prompts:
+ captions.append("")
+ elif isinstance(caption, str):
+ captions.append(caption)
+ elif isinstance(caption, (list, np.ndarray)):
+ # take a random caption if there are multiple
+ captions.append(random.choice(caption) if is_train else caption[0])
+
+ with torch.no_grad():
+ for tokenizer, text_encoder in zip(tokenizers, text_encoders):
+ text_inputs = tokenizer(
+ captions,
+ padding="max_length",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ prompt_embeds = text_encoder(
+ text_input_ids.to(text_encoder.device),
+ output_hidden_states=True,
+ )
+
+ # We are only ALWAYS interested in the pooled output of the final text encoder
+ pooled_prompt_embeds = prompt_embeds[0]
+ prompt_embeds = prompt_embeds.hidden_states[-2]
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
+ prompt_embeds_list.append(prompt_embeds)
+
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
+ pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
+ return prompt_embeds, pooled_prompt_embeds
+
+
+def main(args):
+ if args.report_to == "wandb" and args.hub_token is not None:
+ raise ValueError(
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
+ " Please use `huggingface-cli login` to authenticate with the Hub."
+ )
+
+ logging_dir = Path(args.output_dir, args.logging_dir)
+
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ split_batches=True, # It's important to set this to True when using webdataset to get the right number of steps for lr scheduling. If set to False, the number of steps will be devide by the number of processes assuming batches are multiplied by the number of processes
+ )
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name,
+ exist_ok=True,
+ token=args.hub_token,
+ private=True,
+ ).repo_id
+
+ # 1. Create the noise scheduler and the desired noise schedule.
+ noise_scheduler = DDPMScheduler.from_pretrained(
+ args.pretrained_teacher_model, subfolder="scheduler", revision=args.teacher_revision
+ )
+
+ # DDPMScheduler calculates the alpha and sigma noise schedules (based on the alpha bars) for us
+ alpha_schedule = torch.sqrt(noise_scheduler.alphas_cumprod)
+ sigma_schedule = torch.sqrt(1 - noise_scheduler.alphas_cumprod)
+ # Initialize the DDIM ODE solver for distillation.
+ solver = DDIMSolver(
+ noise_scheduler.alphas_cumprod.numpy(),
+ timesteps=noise_scheduler.config.num_train_timesteps,
+ ddim_timesteps=args.num_ddim_timesteps,
+ )
+
+ # 2. Load tokenizers from SD-XL checkpoint.
+ tokenizer_one = AutoTokenizer.from_pretrained(
+ args.pretrained_teacher_model, subfolder="tokenizer", revision=args.teacher_revision, use_fast=False
+ )
+ tokenizer_two = AutoTokenizer.from_pretrained(
+ args.pretrained_teacher_model, subfolder="tokenizer_2", revision=args.teacher_revision, use_fast=False
+ )
+
+ # 3. Load text encoders from SD-XL checkpoint.
+ # import correct text encoder classes
+ text_encoder_cls_one = import_model_class_from_model_name_or_path(
+ args.pretrained_teacher_model, args.teacher_revision
+ )
+ text_encoder_cls_two = import_model_class_from_model_name_or_path(
+ args.pretrained_teacher_model, args.teacher_revision, subfolder="text_encoder_2"
+ )
+
+ text_encoder_one = text_encoder_cls_one.from_pretrained(
+ args.pretrained_teacher_model, subfolder="text_encoder", revision=args.teacher_revision
+ )
+ text_encoder_two = text_encoder_cls_two.from_pretrained(
+ args.pretrained_teacher_model, subfolder="text_encoder_2", revision=args.teacher_revision
+ )
+
+ # 4. Load VAE from SD-XL checkpoint (or more stable VAE)
+ vae_path = (
+ args.pretrained_teacher_model
+ if args.pretrained_vae_model_name_or_path is None
+ else args.pretrained_vae_model_name_or_path
+ )
+ vae = AutoencoderKL.from_pretrained(
+ vae_path,
+ subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
+ revision=args.teacher_revision,
+ )
+
+ # 5. Load teacher U-Net from SD-XL checkpoint
+ teacher_unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_teacher_model, subfolder="unet", revision=args.teacher_revision
+ )
+
+ # 6. Freeze teacher vae, text_encoders, and teacher_unet
+ vae.requires_grad_(False)
+ text_encoder_one.requires_grad_(False)
+ text_encoder_two.requires_grad_(False)
+ teacher_unet.requires_grad_(False)
+
+ # 7. Create online student U-Net.
+ unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_teacher_model, subfolder="unet", revision=args.teacher_revision
+ )
+ unet.train()
+
+ # Check that all trainable models are in full precision
+ low_precision_error_string = (
+ " Please make sure to always have all model weights in full float32 precision when starting training - even if"
+ " doing mixed precision training, copy of the weights should still be float32."
+ )
+
+ if accelerator.unwrap_model(unet).dtype != torch.float32:
+ raise ValueError(
+ f"Controlnet loaded as datatype {accelerator.unwrap_model(unet).dtype}. {low_precision_error_string}"
+ )
+
+ # 8. Add LoRA to the student U-Net, only the LoRA projection matrix will be updated by the optimizer.
+ if args.lora_target_modules is not None:
+ lora_target_modules = [module_key.strip() for module_key in args.lora_target_modules.split(",")]
+ else:
+ lora_target_modules = [
+ "to_q",
+ "to_k",
+ "to_v",
+ "to_out.0",
+ "proj_in",
+ "proj_out",
+ "ff.net.0.proj",
+ "ff.net.2",
+ "conv1",
+ "conv2",
+ "conv_shortcut",
+ "downsamplers.0.conv",
+ "upsamplers.0.conv",
+ "time_emb_proj",
+ ]
+ lora_config = LoraConfig(
+ r=args.lora_rank,
+ target_modules=lora_target_modules,
+ lora_alpha=args.lora_alpha,
+ lora_dropout=args.lora_dropout,
+ )
+ unet = get_peft_model(unet, lora_config)
+
+ # 9. Handle mixed precision and device placement
+ # For mixed precision training we cast all non-trainable weigths to half-precision
+ # as these weights are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ # Move unet, vae and text_encoder to device and cast to weight_dtype
+ # The VAE is in float32 to avoid NaN losses.
+ vae.to(accelerator.device)
+ if args.pretrained_vae_model_name_or_path is not None:
+ vae.to(dtype=weight_dtype)
+ text_encoder_one.to(accelerator.device, dtype=weight_dtype)
+ text_encoder_two.to(accelerator.device, dtype=weight_dtype)
+
+ # Move teacher_unet to device, optionally cast to weight_dtype
+ teacher_unet.to(accelerator.device)
+ if args.cast_teacher_unet:
+ teacher_unet.to(dtype=weight_dtype)
+
+ # Also move the alpha and sigma noise schedules to accelerator.device.
+ alpha_schedule = alpha_schedule.to(accelerator.device)
+ sigma_schedule = sigma_schedule.to(accelerator.device)
+ # Move the ODE solver to accelerator.device.
+ solver = solver.to(accelerator.device)
+
+ # 10. Handle saving and loading of checkpoints
+ # `accelerate` 0.16.0 will have better support for customized saving
+ if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
+ def save_model_hook(models, weights, output_dir):
+ if accelerator.is_main_process:
+ unet_ = accelerator.unwrap_model(unet)
+ lora_state_dict = get_peft_model_state_dict(unet_, adapter_name="default")
+ StableDiffusionXLPipeline.save_lora_weights(os.path.join(output_dir, "unet_lora"), lora_state_dict)
+ # save weights in peft format to be able to load them back
+ unet_.save_pretrained(output_dir)
+
+ for _, model in enumerate(models):
+ # make sure to pop weight so that corresponding model is not saved again
+ weights.pop()
+
+ def load_model_hook(models, input_dir):
+ # load the LoRA into the model
+ unet_ = accelerator.unwrap_model(unet)
+ unet_.load_adapter(input_dir, "default", is_trainable=True)
+
+ for _ in range(len(models)):
+ # pop models so that they are not loaded again
+ models.pop()
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
+
+ # 11. Enable optimizations
+ if args.enable_xformers_memory_efficient_attention:
+ if is_xformers_available():
+ import xformers
+
+ xformers_version = version.parse(xformers.__version__)
+ if xformers_version == version.parse("0.0.16"):
+ logger.warning(
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
+ )
+ unet.enable_xformers_memory_efficient_attention()
+ teacher_unet.enable_xformers_memory_efficient_attention()
+ # target_unet.enable_xformers_memory_efficient_attention()
+ else:
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.gradient_checkpointing:
+ unet.enable_gradient_checkpointing()
+
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
+ )
+
+ optimizer_class = bnb.optim.AdamW8bit
+ else:
+ optimizer_class = torch.optim.AdamW
+
+ # 12. Optimizer creation
+ optimizer = optimizer_class(
+ unet.parameters(),
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ # 13. Dataset creation and data processing
+ # Here, we compute not just the text embeddings but also the additional embeddings
+ # needed for the SD XL UNet to operate.
+ def compute_embeddings(
+ prompt_batch, original_sizes, crop_coords, proportion_empty_prompts, text_encoders, tokenizers, is_train=True
+ ):
+ target_size = (args.resolution, args.resolution)
+ original_sizes = list(map(list, zip(*original_sizes)))
+ crops_coords_top_left = list(map(list, zip(*crop_coords)))
+
+ original_sizes = torch.tensor(original_sizes, dtype=torch.long)
+ crops_coords_top_left = torch.tensor(crops_coords_top_left, dtype=torch.long)
+
+ prompt_embeds, pooled_prompt_embeds = encode_prompt(
+ prompt_batch, text_encoders, tokenizers, proportion_empty_prompts, is_train
+ )
+ add_text_embeds = pooled_prompt_embeds
+
+ # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
+ add_time_ids = list(target_size)
+ add_time_ids = torch.tensor([add_time_ids])
+ add_time_ids = add_time_ids.repeat(len(prompt_batch), 1)
+ add_time_ids = torch.cat([original_sizes, crops_coords_top_left, add_time_ids], dim=-1)
+ add_time_ids = add_time_ids.to(accelerator.device, dtype=prompt_embeds.dtype)
+
+ prompt_embeds = prompt_embeds.to(accelerator.device)
+ add_text_embeds = add_text_embeds.to(accelerator.device)
+ unet_added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
+
+ return {"prompt_embeds": prompt_embeds, **unet_added_cond_kwargs}
+
+ dataset = SDXLText2ImageDataset(
+ train_shards_path_or_url=args.train_shards_path_or_url,
+ num_train_examples=args.max_train_samples,
+ per_gpu_batch_size=args.train_batch_size,
+ global_batch_size=args.train_batch_size * accelerator.num_processes,
+ num_workers=args.dataloader_num_workers,
+ resolution=args.resolution,
+ interpolation_type=args.interpolation_type,
+ shuffle_buffer_size=1000,
+ pin_memory=True,
+ persistent_workers=True,
+ use_fix_crop_and_size=args.use_fix_crop_and_size,
+ )
+ train_dataloader = dataset.train_dataloader
+
+ # Let's first compute all the embeddings so that we can free up the text encoders
+ # from memory.
+ text_encoders = [text_encoder_one, text_encoder_two]
+ tokenizers = [tokenizer_one, tokenizer_two]
+
+ compute_embeddings_fn = functools.partial(
+ compute_embeddings,
+ proportion_empty_prompts=0,
+ text_encoders=text_encoders,
+ tokenizers=tokenizers,
+ )
+
+ # 14. LR Scheduler creation
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(train_dataloader.num_batches / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps,
+ num_training_steps=args.max_train_steps,
+ )
+
+ # 15. Prepare for training
+ # Prepare everything with our `accelerator`.
+ unet, optimizer, lr_scheduler = accelerator.prepare(unet, optimizer, lr_scheduler)
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(train_dataloader.num_batches / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ tracker_config = dict(vars(args))
+ accelerator.init_trackers(args.tracker_project_name, config=tracker_config)
+
+ # Create uncond embeds for classifier free guidance
+ uncond_prompt_embeds = torch.zeros(args.train_batch_size, 77, 2048).to(accelerator.device)
+ uncond_pooled_prompt_embeds = torch.zeros(args.train_batch_size, 1280).to(accelerator.device)
+
+ # 16. Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num batches each epoch = {train_dataloader.num_batches}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the most recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ initial_global_step = 0
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ initial_global_step = global_step
+ first_epoch = global_step // num_update_steps_per_epoch
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
+
+ for epoch in range(first_epoch, args.num_train_epochs):
+ for step, batch in enumerate(train_dataloader):
+ with accelerator.accumulate(unet):
+ # 1. Load and process the image, text, and micro-conditioning (original image size, crop coordinates)
+ image, text, orig_size, crop_coords = batch
+
+ image = image.to(accelerator.device, non_blocking=True)
+ encoded_text = compute_embeddings_fn(text, orig_size, crop_coords)
+
+ if args.pretrained_vae_model_name_or_path is not None:
+ pixel_values = image.to(dtype=weight_dtype)
+ if vae.dtype != weight_dtype:
+ vae.to(dtype=weight_dtype)
+ else:
+ pixel_values = image
+
+ # encode pixel values with batch size of at most args.vae_encode_batch_size
+ latents = []
+ for i in range(0, pixel_values.shape[0], args.vae_encode_batch_size):
+ latents.append(vae.encode(pixel_values[i : i + args.vae_encode_batch_size]).latent_dist.sample())
+ latents = torch.cat(latents, dim=0)
+
+ latents = latents * vae.config.scaling_factor
+ if args.pretrained_vae_model_name_or_path is None:
+ latents = latents.to(weight_dtype)
+ bsz = latents.shape[0]
+
+ # 2. Sample a random timestep for each image t_n from the ODE solver timesteps without bias.
+ # For the DDIM solver, the timestep schedule is [T - 1, T - k - 1, T - 2 * k - 1, ...]
+ topk = noise_scheduler.config.num_train_timesteps // args.num_ddim_timesteps
+ index = torch.randint(0, args.num_ddim_timesteps, (bsz,), device=latents.device).long()
+ start_timesteps = solver.ddim_timesteps[index]
+ timesteps = start_timesteps - topk
+ timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps)
+
+ # 3. Get boundary scalings for start_timesteps and (end) timesteps.
+ c_skip_start, c_out_start = scalings_for_boundary_conditions(
+ start_timesteps, timestep_scaling=args.timestep_scaling_factor
+ )
+ c_skip_start, c_out_start = [append_dims(x, latents.ndim) for x in [c_skip_start, c_out_start]]
+ c_skip, c_out = scalings_for_boundary_conditions(
+ timesteps, timestep_scaling=args.timestep_scaling_factor
+ )
+ c_skip, c_out = [append_dims(x, latents.ndim) for x in [c_skip, c_out]]
+
+ # 4. Sample noise from the prior and add it to the latents according to the noise magnitude at each
+ # timestep (this is the forward diffusion process) [z_{t_{n + k}} in Algorithm 1]
+ noise = torch.randn_like(latents)
+ noisy_model_input = noise_scheduler.add_noise(latents, noise, start_timesteps)
+
+ # 5. Sample a random guidance scale w from U[w_min, w_max]
+ # Note that for LCM-LoRA distillation it is not necessary to use a guidance scale embedding
+ w = (args.w_max - args.w_min) * torch.rand((bsz,)) + args.w_min
+ w = w.reshape(bsz, 1, 1, 1)
+ w = w.to(device=latents.device, dtype=latents.dtype)
+
+ # 6. Prepare prompt embeds and unet_added_conditions
+ prompt_embeds = encoded_text.pop("prompt_embeds")
+
+ # 7. Get online LCM prediction on z_{t_{n + k}} (noisy_model_input), w, c, t_{n + k} (start_timesteps)
+ noise_pred = unet(
+ noisy_model_input,
+ start_timesteps,
+ timestep_cond=None,
+ encoder_hidden_states=prompt_embeds.float(),
+ added_cond_kwargs=encoded_text,
+ ).sample
+
+ pred_x_0 = get_predicted_original_sample(
+ noise_pred,
+ start_timesteps,
+ noisy_model_input,
+ noise_scheduler.config.prediction_type,
+ alpha_schedule,
+ sigma_schedule,
+ )
+
+ model_pred = c_skip_start * noisy_model_input + c_out_start * pred_x_0
+
+ # 8. Compute the conditional and unconditional teacher model predictions to get CFG estimates of the
+ # predicted noise eps_0 and predicted original sample x_0, then run the ODE solver using these
+ # estimates to predict the data point in the augmented PF-ODE trajectory corresponding to the next ODE
+ # solver timestep.
+ with torch.no_grad():
+ if torch.backends.mps.is_available() or "playground" in args.pretrained_teacher_model:
+ autocast_ctx = nullcontext()
+ else:
+ autocast_ctx = torch.autocast(accelerator.device.type)
+
+ with autocast_ctx:
+ # 1. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and conditional embedding c
+ cond_teacher_output = teacher_unet(
+ noisy_model_input.to(weight_dtype),
+ start_timesteps,
+ encoder_hidden_states=prompt_embeds.to(weight_dtype),
+ added_cond_kwargs={k: v.to(weight_dtype) for k, v in encoded_text.items()},
+ ).sample
+ cond_pred_x0 = get_predicted_original_sample(
+ cond_teacher_output,
+ start_timesteps,
+ noisy_model_input,
+ noise_scheduler.config.prediction_type,
+ alpha_schedule,
+ sigma_schedule,
+ )
+ cond_pred_noise = get_predicted_noise(
+ cond_teacher_output,
+ start_timesteps,
+ noisy_model_input,
+ noise_scheduler.config.prediction_type,
+ alpha_schedule,
+ sigma_schedule,
+ )
+
+ # 2. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and unconditional embedding 0
+ uncond_added_conditions = copy.deepcopy(encoded_text)
+ uncond_added_conditions["text_embeds"] = uncond_pooled_prompt_embeds
+ uncond_teacher_output = teacher_unet(
+ noisy_model_input.to(weight_dtype),
+ start_timesteps,
+ encoder_hidden_states=uncond_prompt_embeds.to(weight_dtype),
+ added_cond_kwargs={k: v.to(weight_dtype) for k, v in uncond_added_conditions.items()},
+ ).sample
+ uncond_pred_x0 = get_predicted_original_sample(
+ uncond_teacher_output,
+ start_timesteps,
+ noisy_model_input,
+ noise_scheduler.config.prediction_type,
+ alpha_schedule,
+ sigma_schedule,
+ )
+ uncond_pred_noise = get_predicted_noise(
+ uncond_teacher_output,
+ start_timesteps,
+ noisy_model_input,
+ noise_scheduler.config.prediction_type,
+ alpha_schedule,
+ sigma_schedule,
+ )
+
+ # 3. Calculate the CFG estimate of x_0 (pred_x0) and eps_0 (pred_noise)
+ # Note that this uses the LCM paper's CFG formulation rather than the Imagen CFG formulation
+ pred_x0 = cond_pred_x0 + w * (cond_pred_x0 - uncond_pred_x0)
+ pred_noise = cond_pred_noise + w * (cond_pred_noise - uncond_pred_noise)
+ # 4. Run one step of the ODE solver to estimate the next point x_prev on the
+ # augmented PF-ODE trajectory (solving backward in time)
+ # Note that the DDIM step depends on both the predicted x_0 and source noise eps_0.
+ x_prev = solver.ddim_step(pred_x0, pred_noise, index)
+
+ # 9. Get target LCM prediction on x_prev, w, c, t_n (timesteps)
+ # Note that we do not use a separate target network for LCM-LoRA distillation.
+ with torch.no_grad():
+ if torch.backends.mps.is_available():
+ autocast_ctx = nullcontext()
+ else:
+ autocast_ctx = torch.autocast(accelerator.device.type, dtype=weight_dtype)
+
+ with autocast_ctx:
+ target_noise_pred = unet(
+ x_prev.float(),
+ timesteps,
+ timestep_cond=None,
+ encoder_hidden_states=prompt_embeds.float(),
+ added_cond_kwargs=encoded_text,
+ ).sample
+ pred_x_0 = get_predicted_original_sample(
+ target_noise_pred,
+ timesteps,
+ x_prev,
+ noise_scheduler.config.prediction_type,
+ alpha_schedule,
+ sigma_schedule,
+ )
+ target = c_skip * x_prev + c_out * pred_x_0
+
+ # 10. Calculate loss
+ if args.loss_type == "l2":
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
+ elif args.loss_type == "huber":
+ loss = torch.mean(
+ torch.sqrt((model_pred.float() - target.float()) ** 2 + args.huber_c**2) - args.huber_c
+ )
+
+ # 11. Backpropagate on the online student model (`unet`)
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad(set_to_none=True)
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ global_step += 1
+
+ if accelerator.is_main_process:
+ if global_step % args.checkpointing_steps == 0:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ if global_step % args.validation_steps == 0:
+ log_validation(vae, unet, args, accelerator, weight_dtype, global_step)
+
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+ accelerator.log(logs, step=global_step)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ # Create the pipeline using using the trained modules and save it.
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ unet = accelerator.unwrap_model(unet)
+ unet.save_pretrained(args.output_dir)
+ lora_state_dict = get_peft_model_state_dict(unet, adapter_name="default")
+ StableDiffusionXLPipeline.save_lora_weights(os.path.join(args.output_dir, "unet_lora"), lora_state_dict)
+
+ if args.push_to_hub:
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ main(args)
diff --git a/diffusers/examples/consistency_distillation/train_lcm_distill_sd_wds.py b/diffusers/examples/consistency_distillation/train_lcm_distill_sd_wds.py
new file mode 100644
index 0000000000000000000000000000000000000000..12d7db09a361a59516bbe319c67731c345091f4d
--- /dev/null
+++ b/diffusers/examples/consistency_distillation/train_lcm_distill_sd_wds.py
@@ -0,0 +1,1433 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+import argparse
+import functools
+import gc
+import itertools
+import json
+import logging
+import math
+import os
+import random
+import shutil
+from contextlib import nullcontext
+from pathlib import Path
+from typing import List, Union
+
+import accelerate
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import torchvision.transforms.functional as TF
+import transformers
+import webdataset as wds
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import ProjectConfiguration, set_seed
+from braceexpand import braceexpand
+from huggingface_hub import create_repo, upload_folder
+from packaging import version
+from torch.utils.data import default_collate
+from torchvision import transforms
+from tqdm.auto import tqdm
+from transformers import AutoTokenizer, CLIPTextModel, PretrainedConfig
+from webdataset.tariterators import (
+ base_plus_ext,
+ tar_file_expander,
+ url_opener,
+ valid_sample,
+)
+
+import diffusers
+from diffusers import (
+ AutoencoderKL,
+ DDPMScheduler,
+ LCMScheduler,
+ StableDiffusionPipeline,
+ UNet2DConditionModel,
+)
+from diffusers.optimization import get_scheduler
+from diffusers.training_utils import resolve_interpolation_mode
+from diffusers.utils import check_min_version, is_wandb_available
+from diffusers.utils.import_utils import is_xformers_available
+
+
+MAX_SEQ_LENGTH = 77
+
+if is_wandb_available():
+ import wandb
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.31.0.dev0")
+
+logger = get_logger(__name__)
+
+
+def filter_keys(key_set):
+ def _f(dictionary):
+ return {k: v for k, v in dictionary.items() if k in key_set}
+
+ return _f
+
+
+def group_by_keys_nothrow(data, keys=base_plus_ext, lcase=True, suffixes=None, handler=None):
+ """Return function over iterator that groups key, value pairs into samples.
+
+ :param keys: function that splits the key into key and extension (base_plus_ext) :param lcase: convert suffixes to
+ lower case (Default value = True)
+ """
+ current_sample = None
+ for filesample in data:
+ assert isinstance(filesample, dict)
+ fname, value = filesample["fname"], filesample["data"]
+ prefix, suffix = keys(fname)
+ if prefix is None:
+ continue
+ if lcase:
+ suffix = suffix.lower()
+ # FIXME webdataset version throws if suffix in current_sample, but we have a potential for
+ # this happening in the current LAION400m dataset if a tar ends with same prefix as the next
+ # begins, rare, but can happen since prefix aren't unique across tar files in that dataset
+ if current_sample is None or prefix != current_sample["__key__"] or suffix in current_sample:
+ if valid_sample(current_sample):
+ yield current_sample
+ current_sample = {"__key__": prefix, "__url__": filesample["__url__"]}
+ if suffixes is None or suffix in suffixes:
+ current_sample[suffix] = value
+ if valid_sample(current_sample):
+ yield current_sample
+
+
+def tarfile_to_samples_nothrow(src, handler=wds.warn_and_continue):
+ # NOTE this is a re-impl of the webdataset impl with group_by_keys that doesn't throw
+ streams = url_opener(src, handler=handler)
+ files = tar_file_expander(streams, handler=handler)
+ samples = group_by_keys_nothrow(files, handler=handler)
+ return samples
+
+
+class WebdatasetFilter:
+ def __init__(self, min_size=1024, max_pwatermark=0.5):
+ self.min_size = min_size
+ self.max_pwatermark = max_pwatermark
+
+ def __call__(self, x):
+ try:
+ if "json" in x:
+ x_json = json.loads(x["json"])
+ filter_size = (x_json.get("original_width", 0.0) or 0.0) >= self.min_size and x_json.get(
+ "original_height", 0
+ ) >= self.min_size
+ filter_watermark = (x_json.get("pwatermark", 1.0) or 1.0) <= self.max_pwatermark
+ return filter_size and filter_watermark
+ else:
+ return False
+ except Exception:
+ return False
+
+
+class SDText2ImageDataset:
+ def __init__(
+ self,
+ train_shards_path_or_url: Union[str, List[str]],
+ num_train_examples: int,
+ per_gpu_batch_size: int,
+ global_batch_size: int,
+ num_workers: int,
+ resolution: int = 512,
+ interpolation_type: str = "bilinear",
+ shuffle_buffer_size: int = 1000,
+ pin_memory: bool = False,
+ persistent_workers: bool = False,
+ ):
+ if not isinstance(train_shards_path_or_url, str):
+ train_shards_path_or_url = [list(braceexpand(urls)) for urls in train_shards_path_or_url]
+ # flatten list using itertools
+ train_shards_path_or_url = list(itertools.chain.from_iterable(train_shards_path_or_url))
+
+ interpolation_mode = resolve_interpolation_mode(interpolation_type)
+
+ def transform(example):
+ # resize image
+ image = example["image"]
+ image = TF.resize(image, resolution, interpolation=interpolation_mode)
+
+ # get crop coordinates and crop image
+ c_top, c_left, _, _ = transforms.RandomCrop.get_params(image, output_size=(resolution, resolution))
+ image = TF.crop(image, c_top, c_left, resolution, resolution)
+ image = TF.to_tensor(image)
+ image = TF.normalize(image, [0.5], [0.5])
+
+ example["image"] = image
+ return example
+
+ processing_pipeline = [
+ wds.decode("pil", handler=wds.ignore_and_continue),
+ wds.rename(image="jpg;png;jpeg;webp", text="text;txt;caption", handler=wds.warn_and_continue),
+ wds.map(filter_keys({"image", "text"})),
+ wds.map(transform),
+ wds.to_tuple("image", "text"),
+ ]
+
+ # Create train dataset and loader
+ pipeline = [
+ wds.ResampledShards(train_shards_path_or_url),
+ tarfile_to_samples_nothrow,
+ wds.shuffle(shuffle_buffer_size),
+ *processing_pipeline,
+ wds.batched(per_gpu_batch_size, partial=False, collation_fn=default_collate),
+ ]
+
+ num_worker_batches = math.ceil(num_train_examples / (global_batch_size * num_workers)) # per dataloader worker
+ num_batches = num_worker_batches * num_workers
+ num_samples = num_batches * global_batch_size
+
+ # each worker is iterating over this
+ self._train_dataset = wds.DataPipeline(*pipeline).with_epoch(num_worker_batches)
+ self._train_dataloader = wds.WebLoader(
+ self._train_dataset,
+ batch_size=None,
+ shuffle=False,
+ num_workers=num_workers,
+ pin_memory=pin_memory,
+ persistent_workers=persistent_workers,
+ )
+ # add meta-data to dataloader instance for convenience
+ self._train_dataloader.num_batches = num_batches
+ self._train_dataloader.num_samples = num_samples
+
+ @property
+ def train_dataset(self):
+ return self._train_dataset
+
+ @property
+ def train_dataloader(self):
+ return self._train_dataloader
+
+
+def log_validation(vae, unet, args, accelerator, weight_dtype, step, name="target"):
+ logger.info("Running validation... ")
+
+ unet = accelerator.unwrap_model(unet)
+ pipeline = StableDiffusionPipeline.from_pretrained(
+ args.pretrained_teacher_model,
+ vae=vae,
+ unet=unet,
+ scheduler=LCMScheduler.from_pretrained(args.pretrained_teacher_model, subfolder="scheduler"),
+ revision=args.revision,
+ torch_dtype=weight_dtype,
+ )
+ pipeline = pipeline.to(accelerator.device)
+ pipeline.set_progress_bar_config(disable=True)
+
+ if args.enable_xformers_memory_efficient_attention:
+ pipeline.enable_xformers_memory_efficient_attention()
+
+ if args.seed is None:
+ generator = None
+ else:
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
+
+ validation_prompts = [
+ "portrait photo of a girl, photograph, highly detailed face, depth of field, moody light, golden hour, style by Dan Winters, Russell James, Steve McCurry, centered, extremely detailed, Nikon D850, award winning photography",
+ "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k",
+ "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
+ "A photo of beautiful mountain with realistic sunset and blue lake, highly detailed, masterpiece",
+ ]
+
+ image_logs = []
+
+ for _, prompt in enumerate(validation_prompts):
+ images = []
+ if torch.backends.mps.is_available():
+ autocast_ctx = nullcontext()
+ else:
+ autocast_ctx = torch.autocast(accelerator.device.type)
+
+ with autocast_ctx:
+ images = pipeline(
+ prompt=prompt,
+ num_inference_steps=4,
+ num_images_per_prompt=4,
+ generator=generator,
+ ).images
+ image_logs.append({"validation_prompt": prompt, "images": images})
+
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ for log in image_logs:
+ images = log["images"]
+ validation_prompt = log["validation_prompt"]
+ formatted_images = []
+ for image in images:
+ formatted_images.append(np.asarray(image))
+
+ formatted_images = np.stack(formatted_images)
+
+ tracker.writer.add_images(validation_prompt, formatted_images, step, dataformats="NHWC")
+ elif tracker.name == "wandb":
+ formatted_images = []
+
+ for log in image_logs:
+ images = log["images"]
+ validation_prompt = log["validation_prompt"]
+ for image in images:
+ image = wandb.Image(image, caption=validation_prompt)
+ formatted_images.append(image)
+
+ tracker.log({f"validation/{name}": formatted_images})
+ else:
+ logger.warning(f"image logging not implemented for {tracker.name}")
+
+ del pipeline
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ return image_logs
+
+
+# From LatentConsistencyModel.get_guidance_scale_embedding
+def guidance_scale_embedding(w, embedding_dim=512, dtype=torch.float32):
+ """
+ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
+
+ Args:
+ timesteps (`torch.Tensor`):
+ generate embedding vectors at these timesteps
+ embedding_dim (`int`, *optional*, defaults to 512):
+ dimension of the embeddings to generate
+ dtype:
+ data type of the generated embeddings
+
+ Returns:
+ `torch.Tensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
+ """
+ assert len(w.shape) == 1
+ w = w * 1000.0
+
+ half_dim = embedding_dim // 2
+ emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
+ emb = w.to(dtype)[:, None] * emb[None, :]
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+ if embedding_dim % 2 == 1: # zero pad
+ emb = torch.nn.functional.pad(emb, (0, 1))
+ assert emb.shape == (w.shape[0], embedding_dim)
+ return emb
+
+
+def append_dims(x, target_dims):
+ """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
+ dims_to_append = target_dims - x.ndim
+ if dims_to_append < 0:
+ raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less")
+ return x[(...,) + (None,) * dims_to_append]
+
+
+# From LCMScheduler.get_scalings_for_boundary_condition_discrete
+def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=10.0):
+ scaled_timestep = timestep_scaling * timestep
+ c_skip = sigma_data**2 / (scaled_timestep**2 + sigma_data**2)
+ c_out = scaled_timestep / (scaled_timestep**2 + sigma_data**2) ** 0.5
+ return c_skip, c_out
+
+
+# Compare LCMScheduler.step, Step 4
+def get_predicted_original_sample(model_output, timesteps, sample, prediction_type, alphas, sigmas):
+ alphas = extract_into_tensor(alphas, timesteps, sample.shape)
+ sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
+ if prediction_type == "epsilon":
+ pred_x_0 = (sample - sigmas * model_output) / alphas
+ elif prediction_type == "sample":
+ pred_x_0 = model_output
+ elif prediction_type == "v_prediction":
+ pred_x_0 = alphas * sample - sigmas * model_output
+ else:
+ raise ValueError(
+ f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`"
+ f" are supported."
+ )
+
+ return pred_x_0
+
+
+# Based on step 4 in DDIMScheduler.step
+def get_predicted_noise(model_output, timesteps, sample, prediction_type, alphas, sigmas):
+ alphas = extract_into_tensor(alphas, timesteps, sample.shape)
+ sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
+ if prediction_type == "epsilon":
+ pred_epsilon = model_output
+ elif prediction_type == "sample":
+ pred_epsilon = (sample - alphas * model_output) / sigmas
+ elif prediction_type == "v_prediction":
+ pred_epsilon = alphas * model_output + sigmas * sample
+ else:
+ raise ValueError(
+ f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`"
+ f" are supported."
+ )
+
+ return pred_epsilon
+
+
+def extract_into_tensor(a, t, x_shape):
+ b, *_ = t.shape
+ out = a.gather(-1, t)
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
+
+
+class DDIMSolver:
+ def __init__(self, alpha_cumprods, timesteps=1000, ddim_timesteps=50):
+ # DDIM sampling parameters
+ step_ratio = timesteps // ddim_timesteps
+ self.ddim_timesteps = (np.arange(1, ddim_timesteps + 1) * step_ratio).round().astype(np.int64) - 1
+ self.ddim_alpha_cumprods = alpha_cumprods[self.ddim_timesteps]
+ self.ddim_alpha_cumprods_prev = np.asarray(
+ [alpha_cumprods[0]] + alpha_cumprods[self.ddim_timesteps[:-1]].tolist()
+ )
+ # convert to torch tensors
+ self.ddim_timesteps = torch.from_numpy(self.ddim_timesteps).long()
+ self.ddim_alpha_cumprods = torch.from_numpy(self.ddim_alpha_cumprods)
+ self.ddim_alpha_cumprods_prev = torch.from_numpy(self.ddim_alpha_cumprods_prev)
+
+ def to(self, device):
+ self.ddim_timesteps = self.ddim_timesteps.to(device)
+ self.ddim_alpha_cumprods = self.ddim_alpha_cumprods.to(device)
+ self.ddim_alpha_cumprods_prev = self.ddim_alpha_cumprods_prev.to(device)
+ return self
+
+ def ddim_step(self, pred_x0, pred_noise, timestep_index):
+ alpha_cumprod_prev = extract_into_tensor(self.ddim_alpha_cumprods_prev, timestep_index, pred_x0.shape)
+ dir_xt = (1.0 - alpha_cumprod_prev).sqrt() * pred_noise
+ x_prev = alpha_cumprod_prev.sqrt() * pred_x0 + dir_xt
+ return x_prev
+
+
+@torch.no_grad()
+def update_ema(target_params, source_params, rate=0.99):
+ """
+ Update target parameters to be closer to those of source parameters using
+ an exponential moving average.
+
+ :param target_params: the target parameter sequence.
+ :param source_params: the source parameter sequence.
+ :param rate: the EMA rate (closer to 1 means slower).
+ """
+ for targ, src in zip(target_params, source_params):
+ targ.detach().mul_(rate).add_(src, alpha=1 - rate)
+
+
+def import_model_class_from_model_name_or_path(
+ pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
+):
+ text_encoder_config = PretrainedConfig.from_pretrained(
+ pretrained_model_name_or_path, subfolder=subfolder, revision=revision
+ )
+ model_class = text_encoder_config.architectures[0]
+
+ if model_class == "CLIPTextModel":
+ from transformers import CLIPTextModel
+
+ return CLIPTextModel
+ elif model_class == "CLIPTextModelWithProjection":
+ from transformers import CLIPTextModelWithProjection
+
+ return CLIPTextModelWithProjection
+ else:
+ raise ValueError(f"{model_class} is not supported.")
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ # ----------Model Checkpoint Loading Arguments----------
+ parser.add_argument(
+ "--pretrained_teacher_model",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained LDM teacher model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--pretrained_vae_model_name_or_path",
+ type=str,
+ default=None,
+ help="Path to pretrained VAE model with better numerical stability. More details: https://github.com/huggingface/diffusers/pull/4038.",
+ )
+ parser.add_argument(
+ "--teacher_revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained LDM teacher model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained LDM model identifier from huggingface.co/models.",
+ )
+ # ----------Training Arguments----------
+ # ----General Training Arguments----
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="lcm-xl-distilled",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ # ----Logging----
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ # ----Checkpointing----
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ # ----Image Processing----
+ parser.add_argument(
+ "--train_shards_path_or_url",
+ type=str,
+ default=None,
+ help=(
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
+ " or to a folder containing files that 🤗 Datasets can understand."
+ ),
+ )
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--interpolation_type",
+ type=str,
+ default="bilinear",
+ help=(
+ "The interpolation function used when resizing images to the desired resolution. Choose between `bilinear`,"
+ " `bicubic`, `box`, `nearest`, `nearest_exact`, `hamming`, and `lanczos`."
+ ),
+ )
+ parser.add_argument(
+ "--center_crop",
+ default=False,
+ action="store_true",
+ help=(
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
+ " cropped. The images will be resized to the resolution first before cropping."
+ ),
+ )
+ parser.add_argument(
+ "--random_flip",
+ action="store_true",
+ help="whether to randomly flip images horizontally",
+ )
+ # ----Dataloader----
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ # ----Batch Size and Training Steps----
+ parser.add_argument(
+ "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=100)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--max_train_samples",
+ type=int,
+ default=None,
+ help=(
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ ),
+ )
+ # ----Learning Rate----
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ # ----Optimizer (Adam)----
+ parser.add_argument(
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ # ----Diffusion Training Arguments----
+ parser.add_argument(
+ "--proportion_empty_prompts",
+ type=float,
+ default=0,
+ help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).",
+ )
+ # ----Latent Consistency Distillation (LCD) Specific Arguments----
+ parser.add_argument(
+ "--w_min",
+ type=float,
+ default=5.0,
+ required=False,
+ help=(
+ "The minimum guidance scale value for guidance scale sampling. Note that we are using the Imagen CFG"
+ " formulation rather than the LCM formulation, which means all guidance scales have 1 added to them as"
+ " compared to the original paper."
+ ),
+ )
+ parser.add_argument(
+ "--w_max",
+ type=float,
+ default=15.0,
+ required=False,
+ help=(
+ "The maximum guidance scale value for guidance scale sampling. Note that we are using the Imagen CFG"
+ " formulation rather than the LCM formulation, which means all guidance scales have 1 added to them as"
+ " compared to the original paper."
+ ),
+ )
+ parser.add_argument(
+ "--num_ddim_timesteps",
+ type=int,
+ default=50,
+ help="The number of timesteps to use for DDIM sampling.",
+ )
+ parser.add_argument(
+ "--loss_type",
+ type=str,
+ default="l2",
+ choices=["l2", "huber"],
+ help="The type of loss to use for the LCD loss.",
+ )
+ parser.add_argument(
+ "--huber_c",
+ type=float,
+ default=0.001,
+ help="The huber loss parameter. Only used if `--loss_type=huber`.",
+ )
+ parser.add_argument(
+ "--unet_time_cond_proj_dim",
+ type=int,
+ default=256,
+ help=(
+ "The dimension of the guidance scale embedding in the U-Net, which will be used if the teacher U-Net"
+ " does not have `time_cond_proj_dim` set."
+ ),
+ )
+ parser.add_argument(
+ "--vae_encode_batch_size",
+ type=int,
+ default=32,
+ required=False,
+ help=(
+ "The batch size used when encoding (and decoding) images to latents (and vice versa) using the VAE."
+ " Encoding or decoding the whole batch at once may run into OOM issues."
+ ),
+ )
+ parser.add_argument(
+ "--timestep_scaling_factor",
+ type=float,
+ default=10.0,
+ help=(
+ "The multiplicative timestep scaling factor used when calculating the boundary scalings for LCM. The"
+ " higher the scaling is, the lower the approximation error, but the default value of 10.0 should typically"
+ " suffice."
+ ),
+ )
+ # ----Exponential Moving Average (EMA)----
+ parser.add_argument(
+ "--ema_decay",
+ type=float,
+ default=0.95,
+ required=False,
+ help="The exponential moving average (EMA) rate or decay factor.",
+ )
+ # ----Mixed Precision----
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument(
+ "--cast_teacher_unet",
+ action="store_true",
+ help="Whether to cast the teacher U-Net to the precision specified by `--mixed_precision`.",
+ )
+ # ----Training Optimizations----
+ parser.add_argument(
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ # ----Distributed Training----
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+ # ----------Validation Arguments----------
+ parser.add_argument(
+ "--validation_steps",
+ type=int,
+ default=200,
+ help="Run validation every X steps.",
+ )
+ # ----------Huggingface Hub Arguments-----------
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ # ----------Accelerate Arguments----------
+ parser.add_argument(
+ "--tracker_project_name",
+ type=str,
+ default="text2image-fine-tune",
+ help=(
+ "The `project_name` argument passed to Accelerator.init_trackers for"
+ " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
+ ),
+ )
+
+ args = parser.parse_args()
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1:
+ raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].")
+
+ return args
+
+
+# Adapted from pipelines.StableDiffusionPipeline.encode_prompt
+def encode_prompt(prompt_batch, text_encoder, tokenizer, proportion_empty_prompts, is_train=True):
+ captions = []
+ for caption in prompt_batch:
+ if random.random() < proportion_empty_prompts:
+ captions.append("")
+ elif isinstance(caption, str):
+ captions.append(caption)
+ elif isinstance(caption, (list, np.ndarray)):
+ # take a random caption if there are multiple
+ captions.append(random.choice(caption) if is_train else caption[0])
+
+ with torch.no_grad():
+ text_inputs = tokenizer(
+ captions,
+ padding="max_length",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ prompt_embeds = text_encoder(text_input_ids.to(text_encoder.device))[0]
+
+ return prompt_embeds
+
+
+def main(args):
+ if args.report_to == "wandb" and args.hub_token is not None:
+ raise ValueError(
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
+ " Please use `huggingface-cli login` to authenticate with the Hub."
+ )
+
+ logging_dir = Path(args.output_dir, args.logging_dir)
+
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ split_batches=True, # It's important to set this to True when using webdataset to get the right number of steps for lr scheduling. If set to False, the number of steps will be devide by the number of processes assuming batches are multiplied by the number of processes
+ )
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name,
+ exist_ok=True,
+ token=args.hub_token,
+ private=True,
+ ).repo_id
+
+ # 1. Create the noise scheduler and the desired noise schedule.
+ noise_scheduler = DDPMScheduler.from_pretrained(
+ args.pretrained_teacher_model, subfolder="scheduler", revision=args.teacher_revision
+ )
+
+ # DDPMScheduler calculates the alpha and sigma noise schedules (based on the alpha bars) for us
+ alpha_schedule = torch.sqrt(noise_scheduler.alphas_cumprod)
+ sigma_schedule = torch.sqrt(1 - noise_scheduler.alphas_cumprod)
+ # Initialize the DDIM ODE solver for distillation.
+ solver = DDIMSolver(
+ noise_scheduler.alphas_cumprod.numpy(),
+ timesteps=noise_scheduler.config.num_train_timesteps,
+ ddim_timesteps=args.num_ddim_timesteps,
+ )
+
+ # 2. Load tokenizers from SD 1.X/2.X checkpoint.
+ tokenizer = AutoTokenizer.from_pretrained(
+ args.pretrained_teacher_model, subfolder="tokenizer", revision=args.teacher_revision, use_fast=False
+ )
+
+ # 3. Load text encoders from SD 1.X/2.X checkpoint.
+ # import correct text encoder classes
+ text_encoder = CLIPTextModel.from_pretrained(
+ args.pretrained_teacher_model, subfolder="text_encoder", revision=args.teacher_revision
+ )
+
+ # 4. Load VAE from SD 1.X/2.X checkpoint
+ vae = AutoencoderKL.from_pretrained(
+ args.pretrained_teacher_model,
+ subfolder="vae",
+ revision=args.teacher_revision,
+ )
+
+ # 5. Load teacher U-Net from SD 1.X/2.X checkpoint
+ teacher_unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_teacher_model, subfolder="unet", revision=args.teacher_revision
+ )
+
+ # 6. Freeze teacher vae, text_encoder, and teacher_unet
+ vae.requires_grad_(False)
+ text_encoder.requires_grad_(False)
+ teacher_unet.requires_grad_(False)
+
+ # 7. Create online student U-Net. This will be updated by the optimizer (e.g. via backpropagation.)
+ # Add `time_cond_proj_dim` to the student U-Net if `teacher_unet.config.time_cond_proj_dim` is None
+ time_cond_proj_dim = (
+ teacher_unet.config.time_cond_proj_dim
+ if teacher_unet.config.time_cond_proj_dim is not None
+ else args.unet_time_cond_proj_dim
+ )
+ unet = UNet2DConditionModel.from_config(teacher_unet.config, time_cond_proj_dim=time_cond_proj_dim)
+ # load teacher_unet weights into unet
+ unet.load_state_dict(teacher_unet.state_dict(), strict=False)
+ unet.train()
+
+ # 8. Create target student U-Net. This will be updated via EMA updates (polyak averaging).
+ # Initialize from (online) unet
+ target_unet = UNet2DConditionModel.from_config(unet.config)
+ target_unet.load_state_dict(unet.state_dict())
+ target_unet.train()
+ target_unet.requires_grad_(False)
+
+ # Check that all trainable models are in full precision
+ low_precision_error_string = (
+ " Please make sure to always have all model weights in full float32 precision when starting training - even if"
+ " doing mixed precision training, copy of the weights should still be float32."
+ )
+
+ if accelerator.unwrap_model(unet).dtype != torch.float32:
+ raise ValueError(
+ f"Controlnet loaded as datatype {accelerator.unwrap_model(unet).dtype}. {low_precision_error_string}"
+ )
+
+ # 9. Handle mixed precision and device placement
+ # For mixed precision training we cast all non-trainable weigths to half-precision
+ # as these weights are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ # Move unet, vae and text_encoder to device and cast to weight_dtype
+ # The VAE is in float32 to avoid NaN losses.
+ vae.to(accelerator.device)
+ if args.pretrained_vae_model_name_or_path is not None:
+ vae.to(dtype=weight_dtype)
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
+
+ # Move teacher_unet to device, optionally cast to weight_dtype
+ target_unet.to(accelerator.device)
+ teacher_unet.to(accelerator.device)
+ if args.cast_teacher_unet:
+ teacher_unet.to(dtype=weight_dtype)
+
+ # Also move the alpha and sigma noise schedules to accelerator.device.
+ alpha_schedule = alpha_schedule.to(accelerator.device)
+ sigma_schedule = sigma_schedule.to(accelerator.device)
+ solver = solver.to(accelerator.device)
+
+ # 10. Handle saving and loading of checkpoints
+ # `accelerate` 0.16.0 will have better support for customized saving
+ if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
+ def save_model_hook(models, weights, output_dir):
+ if accelerator.is_main_process:
+ target_unet.save_pretrained(os.path.join(output_dir, "unet_target"))
+
+ for i, model in enumerate(models):
+ model.save_pretrained(os.path.join(output_dir, "unet"))
+
+ # make sure to pop weight so that corresponding model is not saved again
+ weights.pop()
+
+ def load_model_hook(models, input_dir):
+ load_model = UNet2DConditionModel.from_pretrained(os.path.join(input_dir, "unet_target"))
+ target_unet.load_state_dict(load_model.state_dict())
+ target_unet.to(accelerator.device)
+ del load_model
+
+ for i in range(len(models)):
+ # pop models so that they are not loaded again
+ model = models.pop()
+
+ # load diffusers style into model
+ load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet")
+ model.register_to_config(**load_model.config)
+
+ model.load_state_dict(load_model.state_dict())
+ del load_model
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
+
+ # 11. Enable optimizations
+ if args.enable_xformers_memory_efficient_attention:
+ if is_xformers_available():
+ import xformers
+
+ xformers_version = version.parse(xformers.__version__)
+ if xformers_version == version.parse("0.0.16"):
+ logger.warning(
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
+ )
+ unet.enable_xformers_memory_efficient_attention()
+ teacher_unet.enable_xformers_memory_efficient_attention()
+ target_unet.enable_xformers_memory_efficient_attention()
+ else:
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.gradient_checkpointing:
+ unet.enable_gradient_checkpointing()
+
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
+ )
+
+ optimizer_class = bnb.optim.AdamW8bit
+ else:
+ optimizer_class = torch.optim.AdamW
+
+ # 12. Optimizer creation
+ optimizer = optimizer_class(
+ unet.parameters(),
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ # 13. Dataset creation and data processing
+ # Here, we compute not just the text embeddings but also the additional embeddings
+ # needed for the SD XL UNet to operate.
+ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tokenizer, is_train=True):
+ prompt_embeds = encode_prompt(prompt_batch, text_encoder, tokenizer, proportion_empty_prompts, is_train)
+ return {"prompt_embeds": prompt_embeds}
+
+ dataset = SDText2ImageDataset(
+ train_shards_path_or_url=args.train_shards_path_or_url,
+ num_train_examples=args.max_train_samples,
+ per_gpu_batch_size=args.train_batch_size,
+ global_batch_size=args.train_batch_size * accelerator.num_processes,
+ num_workers=args.dataloader_num_workers,
+ resolution=args.resolution,
+ interpolation_type=args.interpolation_type,
+ shuffle_buffer_size=1000,
+ pin_memory=True,
+ persistent_workers=True,
+ )
+ train_dataloader = dataset.train_dataloader
+
+ compute_embeddings_fn = functools.partial(
+ compute_embeddings,
+ proportion_empty_prompts=0,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ )
+
+ # 14. LR Scheduler creation
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(train_dataloader.num_batches / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps,
+ num_training_steps=args.max_train_steps,
+ )
+
+ # 15. Prepare for training
+ # Prepare everything with our `accelerator`.
+ unet, optimizer, lr_scheduler = accelerator.prepare(unet, optimizer, lr_scheduler)
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(train_dataloader.num_batches / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ tracker_config = dict(vars(args))
+ accelerator.init_trackers(args.tracker_project_name, config=tracker_config)
+
+ uncond_input_ids = tokenizer(
+ [""] * args.train_batch_size, return_tensors="pt", padding="max_length", max_length=77
+ ).input_ids.to(accelerator.device)
+ uncond_prompt_embeds = text_encoder(uncond_input_ids)[0]
+
+ # 16. Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num batches each epoch = {train_dataloader.num_batches}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the most recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ initial_global_step = 0
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ initial_global_step = global_step
+ first_epoch = global_step // num_update_steps_per_epoch
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
+
+ for epoch in range(first_epoch, args.num_train_epochs):
+ for step, batch in enumerate(train_dataloader):
+ with accelerator.accumulate(unet):
+ # 1. Load and process the image and text conditioning
+ image, text = batch
+
+ image = image.to(accelerator.device, non_blocking=True)
+ encoded_text = compute_embeddings_fn(text)
+
+ pixel_values = image.to(dtype=weight_dtype)
+ if vae.dtype != weight_dtype:
+ vae.to(dtype=weight_dtype)
+
+ # encode pixel values with batch size of at most args.vae_encode_batch_size
+ latents = []
+ for i in range(0, pixel_values.shape[0], args.vae_encode_batch_size):
+ latents.append(vae.encode(pixel_values[i : i + args.vae_encode_batch_size]).latent_dist.sample())
+ latents = torch.cat(latents, dim=0)
+
+ latents = latents * vae.config.scaling_factor
+ latents = latents.to(weight_dtype)
+ bsz = latents.shape[0]
+
+ # 2. Sample a random timestep for each image t_n from the ODE solver timesteps without bias.
+ # For the DDIM solver, the timestep schedule is [T - 1, T - k - 1, T - 2 * k - 1, ...]
+ topk = noise_scheduler.config.num_train_timesteps // args.num_ddim_timesteps
+ index = torch.randint(0, args.num_ddim_timesteps, (bsz,), device=latents.device).long()
+ start_timesteps = solver.ddim_timesteps[index]
+ timesteps = start_timesteps - topk
+ timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps)
+
+ # 3. Get boundary scalings for start_timesteps and (end) timesteps.
+ c_skip_start, c_out_start = scalings_for_boundary_conditions(
+ start_timesteps, timestep_scaling=args.timestep_scaling_factor
+ )
+ c_skip_start, c_out_start = [append_dims(x, latents.ndim) for x in [c_skip_start, c_out_start]]
+ c_skip, c_out = scalings_for_boundary_conditions(
+ timesteps, timestep_scaling=args.timestep_scaling_factor
+ )
+ c_skip, c_out = [append_dims(x, latents.ndim) for x in [c_skip, c_out]]
+
+ # 4. Sample noise from the prior and add it to the latents according to the noise magnitude at each
+ # timestep (this is the forward diffusion process) [z_{t_{n + k}} in Algorithm 1]
+ noise = torch.randn_like(latents)
+ noisy_model_input = noise_scheduler.add_noise(latents, noise, start_timesteps)
+
+ # 5. Sample a random guidance scale w from U[w_min, w_max] and embed it
+ w = (args.w_max - args.w_min) * torch.rand((bsz,)) + args.w_min
+ w_embedding = guidance_scale_embedding(w, embedding_dim=time_cond_proj_dim)
+ w = w.reshape(bsz, 1, 1, 1)
+ # Move to U-Net device and dtype
+ w = w.to(device=latents.device, dtype=latents.dtype)
+ w_embedding = w_embedding.to(device=latents.device, dtype=latents.dtype)
+
+ # 6. Prepare prompt embeds and unet_added_conditions
+ prompt_embeds = encoded_text.pop("prompt_embeds")
+
+ # 7. Get online LCM prediction on z_{t_{n + k}} (noisy_model_input), w, c, t_{n + k} (start_timesteps)
+ noise_pred = unet(
+ noisy_model_input,
+ start_timesteps,
+ timestep_cond=w_embedding,
+ encoder_hidden_states=prompt_embeds.float(),
+ added_cond_kwargs=encoded_text,
+ ).sample
+
+ pred_x_0 = get_predicted_original_sample(
+ noise_pred,
+ start_timesteps,
+ noisy_model_input,
+ noise_scheduler.config.prediction_type,
+ alpha_schedule,
+ sigma_schedule,
+ )
+
+ model_pred = c_skip_start * noisy_model_input + c_out_start * pred_x_0
+
+ # 8. Compute the conditional and unconditional teacher model predictions to get CFG estimates of the
+ # predicted noise eps_0 and predicted original sample x_0, then run the ODE solver using these
+ # estimates to predict the data point in the augmented PF-ODE trajectory corresponding to the next ODE
+ # solver timestep.
+ with torch.no_grad():
+ if torch.backends.mps.is_available():
+ autocast_ctx = nullcontext()
+ else:
+ autocast_ctx = torch.autocast(accelerator.device.type)
+
+ with autocast_ctx:
+ # 1. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and conditional embedding c
+ cond_teacher_output = teacher_unet(
+ noisy_model_input.to(weight_dtype),
+ start_timesteps,
+ encoder_hidden_states=prompt_embeds.to(weight_dtype),
+ ).sample
+ cond_pred_x0 = get_predicted_original_sample(
+ cond_teacher_output,
+ start_timesteps,
+ noisy_model_input,
+ noise_scheduler.config.prediction_type,
+ alpha_schedule,
+ sigma_schedule,
+ )
+ cond_pred_noise = get_predicted_noise(
+ cond_teacher_output,
+ start_timesteps,
+ noisy_model_input,
+ noise_scheduler.config.prediction_type,
+ alpha_schedule,
+ sigma_schedule,
+ )
+
+ # 2. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and unconditional embedding 0
+ uncond_teacher_output = teacher_unet(
+ noisy_model_input.to(weight_dtype),
+ start_timesteps,
+ encoder_hidden_states=uncond_prompt_embeds.to(weight_dtype),
+ ).sample
+ uncond_pred_x0 = get_predicted_original_sample(
+ uncond_teacher_output,
+ start_timesteps,
+ noisy_model_input,
+ noise_scheduler.config.prediction_type,
+ alpha_schedule,
+ sigma_schedule,
+ )
+ uncond_pred_noise = get_predicted_noise(
+ uncond_teacher_output,
+ start_timesteps,
+ noisy_model_input,
+ noise_scheduler.config.prediction_type,
+ alpha_schedule,
+ sigma_schedule,
+ )
+
+ # 3. Calculate the CFG estimate of x_0 (pred_x0) and eps_0 (pred_noise)
+ # Note that this uses the LCM paper's CFG formulation rather than the Imagen CFG formulation
+ pred_x0 = cond_pred_x0 + w * (cond_pred_x0 - uncond_pred_x0)
+ pred_noise = cond_pred_noise + w * (cond_pred_noise - uncond_pred_noise)
+ # 4. Run one step of the ODE solver to estimate the next point x_prev on the
+ # augmented PF-ODE trajectory (solving backward in time)
+ # Note that the DDIM step depends on both the predicted x_0 and source noise eps_0.
+ x_prev = solver.ddim_step(pred_x0, pred_noise, index)
+
+ # 9. Get target LCM prediction on x_prev, w, c, t_n (timesteps)
+ with torch.no_grad():
+ if torch.backends.mps.is_available():
+ autocast_ctx = nullcontext()
+ else:
+ autocast_ctx = torch.autocast(accelerator.device.type, dtype=weight_dtype)
+
+ with autocast_ctx:
+ target_noise_pred = target_unet(
+ x_prev.float(),
+ timesteps,
+ timestep_cond=w_embedding,
+ encoder_hidden_states=prompt_embeds.float(),
+ ).sample
+ pred_x_0 = get_predicted_original_sample(
+ target_noise_pred,
+ timesteps,
+ x_prev,
+ noise_scheduler.config.prediction_type,
+ alpha_schedule,
+ sigma_schedule,
+ )
+ target = c_skip * x_prev + c_out * pred_x_0
+
+ # 10. Calculate loss
+ if args.loss_type == "l2":
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
+ elif args.loss_type == "huber":
+ loss = torch.mean(
+ torch.sqrt((model_pred.float() - target.float()) ** 2 + args.huber_c**2) - args.huber_c
+ )
+
+ # 11. Backpropagate on the online student model (`unet`)
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad(set_to_none=True)
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ # 12. Make EMA update to target student model parameters (`target_unet`)
+ update_ema(target_unet.parameters(), unet.parameters(), args.ema_decay)
+ progress_bar.update(1)
+ global_step += 1
+
+ if accelerator.is_main_process:
+ if global_step % args.checkpointing_steps == 0:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ if global_step % args.validation_steps == 0:
+ log_validation(vae, target_unet, args, accelerator, weight_dtype, global_step, "target")
+ log_validation(vae, unet, args, accelerator, weight_dtype, global_step, "online")
+
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+ accelerator.log(logs, step=global_step)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ # Create the pipeline using using the trained modules and save it.
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ unet = accelerator.unwrap_model(unet)
+ unet.save_pretrained(os.path.join(args.output_dir, "unet"))
+
+ target_unet = accelerator.unwrap_model(target_unet)
+ target_unet.save_pretrained(os.path.join(args.output_dir, "unet_target"))
+
+ if args.push_to_hub:
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ main(args)
diff --git a/diffusers/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py b/diffusers/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py
new file mode 100644
index 0000000000000000000000000000000000000000..cc5e6812127e1cb1101f31ef3016ee8e8718aaa1
--- /dev/null
+++ b/diffusers/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py
@@ -0,0 +1,1536 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+import argparse
+import copy
+import functools
+import gc
+import itertools
+import json
+import logging
+import math
+import os
+import random
+import shutil
+from contextlib import nullcontext
+from pathlib import Path
+from typing import List, Union
+
+import accelerate
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import torchvision.transforms.functional as TF
+import transformers
+import webdataset as wds
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import ProjectConfiguration, set_seed
+from braceexpand import braceexpand
+from huggingface_hub import create_repo, upload_folder
+from packaging import version
+from torch.utils.data import default_collate
+from torchvision import transforms
+from tqdm.auto import tqdm
+from transformers import AutoTokenizer, PretrainedConfig
+from webdataset.tariterators import (
+ base_plus_ext,
+ tar_file_expander,
+ url_opener,
+ valid_sample,
+)
+
+import diffusers
+from diffusers import (
+ AutoencoderKL,
+ DDPMScheduler,
+ LCMScheduler,
+ StableDiffusionXLPipeline,
+ UNet2DConditionModel,
+)
+from diffusers.optimization import get_scheduler
+from diffusers.training_utils import resolve_interpolation_mode
+from diffusers.utils import check_min_version, is_wandb_available
+from diffusers.utils.import_utils import is_xformers_available
+
+
+MAX_SEQ_LENGTH = 77
+
+# Adjust for your dataset
+WDS_JSON_WIDTH = "width" # original_width for LAION
+WDS_JSON_HEIGHT = "height" # original_height for LAION
+MIN_SIZE = 700 # ~960 for LAION, ideal: 1024 if the dataset contains large images
+
+if is_wandb_available():
+ import wandb
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.31.0.dev0")
+
+logger = get_logger(__name__)
+
+
+def filter_keys(key_set):
+ def _f(dictionary):
+ return {k: v for k, v in dictionary.items() if k in key_set}
+
+ return _f
+
+
+def group_by_keys_nothrow(data, keys=base_plus_ext, lcase=True, suffixes=None, handler=None):
+ """Return function over iterator that groups key, value pairs into samples.
+
+ :param keys: function that splits the key into key and extension (base_plus_ext) :param lcase: convert suffixes to
+ lower case (Default value = True)
+ """
+ current_sample = None
+ for filesample in data:
+ assert isinstance(filesample, dict)
+ fname, value = filesample["fname"], filesample["data"]
+ prefix, suffix = keys(fname)
+ if prefix is None:
+ continue
+ if lcase:
+ suffix = suffix.lower()
+ # FIXME webdataset version throws if suffix in current_sample, but we have a potential for
+ # this happening in the current LAION400m dataset if a tar ends with same prefix as the next
+ # begins, rare, but can happen since prefix aren't unique across tar files in that dataset
+ if current_sample is None or prefix != current_sample["__key__"] or suffix in current_sample:
+ if valid_sample(current_sample):
+ yield current_sample
+ current_sample = {"__key__": prefix, "__url__": filesample["__url__"]}
+ if suffixes is None or suffix in suffixes:
+ current_sample[suffix] = value
+ if valid_sample(current_sample):
+ yield current_sample
+
+
+def tarfile_to_samples_nothrow(src, handler=wds.warn_and_continue):
+ # NOTE this is a re-impl of the webdataset impl with group_by_keys that doesn't throw
+ streams = url_opener(src, handler=handler)
+ files = tar_file_expander(streams, handler=handler)
+ samples = group_by_keys_nothrow(files, handler=handler)
+ return samples
+
+
+class WebdatasetFilter:
+ def __init__(self, min_size=1024, max_pwatermark=0.5):
+ self.min_size = min_size
+ self.max_pwatermark = max_pwatermark
+
+ def __call__(self, x):
+ try:
+ if "json" in x:
+ x_json = json.loads(x["json"])
+ filter_size = (x_json.get(WDS_JSON_WIDTH, 0.0) or 0.0) >= self.min_size and x_json.get(
+ WDS_JSON_HEIGHT, 0
+ ) >= self.min_size
+ filter_watermark = (x_json.get("pwatermark", 0.0) or 0.0) <= self.max_pwatermark
+ return filter_size and filter_watermark
+ else:
+ return False
+ except Exception:
+ return False
+
+
+class SDXLText2ImageDataset:
+ def __init__(
+ self,
+ train_shards_path_or_url: Union[str, List[str]],
+ num_train_examples: int,
+ per_gpu_batch_size: int,
+ global_batch_size: int,
+ num_workers: int,
+ resolution: int = 1024,
+ interpolation_type: str = "bilinear",
+ shuffle_buffer_size: int = 1000,
+ pin_memory: bool = False,
+ persistent_workers: bool = False,
+ use_fix_crop_and_size: bool = False,
+ ):
+ if not isinstance(train_shards_path_or_url, str):
+ train_shards_path_or_url = [list(braceexpand(urls)) for urls in train_shards_path_or_url]
+ # flatten list using itertools
+ train_shards_path_or_url = list(itertools.chain.from_iterable(train_shards_path_or_url))
+
+ def get_orig_size(json):
+ if use_fix_crop_and_size:
+ return (resolution, resolution)
+ else:
+ return (int(json.get(WDS_JSON_WIDTH, 0.0)), int(json.get(WDS_JSON_HEIGHT, 0.0)))
+
+ interpolation_mode = resolve_interpolation_mode(interpolation_type)
+
+ def transform(example):
+ # resize image
+ image = example["image"]
+ image = TF.resize(image, resolution, interpolation=interpolation_mode)
+
+ # get crop coordinates and crop image
+ c_top, c_left, _, _ = transforms.RandomCrop.get_params(image, output_size=(resolution, resolution))
+ image = TF.crop(image, c_top, c_left, resolution, resolution)
+ image = TF.to_tensor(image)
+ image = TF.normalize(image, [0.5], [0.5])
+
+ example["image"] = image
+ example["crop_coords"] = (c_top, c_left) if not use_fix_crop_and_size else (0, 0)
+ return example
+
+ processing_pipeline = [
+ wds.decode("pil", handler=wds.ignore_and_continue),
+ wds.rename(
+ image="jpg;png;jpeg;webp", text="text;txt;caption", orig_size="json", handler=wds.warn_and_continue
+ ),
+ wds.map(filter_keys({"image", "text", "orig_size"})),
+ wds.map_dict(orig_size=get_orig_size),
+ wds.map(transform),
+ wds.to_tuple("image", "text", "orig_size", "crop_coords"),
+ ]
+
+ # Create train dataset and loader
+ pipeline = [
+ wds.ResampledShards(train_shards_path_or_url),
+ tarfile_to_samples_nothrow,
+ wds.select(WebdatasetFilter(min_size=MIN_SIZE)),
+ wds.shuffle(shuffle_buffer_size),
+ *processing_pipeline,
+ wds.batched(per_gpu_batch_size, partial=False, collation_fn=default_collate),
+ ]
+
+ num_worker_batches = math.ceil(num_train_examples / (global_batch_size * num_workers)) # per dataloader worker
+ num_batches = num_worker_batches * num_workers
+ num_samples = num_batches * global_batch_size
+
+ # each worker is iterating over this
+ self._train_dataset = wds.DataPipeline(*pipeline).with_epoch(num_worker_batches)
+ self._train_dataloader = wds.WebLoader(
+ self._train_dataset,
+ batch_size=None,
+ shuffle=False,
+ num_workers=num_workers,
+ pin_memory=pin_memory,
+ persistent_workers=persistent_workers,
+ )
+ # add meta-data to dataloader instance for convenience
+ self._train_dataloader.num_batches = num_batches
+ self._train_dataloader.num_samples = num_samples
+
+ @property
+ def train_dataset(self):
+ return self._train_dataset
+
+ @property
+ def train_dataloader(self):
+ return self._train_dataloader
+
+
+def log_validation(vae, unet, args, accelerator, weight_dtype, step, name="target"):
+ logger.info("Running validation... ")
+
+ unet = accelerator.unwrap_model(unet)
+ pipeline = StableDiffusionXLPipeline.from_pretrained(
+ args.pretrained_teacher_model,
+ vae=vae,
+ unet=unet,
+ scheduler=LCMScheduler.from_pretrained(args.pretrained_teacher_model, subfolder="scheduler"),
+ revision=args.revision,
+ torch_dtype=weight_dtype,
+ )
+ pipeline = pipeline.to(accelerator.device)
+ pipeline.set_progress_bar_config(disable=True)
+
+ if args.enable_xformers_memory_efficient_attention:
+ pipeline.enable_xformers_memory_efficient_attention()
+
+ if args.seed is None:
+ generator = None
+ else:
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
+
+ validation_prompts = [
+ "portrait photo of a girl, photograph, highly detailed face, depth of field, moody light, golden hour, style by Dan Winters, Russell James, Steve McCurry, centered, extremely detailed, Nikon D850, award winning photography",
+ "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k",
+ "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
+ "A photo of beautiful mountain with realistic sunset and blue lake, highly detailed, masterpiece",
+ ]
+
+ image_logs = []
+
+ for _, prompt in enumerate(validation_prompts):
+ images = []
+ if torch.backends.mps.is_available():
+ autocast_ctx = nullcontext()
+ else:
+ autocast_ctx = torch.autocast(accelerator.device.type)
+
+ with autocast_ctx:
+ images = pipeline(
+ prompt=prompt,
+ num_inference_steps=4,
+ num_images_per_prompt=4,
+ generator=generator,
+ ).images
+ image_logs.append({"validation_prompt": prompt, "images": images})
+
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ for log in image_logs:
+ images = log["images"]
+ validation_prompt = log["validation_prompt"]
+ formatted_images = []
+ for image in images:
+ formatted_images.append(np.asarray(image))
+
+ formatted_images = np.stack(formatted_images)
+
+ tracker.writer.add_images(validation_prompt, formatted_images, step, dataformats="NHWC")
+ elif tracker.name == "wandb":
+ formatted_images = []
+
+ for log in image_logs:
+ images = log["images"]
+ validation_prompt = log["validation_prompt"]
+ for image in images:
+ image = wandb.Image(image, caption=validation_prompt)
+ formatted_images.append(image)
+
+ tracker.log({f"validation/{name}": formatted_images})
+ else:
+ logger.warning(f"image logging not implemented for {tracker.name}")
+
+ del pipeline
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ return image_logs
+
+
+def append_dims(x, target_dims):
+ """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
+ dims_to_append = target_dims - x.ndim
+ if dims_to_append < 0:
+ raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less")
+ return x[(...,) + (None,) * dims_to_append]
+
+
+# From LCMScheduler.get_scalings_for_boundary_condition_discrete
+def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=10.0):
+ scaled_timestep = timestep_scaling * timestep
+ c_skip = sigma_data**2 / (scaled_timestep**2 + sigma_data**2)
+ c_out = scaled_timestep / (scaled_timestep**2 + sigma_data**2) ** 0.5
+ return c_skip, c_out
+
+
+# Compare LCMScheduler.step, Step 4
+def get_predicted_original_sample(model_output, timesteps, sample, prediction_type, alphas, sigmas):
+ alphas = extract_into_tensor(alphas, timesteps, sample.shape)
+ sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
+ if prediction_type == "epsilon":
+ pred_x_0 = (sample - sigmas * model_output) / alphas
+ elif prediction_type == "sample":
+ pred_x_0 = model_output
+ elif prediction_type == "v_prediction":
+ pred_x_0 = alphas * sample - sigmas * model_output
+ else:
+ raise ValueError(
+ f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`"
+ f" are supported."
+ )
+
+ return pred_x_0
+
+
+# Based on step 4 in DDIMScheduler.step
+def get_predicted_noise(model_output, timesteps, sample, prediction_type, alphas, sigmas):
+ alphas = extract_into_tensor(alphas, timesteps, sample.shape)
+ sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
+ if prediction_type == "epsilon":
+ pred_epsilon = model_output
+ elif prediction_type == "sample":
+ pred_epsilon = (sample - alphas * model_output) / sigmas
+ elif prediction_type == "v_prediction":
+ pred_epsilon = alphas * model_output + sigmas * sample
+ else:
+ raise ValueError(
+ f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`"
+ f" are supported."
+ )
+
+ return pred_epsilon
+
+
+def extract_into_tensor(a, t, x_shape):
+ b, *_ = t.shape
+ out = a.gather(-1, t)
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
+
+
+@torch.no_grad()
+def update_ema(target_params, source_params, rate=0.99):
+ """
+ Update target parameters to be closer to those of source parameters using
+ an exponential moving average.
+
+ :param target_params: the target parameter sequence.
+ :param source_params: the source parameter sequence.
+ :param rate: the EMA rate (closer to 1 means slower).
+ """
+ for targ, src in zip(target_params, source_params):
+ targ.detach().mul_(rate).add_(src, alpha=1 - rate)
+
+
+# From LatentConsistencyModel.get_guidance_scale_embedding
+def guidance_scale_embedding(w, embedding_dim=512, dtype=torch.float32):
+ """
+ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
+
+ Args:
+ timesteps (`torch.Tensor`):
+ generate embedding vectors at these timesteps
+ embedding_dim (`int`, *optional*, defaults to 512):
+ dimension of the embeddings to generate
+ dtype:
+ data type of the generated embeddings
+
+ Returns:
+ `torch.Tensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
+ """
+ assert len(w.shape) == 1
+ w = w * 1000.0
+
+ half_dim = embedding_dim // 2
+ emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
+ emb = w.to(dtype)[:, None] * emb[None, :]
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+ if embedding_dim % 2 == 1: # zero pad
+ emb = torch.nn.functional.pad(emb, (0, 1))
+ assert emb.shape == (w.shape[0], embedding_dim)
+ return emb
+
+
+class DDIMSolver:
+ def __init__(self, alpha_cumprods, timesteps=1000, ddim_timesteps=50):
+ # DDIM sampling parameters
+ step_ratio = timesteps // ddim_timesteps
+
+ self.ddim_timesteps = (np.arange(1, ddim_timesteps + 1) * step_ratio).round().astype(np.int64) - 1
+ self.ddim_alpha_cumprods = alpha_cumprods[self.ddim_timesteps]
+ self.ddim_alpha_cumprods_prev = np.asarray(
+ [alpha_cumprods[0]] + alpha_cumprods[self.ddim_timesteps[:-1]].tolist()
+ )
+ # convert to torch tensors
+ self.ddim_timesteps = torch.from_numpy(self.ddim_timesteps).long()
+ self.ddim_alpha_cumprods = torch.from_numpy(self.ddim_alpha_cumprods)
+ self.ddim_alpha_cumprods_prev = torch.from_numpy(self.ddim_alpha_cumprods_prev)
+
+ def to(self, device):
+ self.ddim_timesteps = self.ddim_timesteps.to(device)
+ self.ddim_alpha_cumprods = self.ddim_alpha_cumprods.to(device)
+ self.ddim_alpha_cumprods_prev = self.ddim_alpha_cumprods_prev.to(device)
+ return self
+
+ def ddim_step(self, pred_x0, pred_noise, timestep_index):
+ alpha_cumprod_prev = extract_into_tensor(self.ddim_alpha_cumprods_prev, timestep_index, pred_x0.shape)
+ dir_xt = (1.0 - alpha_cumprod_prev).sqrt() * pred_noise
+ x_prev = alpha_cumprod_prev.sqrt() * pred_x0 + dir_xt
+ return x_prev
+
+
+def import_model_class_from_model_name_or_path(
+ pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
+):
+ text_encoder_config = PretrainedConfig.from_pretrained(
+ pretrained_model_name_or_path, subfolder=subfolder, revision=revision
+ )
+ model_class = text_encoder_config.architectures[0]
+
+ if model_class == "CLIPTextModel":
+ from transformers import CLIPTextModel
+
+ return CLIPTextModel
+ elif model_class == "CLIPTextModelWithProjection":
+ from transformers import CLIPTextModelWithProjection
+
+ return CLIPTextModelWithProjection
+ else:
+ raise ValueError(f"{model_class} is not supported.")
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ # ----------Model Checkpoint Loading Arguments----------
+ parser.add_argument(
+ "--pretrained_teacher_model",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained LDM teacher model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--pretrained_vae_model_name_or_path",
+ type=str,
+ default=None,
+ help="Path to pretrained VAE model with better numerical stability. More details: https://github.com/huggingface/diffusers/pull/4038.",
+ )
+ parser.add_argument(
+ "--teacher_revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained LDM teacher model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained LDM model identifier from huggingface.co/models.",
+ )
+ # ----------Training Arguments----------
+ # ----General Training Arguments----
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="lcm-xl-distilled",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ # ----Logging----
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ # ----Checkpointing----
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ # ----Image Processing----
+ parser.add_argument(
+ "--train_shards_path_or_url",
+ type=str,
+ default=None,
+ help=(
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
+ " or to a folder containing files that 🤗 Datasets can understand."
+ ),
+ )
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=1024,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--interpolation_type",
+ type=str,
+ default="bilinear",
+ help=(
+ "The interpolation function used when resizing images to the desired resolution. Choose between `bilinear`,"
+ " `bicubic`, `box`, `nearest`, `nearest_exact`, `hamming`, and `lanczos`."
+ ),
+ )
+ parser.add_argument(
+ "--use_fix_crop_and_size",
+ action="store_true",
+ help="Whether or not to use the fixed crop and size for the teacher model.",
+ default=False,
+ )
+ parser.add_argument(
+ "--center_crop",
+ default=False,
+ action="store_true",
+ help=(
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
+ " cropped. The images will be resized to the resolution first before cropping."
+ ),
+ )
+ parser.add_argument(
+ "--random_flip",
+ action="store_true",
+ help="whether to randomly flip images horizontally",
+ )
+ # ----Dataloader----
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ # ----Batch Size and Training Steps----
+ parser.add_argument(
+ "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=100)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--max_train_samples",
+ type=int,
+ default=None,
+ help=(
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ ),
+ )
+ # ----Learning Rate----
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ # ----Optimizer (Adam)----
+ parser.add_argument(
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ # ----Diffusion Training Arguments----
+ parser.add_argument(
+ "--proportion_empty_prompts",
+ type=float,
+ default=0,
+ help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).",
+ )
+ # ----Latent Consistency Distillation (LCD) Specific Arguments----
+ parser.add_argument(
+ "--w_min",
+ type=float,
+ default=3.0,
+ required=False,
+ help=(
+ "The minimum guidance scale value for guidance scale sampling. Note that we are using the Imagen CFG"
+ " formulation rather than the LCM formulation, which means all guidance scales have 1 added to them as"
+ " compared to the original paper."
+ ),
+ )
+ parser.add_argument(
+ "--w_max",
+ type=float,
+ default=15.0,
+ required=False,
+ help=(
+ "The maximum guidance scale value for guidance scale sampling. Note that we are using the Imagen CFG"
+ " formulation rather than the LCM formulation, which means all guidance scales have 1 added to them as"
+ " compared to the original paper."
+ ),
+ )
+ parser.add_argument(
+ "--num_ddim_timesteps",
+ type=int,
+ default=50,
+ help="The number of timesteps to use for DDIM sampling.",
+ )
+ parser.add_argument(
+ "--loss_type",
+ type=str,
+ default="l2",
+ choices=["l2", "huber"],
+ help="The type of loss to use for the LCD loss.",
+ )
+ parser.add_argument(
+ "--huber_c",
+ type=float,
+ default=0.001,
+ help="The huber loss parameter. Only used if `--loss_type=huber`.",
+ )
+ parser.add_argument(
+ "--unet_time_cond_proj_dim",
+ type=int,
+ default=256,
+ help=(
+ "The dimension of the guidance scale embedding in the U-Net, which will be used if the teacher U-Net"
+ " does not have `time_cond_proj_dim` set."
+ ),
+ )
+ parser.add_argument(
+ "--vae_encode_batch_size",
+ type=int,
+ default=8,
+ required=False,
+ help=(
+ "The batch size used when encoding (and decoding) images to latents (and vice versa) using the VAE."
+ " Encoding or decoding the whole batch at once may run into OOM issues."
+ ),
+ )
+ parser.add_argument(
+ "--timestep_scaling_factor",
+ type=float,
+ default=10.0,
+ help=(
+ "The multiplicative timestep scaling factor used when calculating the boundary scalings for LCM. The"
+ " higher the scaling is, the lower the approximation error, but the default value of 10.0 should typically"
+ " suffice."
+ ),
+ )
+ # ----Exponential Moving Average (EMA)----
+ parser.add_argument(
+ "--ema_decay",
+ type=float,
+ default=0.95,
+ required=False,
+ help="The exponential moving average (EMA) rate or decay factor.",
+ )
+ # ----Mixed Precision----
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument(
+ "--cast_teacher_unet",
+ action="store_true",
+ help="Whether to cast the teacher U-Net to the precision specified by `--mixed_precision`.",
+ )
+ # ----Training Optimizations----
+ parser.add_argument(
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ # ----Distributed Training----
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+ # ----------Validation Arguments----------
+ parser.add_argument(
+ "--validation_steps",
+ type=int,
+ default=200,
+ help="Run validation every X steps.",
+ )
+ # ----------Huggingface Hub Arguments-----------
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ # ----------Accelerate Arguments----------
+ parser.add_argument(
+ "--tracker_project_name",
+ type=str,
+ default="text2image-fine-tune",
+ help=(
+ "The `project_name` argument passed to Accelerator.init_trackers for"
+ " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
+ ),
+ )
+
+ args = parser.parse_args()
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1:
+ raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].")
+
+ return args
+
+
+# Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt
+def encode_prompt(prompt_batch, text_encoders, tokenizers, proportion_empty_prompts, is_train=True):
+ prompt_embeds_list = []
+
+ captions = []
+ for caption in prompt_batch:
+ if random.random() < proportion_empty_prompts:
+ captions.append("")
+ elif isinstance(caption, str):
+ captions.append(caption)
+ elif isinstance(caption, (list, np.ndarray)):
+ # take a random caption if there are multiple
+ captions.append(random.choice(caption) if is_train else caption[0])
+
+ with torch.no_grad():
+ for tokenizer, text_encoder in zip(tokenizers, text_encoders):
+ text_inputs = tokenizer(
+ captions,
+ padding="max_length",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ prompt_embeds = text_encoder(
+ text_input_ids.to(text_encoder.device),
+ output_hidden_states=True,
+ )
+
+ # We are only ALWAYS interested in the pooled output of the final text encoder
+ pooled_prompt_embeds = prompt_embeds[0]
+ prompt_embeds = prompt_embeds.hidden_states[-2]
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
+ prompt_embeds_list.append(prompt_embeds)
+
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
+ pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
+ return prompt_embeds, pooled_prompt_embeds
+
+
+def main(args):
+ if args.report_to == "wandb" and args.hub_token is not None:
+ raise ValueError(
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
+ " Please use `huggingface-cli login` to authenticate with the Hub."
+ )
+
+ logging_dir = Path(args.output_dir, args.logging_dir)
+
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ split_batches=True, # It's important to set this to True when using webdataset to get the right number of steps for lr scheduling. If set to False, the number of steps will be devide by the number of processes assuming batches are multiplied by the number of processes
+ )
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name,
+ exist_ok=True,
+ token=args.hub_token,
+ private=True,
+ ).repo_id
+
+ # 1. Create the noise scheduler and the desired noise schedule.
+ noise_scheduler = DDPMScheduler.from_pretrained(
+ args.pretrained_teacher_model, subfolder="scheduler", revision=args.teacher_revision
+ )
+
+ # DDPMScheduler calculates the alpha and sigma noise schedules (based on the alpha bars) for us
+ alpha_schedule = torch.sqrt(noise_scheduler.alphas_cumprod)
+ sigma_schedule = torch.sqrt(1 - noise_scheduler.alphas_cumprod)
+ # Initialize the DDIM ODE solver for distillation.
+ solver = DDIMSolver(
+ noise_scheduler.alphas_cumprod.numpy(),
+ timesteps=noise_scheduler.config.num_train_timesteps,
+ ddim_timesteps=args.num_ddim_timesteps,
+ )
+
+ # 2. Load tokenizers from SD-XL checkpoint.
+ tokenizer_one = AutoTokenizer.from_pretrained(
+ args.pretrained_teacher_model, subfolder="tokenizer", revision=args.teacher_revision, use_fast=False
+ )
+ tokenizer_two = AutoTokenizer.from_pretrained(
+ args.pretrained_teacher_model, subfolder="tokenizer_2", revision=args.teacher_revision, use_fast=False
+ )
+
+ # 3. Load text encoders from SD-XL checkpoint.
+ # import correct text encoder classes
+ text_encoder_cls_one = import_model_class_from_model_name_or_path(
+ args.pretrained_teacher_model, args.teacher_revision
+ )
+ text_encoder_cls_two = import_model_class_from_model_name_or_path(
+ args.pretrained_teacher_model, args.teacher_revision, subfolder="text_encoder_2"
+ )
+
+ text_encoder_one = text_encoder_cls_one.from_pretrained(
+ args.pretrained_teacher_model, subfolder="text_encoder", revision=args.teacher_revision
+ )
+ text_encoder_two = text_encoder_cls_two.from_pretrained(
+ args.pretrained_teacher_model, subfolder="text_encoder_2", revision=args.teacher_revision
+ )
+
+ # 4. Load VAE from SD-XL checkpoint (or more stable VAE)
+ vae_path = (
+ args.pretrained_teacher_model
+ if args.pretrained_vae_model_name_or_path is None
+ else args.pretrained_vae_model_name_or_path
+ )
+ vae = AutoencoderKL.from_pretrained(
+ vae_path,
+ subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
+ revision=args.teacher_revision,
+ )
+
+ # 5. Load teacher U-Net from SD-XL checkpoint
+ teacher_unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_teacher_model, subfolder="unet", revision=args.teacher_revision
+ )
+
+ # 6. Freeze teacher vae, text_encoders, and teacher_unet
+ vae.requires_grad_(False)
+ text_encoder_one.requires_grad_(False)
+ text_encoder_two.requires_grad_(False)
+ teacher_unet.requires_grad_(False)
+
+ # 7. Create online student U-Net. This will be updated by the optimizer (e.g. via backpropagation.)
+ # Add `time_cond_proj_dim` to the student U-Net if `teacher_unet.config.time_cond_proj_dim` is None
+ time_cond_proj_dim = (
+ teacher_unet.config.time_cond_proj_dim
+ if teacher_unet.config.time_cond_proj_dim is not None
+ else args.unet_time_cond_proj_dim
+ )
+ unet = UNet2DConditionModel.from_config(teacher_unet.config, time_cond_proj_dim=time_cond_proj_dim)
+ # load teacher_unet weights into unet
+ unet.load_state_dict(teacher_unet.state_dict(), strict=False)
+ unet.train()
+
+ # 8. Create target student U-Net. This will be updated via EMA updates (polyak averaging).
+ # Initialize from (online) unet
+ target_unet = UNet2DConditionModel.from_config(unet.config)
+ target_unet.load_state_dict(unet.state_dict())
+ target_unet.train()
+ target_unet.requires_grad_(False)
+
+ # Check that all trainable models are in full precision
+ low_precision_error_string = (
+ " Please make sure to always have all model weights in full float32 precision when starting training - even if"
+ " doing mixed precision training, copy of the weights should still be float32."
+ )
+
+ if accelerator.unwrap_model(unet).dtype != torch.float32:
+ raise ValueError(
+ f"Controlnet loaded as datatype {accelerator.unwrap_model(unet).dtype}. {low_precision_error_string}"
+ )
+
+ # 9. Handle mixed precision and device placement
+ # For mixed precision training we cast all non-trainable weigths to half-precision
+ # as these weights are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ # Move unet, vae and text_encoder to device and cast to weight_dtype
+ # The VAE is in float32 to avoid NaN losses.
+ vae.to(accelerator.device)
+ if args.pretrained_vae_model_name_or_path is not None:
+ vae.to(dtype=weight_dtype)
+ text_encoder_one.to(accelerator.device, dtype=weight_dtype)
+ text_encoder_two.to(accelerator.device, dtype=weight_dtype)
+ target_unet.to(accelerator.device)
+ # Move teacher_unet to device, optionally cast to weight_dtype
+ teacher_unet.to(accelerator.device)
+ if args.cast_teacher_unet:
+ teacher_unet.to(dtype=weight_dtype)
+
+ # Also move the alpha and sigma noise schedules to accelerator.device.
+ alpha_schedule = alpha_schedule.to(accelerator.device)
+ sigma_schedule = sigma_schedule.to(accelerator.device)
+ # Move the ODE solver to accelerator.device.
+ solver = solver.to(accelerator.device)
+
+ # 10. Handle saving and loading of checkpoints
+ # `accelerate` 0.16.0 will have better support for customized saving
+ if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
+ def save_model_hook(models, weights, output_dir):
+ if accelerator.is_main_process:
+ target_unet.save_pretrained(os.path.join(output_dir, "unet_target"))
+
+ for i, model in enumerate(models):
+ model.save_pretrained(os.path.join(output_dir, "unet"))
+
+ # make sure to pop weight so that corresponding model is not saved again
+ weights.pop()
+
+ def load_model_hook(models, input_dir):
+ load_model = UNet2DConditionModel.from_pretrained(os.path.join(input_dir, "unet_target"))
+ target_unet.load_state_dict(load_model.state_dict())
+ target_unet.to(accelerator.device)
+ del load_model
+
+ for i in range(len(models)):
+ # pop models so that they are not loaded again
+ model = models.pop()
+
+ # load diffusers style into model
+ load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet")
+ model.register_to_config(**load_model.config)
+
+ model.load_state_dict(load_model.state_dict())
+ del load_model
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
+
+ # 11. Enable optimizations
+ if args.enable_xformers_memory_efficient_attention:
+ if is_xformers_available():
+ import xformers
+
+ xformers_version = version.parse(xformers.__version__)
+ if xformers_version == version.parse("0.0.16"):
+ logger.warning(
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
+ )
+ unet.enable_xformers_memory_efficient_attention()
+ teacher_unet.enable_xformers_memory_efficient_attention()
+ target_unet.enable_xformers_memory_efficient_attention()
+ else:
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.gradient_checkpointing:
+ unet.enable_gradient_checkpointing()
+
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
+ )
+
+ optimizer_class = bnb.optim.AdamW8bit
+ else:
+ optimizer_class = torch.optim.AdamW
+
+ # 12. Optimizer creation
+ optimizer = optimizer_class(
+ unet.parameters(),
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ # 13. Dataset creation and data processing
+ # Here, we compute not just the text embeddings but also the additional embeddings
+ # needed for the SD XL UNet to operate.
+ def compute_embeddings(
+ prompt_batch, original_sizes, crop_coords, proportion_empty_prompts, text_encoders, tokenizers, is_train=True
+ ):
+ target_size = (args.resolution, args.resolution)
+ original_sizes = list(map(list, zip(*original_sizes)))
+ crops_coords_top_left = list(map(list, zip(*crop_coords)))
+
+ original_sizes = torch.tensor(original_sizes, dtype=torch.long)
+ crops_coords_top_left = torch.tensor(crops_coords_top_left, dtype=torch.long)
+
+ prompt_embeds, pooled_prompt_embeds = encode_prompt(
+ prompt_batch, text_encoders, tokenizers, proportion_empty_prompts, is_train
+ )
+ add_text_embeds = pooled_prompt_embeds
+
+ # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
+ add_time_ids = list(target_size)
+ add_time_ids = torch.tensor([add_time_ids])
+ add_time_ids = add_time_ids.repeat(len(prompt_batch), 1)
+ add_time_ids = torch.cat([original_sizes, crops_coords_top_left, add_time_ids], dim=-1)
+ add_time_ids = add_time_ids.to(accelerator.device, dtype=prompt_embeds.dtype)
+
+ prompt_embeds = prompt_embeds.to(accelerator.device)
+ add_text_embeds = add_text_embeds.to(accelerator.device)
+ unet_added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
+
+ return {"prompt_embeds": prompt_embeds, **unet_added_cond_kwargs}
+
+ dataset = SDXLText2ImageDataset(
+ train_shards_path_or_url=args.train_shards_path_or_url,
+ num_train_examples=args.max_train_samples,
+ per_gpu_batch_size=args.train_batch_size,
+ global_batch_size=args.train_batch_size * accelerator.num_processes,
+ num_workers=args.dataloader_num_workers,
+ resolution=args.resolution,
+ interpolation_type=args.interpolation_type,
+ shuffle_buffer_size=1000,
+ pin_memory=True,
+ persistent_workers=True,
+ use_fix_crop_and_size=args.use_fix_crop_and_size,
+ )
+ train_dataloader = dataset.train_dataloader
+
+ # Let's first compute all the embeddings so that we can free up the text encoders
+ # from memory.
+ text_encoders = [text_encoder_one, text_encoder_two]
+ tokenizers = [tokenizer_one, tokenizer_two]
+
+ compute_embeddings_fn = functools.partial(
+ compute_embeddings,
+ proportion_empty_prompts=0,
+ text_encoders=text_encoders,
+ tokenizers=tokenizers,
+ )
+
+ # 14. LR Scheduler creation
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(train_dataloader.num_batches / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps,
+ num_training_steps=args.max_train_steps,
+ )
+
+ # 15. Prepare for training
+ # Prepare everything with our `accelerator`.
+ unet, optimizer, lr_scheduler = accelerator.prepare(unet, optimizer, lr_scheduler)
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(train_dataloader.num_batches / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ tracker_config = dict(vars(args))
+ accelerator.init_trackers(args.tracker_project_name, config=tracker_config)
+
+ # Create uncond embeds for classifier free guidance
+ uncond_prompt_embeds = torch.zeros(args.train_batch_size, 77, 2048).to(accelerator.device)
+ uncond_pooled_prompt_embeds = torch.zeros(args.train_batch_size, 1280).to(accelerator.device)
+
+ # 16. Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num batches each epoch = {train_dataloader.num_batches}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the most recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ initial_global_step = 0
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ initial_global_step = global_step
+ first_epoch = global_step // num_update_steps_per_epoch
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
+
+ for epoch in range(first_epoch, args.num_train_epochs):
+ for step, batch in enumerate(train_dataloader):
+ with accelerator.accumulate(unet):
+ # 1. Load and process the image, text, and micro-conditioning (original image size, crop coordinates)
+ image, text, orig_size, crop_coords = batch
+
+ image = image.to(accelerator.device, non_blocking=True)
+ encoded_text = compute_embeddings_fn(text, orig_size, crop_coords)
+
+ if args.pretrained_vae_model_name_or_path is not None:
+ pixel_values = image.to(dtype=weight_dtype)
+ if vae.dtype != weight_dtype:
+ vae.to(dtype=weight_dtype)
+ else:
+ pixel_values = image
+
+ # encode pixel values with batch size of at most args.vae_encode_batch_size
+ latents = []
+ for i in range(0, pixel_values.shape[0], args.vae_encode_batch_size):
+ latents.append(vae.encode(pixel_values[i : i + args.vae_encode_batch_size]).latent_dist.sample())
+ latents = torch.cat(latents, dim=0)
+
+ latents = latents * vae.config.scaling_factor
+ if args.pretrained_vae_model_name_or_path is None:
+ latents = latents.to(weight_dtype)
+ bsz = latents.shape[0]
+
+ # 2. Sample a random timestep for each image t_n from the ODE solver timesteps without bias.
+ # For the DDIM solver, the timestep schedule is [T - 1, T - k - 1, T - 2 * k - 1, ...]
+ topk = noise_scheduler.config.num_train_timesteps // args.num_ddim_timesteps
+ index = torch.randint(0, args.num_ddim_timesteps, (bsz,), device=latents.device).long()
+ start_timesteps = solver.ddim_timesteps[index]
+ timesteps = start_timesteps - topk
+ timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps)
+
+ # 3. Get boundary scalings for start_timesteps and (end) timesteps.
+ c_skip_start, c_out_start = scalings_for_boundary_conditions(
+ start_timesteps, timestep_scaling=args.timestep_scaling_factor
+ )
+ c_skip_start, c_out_start = [append_dims(x, latents.ndim) for x in [c_skip_start, c_out_start]]
+ c_skip, c_out = scalings_for_boundary_conditions(
+ timesteps, timestep_scaling=args.timestep_scaling_factor
+ )
+ c_skip, c_out = [append_dims(x, latents.ndim) for x in [c_skip, c_out]]
+
+ # 4. Sample noise from the prior and add it to the latents according to the noise magnitude at each
+ # timestep (this is the forward diffusion process) [z_{t_{n + k}} in Algorithm 1]
+ noise = torch.randn_like(latents)
+ noisy_model_input = noise_scheduler.add_noise(latents, noise, start_timesteps)
+
+ # 5. Sample a random guidance scale w from U[w_min, w_max] and embed it
+ w = (args.w_max - args.w_min) * torch.rand((bsz,)) + args.w_min
+ w_embedding = guidance_scale_embedding(w, embedding_dim=time_cond_proj_dim)
+ w = w.reshape(bsz, 1, 1, 1)
+ # Move to U-Net device and dtype
+ w = w.to(device=latents.device, dtype=latents.dtype)
+ w_embedding = w_embedding.to(device=latents.device, dtype=latents.dtype)
+
+ # 6. Prepare prompt embeds and unet_added_conditions
+ prompt_embeds = encoded_text.pop("prompt_embeds")
+
+ # 7. Get online LCM prediction on z_{t_{n + k}} (noisy_model_input), w, c, t_{n + k} (start_timesteps)
+ noise_pred = unet(
+ noisy_model_input,
+ start_timesteps,
+ timestep_cond=w_embedding,
+ encoder_hidden_states=prompt_embeds.float(),
+ added_cond_kwargs=encoded_text,
+ ).sample
+
+ pred_x_0 = get_predicted_original_sample(
+ noise_pred,
+ start_timesteps,
+ noisy_model_input,
+ noise_scheduler.config.prediction_type,
+ alpha_schedule,
+ sigma_schedule,
+ )
+
+ model_pred = c_skip_start * noisy_model_input + c_out_start * pred_x_0
+
+ # 8. Compute the conditional and unconditional teacher model predictions to get CFG estimates of the
+ # predicted noise eps_0 and predicted original sample x_0, then run the ODE solver using these
+ # estimates to predict the data point in the augmented PF-ODE trajectory corresponding to the next ODE
+ # solver timestep.
+ with torch.no_grad():
+ if torch.backends.mps.is_available():
+ autocast_ctx = nullcontext()
+ else:
+ autocast_ctx = torch.autocast(accelerator.device.type)
+
+ with autocast_ctx:
+ # 1. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and conditional embedding c
+ cond_teacher_output = teacher_unet(
+ noisy_model_input.to(weight_dtype),
+ start_timesteps,
+ encoder_hidden_states=prompt_embeds.to(weight_dtype),
+ added_cond_kwargs={k: v.to(weight_dtype) for k, v in encoded_text.items()},
+ ).sample
+ cond_pred_x0 = get_predicted_original_sample(
+ cond_teacher_output,
+ start_timesteps,
+ noisy_model_input,
+ noise_scheduler.config.prediction_type,
+ alpha_schedule,
+ sigma_schedule,
+ )
+ cond_pred_noise = get_predicted_noise(
+ cond_teacher_output,
+ start_timesteps,
+ noisy_model_input,
+ noise_scheduler.config.prediction_type,
+ alpha_schedule,
+ sigma_schedule,
+ )
+
+ # 2. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and unconditional embedding 0
+ uncond_added_conditions = copy.deepcopy(encoded_text)
+ uncond_added_conditions["text_embeds"] = uncond_pooled_prompt_embeds
+ uncond_teacher_output = teacher_unet(
+ noisy_model_input.to(weight_dtype),
+ start_timesteps,
+ encoder_hidden_states=uncond_prompt_embeds.to(weight_dtype),
+ added_cond_kwargs={k: v.to(weight_dtype) for k, v in uncond_added_conditions.items()},
+ ).sample
+ uncond_pred_x0 = get_predicted_original_sample(
+ uncond_teacher_output,
+ start_timesteps,
+ noisy_model_input,
+ noise_scheduler.config.prediction_type,
+ alpha_schedule,
+ sigma_schedule,
+ )
+ uncond_pred_noise = get_predicted_noise(
+ uncond_teacher_output,
+ start_timesteps,
+ noisy_model_input,
+ noise_scheduler.config.prediction_type,
+ alpha_schedule,
+ sigma_schedule,
+ )
+
+ # 3. Calculate the CFG estimate of x_0 (pred_x0) and eps_0 (pred_noise)
+ # Note that this uses the LCM paper's CFG formulation rather than the Imagen CFG formulation
+ pred_x0 = cond_pred_x0 + w * (cond_pred_x0 - uncond_pred_x0)
+ pred_noise = cond_pred_noise + w * (cond_pred_noise - uncond_pred_noise)
+ # 4. Run one step of the ODE solver to estimate the next point x_prev on the
+ # augmented PF-ODE trajectory (solving backward in time)
+ # Note that the DDIM step depends on both the predicted x_0 and source noise eps_0.
+ x_prev = solver.ddim_step(pred_x0, pred_noise, index)
+
+ # 9. Get target LCM prediction on x_prev, w, c, t_n (timesteps)
+ with torch.no_grad():
+ if torch.backends.mps.is_available():
+ autocast_ctx = nullcontext()
+ else:
+ autocast_ctx = torch.autocast(accelerator.device.type, dtype=weight_dtype)
+
+ with autocast_ctx:
+ target_noise_pred = target_unet(
+ x_prev.float(),
+ timesteps,
+ timestep_cond=w_embedding,
+ encoder_hidden_states=prompt_embeds.float(),
+ added_cond_kwargs=encoded_text,
+ ).sample
+ pred_x_0 = get_predicted_original_sample(
+ target_noise_pred,
+ timesteps,
+ x_prev,
+ noise_scheduler.config.prediction_type,
+ alpha_schedule,
+ sigma_schedule,
+ )
+ target = c_skip * x_prev + c_out * pred_x_0
+
+ # 10. Calculate loss
+ if args.loss_type == "l2":
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
+ elif args.loss_type == "huber":
+ loss = torch.mean(
+ torch.sqrt((model_pred.float() - target.float()) ** 2 + args.huber_c**2) - args.huber_c
+ )
+
+ # 11. Backpropagate on the online student model (`unet`)
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad(set_to_none=True)
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ # 12. Make EMA update to target student model parameters (`target_unet`)
+ update_ema(target_unet.parameters(), unet.parameters(), args.ema_decay)
+ progress_bar.update(1)
+ global_step += 1
+
+ if accelerator.is_main_process:
+ if global_step % args.checkpointing_steps == 0:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ if global_step % args.validation_steps == 0:
+ log_validation(vae, target_unet, args, accelerator, weight_dtype, global_step, "target")
+ log_validation(vae, unet, args, accelerator, weight_dtype, global_step, "online")
+
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+ accelerator.log(logs, step=global_step)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ # Create the pipeline using using the trained modules and save it.
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ unet = accelerator.unwrap_model(unet)
+ unet.save_pretrained(os.path.join(args.output_dir, "unet"))
+
+ target_unet = accelerator.unwrap_model(target_unet)
+ target_unet.save_pretrained(os.path.join(args.output_dir, "unet_target"))
+
+ if args.push_to_hub:
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ main(args)
diff --git a/diffusers/examples/controlnet/README.md b/diffusers/examples/controlnet/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..0555857b773834a14462d3a09e3e17866b713663
--- /dev/null
+++ b/diffusers/examples/controlnet/README.md
@@ -0,0 +1,465 @@
+# ControlNet training example
+
+[Adding Conditional Control to Text-to-Image Diffusion Models](https://arxiv.org/abs/2302.05543) by Lvmin Zhang and Maneesh Agrawala.
+
+This example is based on the [training example in the original ControlNet repository](https://github.com/lllyasviel/ControlNet/blob/main/docs/train.md). It trains a ControlNet to fill circles using a [small synthetic dataset](https://huggingface.co/datasets/fusing/fill50k).
+
+## Installing the dependencies
+
+Before running the scripts, make sure to install the library's training dependencies:
+
+**Important**
+
+To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
+```bash
+git clone https://github.com/huggingface/diffusers
+cd diffusers
+pip install -e .
+```
+
+Then cd in the example folder and run
+```bash
+pip install -r requirements.txt
+```
+
+And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
+
+```bash
+accelerate config
+```
+
+Or for a default accelerate configuration without answering questions about your environment
+
+```bash
+accelerate config default
+```
+
+Or if your environment doesn't support an interactive shell e.g. a notebook
+
+```python
+from accelerate.utils import write_basic_config
+write_basic_config()
+```
+
+## Circle filling dataset
+
+The original dataset is hosted in the [ControlNet repo](https://huggingface.co/lllyasviel/ControlNet/blob/main/training/fill50k.zip). We re-uploaded it to be compatible with `datasets` [here](https://huggingface.co/datasets/fusing/fill50k). Note that `datasets` handles dataloading within the training script.
+
+Our training examples use [Stable Diffusion 1.5](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) as the original set of ControlNet models were trained from it. However, ControlNet can be trained to augment any Stable Diffusion compatible model (such as [CompVis/stable-diffusion-v1-4](https://huggingface.co/CompVis/stable-diffusion-v1-4)) or [stabilityai/stable-diffusion-2-1](https://huggingface.co/stabilityai/stable-diffusion-2-1).
+
+## Training
+
+Our training examples use two test conditioning images. They can be downloaded by running
+
+```sh
+wget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_1.png
+
+wget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_2.png
+```
+
+
+```bash
+export MODEL_DIR="stable-diffusion-v1-5/stable-diffusion-v1-5"
+export OUTPUT_DIR="path to save model"
+
+accelerate launch train_controlnet.py \
+ --pretrained_model_name_or_path=$MODEL_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --dataset_name=fusing/fill50k \
+ --resolution=512 \
+ --learning_rate=1e-5 \
+ --validation_image "./conditioning_image_1.png" "./conditioning_image_2.png" \
+ --validation_prompt "red circle with blue background" "cyan circle with brown floral background" \
+ --train_batch_size=4
+```
+
+This default configuration requires ~38GB VRAM.
+
+By default, the training script logs outputs to tensorboard. Pass `--report_to wandb` to use weights and
+biases.
+
+Gradient accumulation with a smaller batch size can be used to reduce training requirements to ~20 GB VRAM.
+
+```bash
+export MODEL_DIR="stable-diffusion-v1-5/stable-diffusion-v1-5"
+export OUTPUT_DIR="path to save model"
+
+accelerate launch train_controlnet.py \
+ --pretrained_model_name_or_path=$MODEL_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --dataset_name=fusing/fill50k \
+ --resolution=512 \
+ --learning_rate=1e-5 \
+ --validation_image "./conditioning_image_1.png" "./conditioning_image_2.png" \
+ --validation_prompt "red circle with blue background" "cyan circle with brown floral background" \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=4
+```
+
+## Training with multiple GPUs
+
+`accelerate` allows for seamless multi-GPU training. Follow the instructions [here](https://huggingface.co/docs/accelerate/basic_tutorials/launch)
+for running distributed training with `accelerate`. Here is an example command:
+
+```bash
+export MODEL_DIR="stable-diffusion-v1-5/stable-diffusion-v1-5"
+export OUTPUT_DIR="path to save model"
+
+accelerate launch --mixed_precision="fp16" --multi_gpu train_controlnet.py \
+ --pretrained_model_name_or_path=$MODEL_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --dataset_name=fusing/fill50k \
+ --resolution=512 \
+ --learning_rate=1e-5 \
+ --validation_image "./conditioning_image_1.png" "./conditioning_image_2.png" \
+ --validation_prompt "red circle with blue background" "cyan circle with brown floral background" \
+ --train_batch_size=4 \
+ --mixed_precision="fp16" \
+ --tracker_project_name="controlnet-demo" \
+ --report_to=wandb
+```
+
+## Example results
+
+#### After 300 steps with batch size 8
+
+| | |
+|-------------------|:-------------------------:|
+| | red circle with blue background |
+ |  |
+| | cyan circle with brown floral background |
+ |  |
+
+
+#### After 6000 steps with batch size 8:
+
+| | |
+|-------------------|:-------------------------:|
+| | red circle with blue background |
+ |  |
+| | cyan circle with brown floral background |
+ |  |
+
+## Training on a 16 GB GPU
+
+Optimizations:
+- Gradient checkpointing
+- bitsandbyte's 8-bit optimizer
+
+[bitandbytes install instructions](https://github.com/TimDettmers/bitsandbytes#requirements--installation).
+
+```bash
+export MODEL_DIR="stable-diffusion-v1-5/stable-diffusion-v1-5"
+export OUTPUT_DIR="path to save model"
+
+accelerate launch train_controlnet.py \
+ --pretrained_model_name_or_path=$MODEL_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --dataset_name=fusing/fill50k \
+ --resolution=512 \
+ --learning_rate=1e-5 \
+ --validation_image "./conditioning_image_1.png" "./conditioning_image_2.png" \
+ --validation_prompt "red circle with blue background" "cyan circle with brown floral background" \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=4 \
+ --gradient_checkpointing \
+ --use_8bit_adam
+```
+
+## Training on a 12 GB GPU
+
+Optimizations:
+- Gradient checkpointing
+- bitsandbyte's 8-bit optimizer
+- xformers
+- set grads to none
+
+```bash
+export MODEL_DIR="stable-diffusion-v1-5/stable-diffusion-v1-5"
+export OUTPUT_DIR="path to save model"
+
+accelerate launch train_controlnet.py \
+ --pretrained_model_name_or_path=$MODEL_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --dataset_name=fusing/fill50k \
+ --resolution=512 \
+ --learning_rate=1e-5 \
+ --validation_image "./conditioning_image_1.png" "./conditioning_image_2.png" \
+ --validation_prompt "red circle with blue background" "cyan circle with brown floral background" \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=4 \
+ --gradient_checkpointing \
+ --use_8bit_adam \
+ --enable_xformers_memory_efficient_attention \
+ --set_grads_to_none
+```
+
+When using `enable_xformers_memory_efficient_attention`, please make sure to install `xformers` by `pip install xformers`.
+
+## Training on an 8 GB GPU
+
+We have not exhaustively tested DeepSpeed support for ControlNet. While the configuration does
+save memory, we have not confirmed the configuration to train successfully. You will very likely
+have to make changes to the config to have a successful training run.
+
+Optimizations:
+- Gradient checkpointing
+- xformers
+- set grads to none
+- DeepSpeed stage 2 with parameter and optimizer offloading
+- fp16 mixed precision
+
+[DeepSpeed](https://www.deepspeed.ai/) can offload tensors from VRAM to either
+CPU or NVME. This requires significantly more RAM (about 25 GB).
+
+Use `accelerate config` to enable DeepSpeed stage 2.
+
+The relevant parts of the resulting accelerate config file are
+
+```yaml
+compute_environment: LOCAL_MACHINE
+deepspeed_config:
+ gradient_accumulation_steps: 4
+ offload_optimizer_device: cpu
+ offload_param_device: cpu
+ zero3_init_flag: false
+ zero_stage: 2
+distributed_type: DEEPSPEED
+```
+
+See [documentation](https://huggingface.co/docs/accelerate/usage_guides/deepspeed) for more DeepSpeed configuration options.
+
+Changing the default Adam optimizer to DeepSpeed's Adam
+`deepspeed.ops.adam.DeepSpeedCPUAdam` gives a substantial speedup but
+it requires CUDA toolchain with the same version as pytorch. 8-bit optimizer
+does not seem to be compatible with DeepSpeed at the moment.
+
+```bash
+export MODEL_DIR="stable-diffusion-v1-5/stable-diffusion-v1-5"
+export OUTPUT_DIR="path to save model"
+
+accelerate launch train_controlnet.py \
+ --pretrained_model_name_or_path=$MODEL_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --dataset_name=fusing/fill50k \
+ --resolution=512 \
+ --validation_image "./conditioning_image_1.png" "./conditioning_image_2.png" \
+ --validation_prompt "red circle with blue background" "cyan circle with brown floral background" \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=4 \
+ --gradient_checkpointing \
+ --enable_xformers_memory_efficient_attention \
+ --set_grads_to_none \
+ --mixed_precision fp16
+```
+
+## Performing inference with the trained ControlNet
+
+The trained model can be run the same as the original ControlNet pipeline with the newly trained ControlNet.
+Set `base_model_path` and `controlnet_path` to the values `--pretrained_model_name_or_path` and
+`--output_dir` were respectively set to in the training script.
+
+```py
+from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler
+from diffusers.utils import load_image
+import torch
+
+base_model_path = "path to model"
+controlnet_path = "path to controlnet"
+
+controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16)
+pipe = StableDiffusionControlNetPipeline.from_pretrained(
+ base_model_path, controlnet=controlnet, torch_dtype=torch.float16
+)
+
+# speed up diffusion process with faster scheduler and memory optimization
+pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
+# remove following line if xformers is not installed or when using Torch 2.0.
+pipe.enable_xformers_memory_efficient_attention()
+# memory optimization.
+pipe.enable_model_cpu_offload()
+
+control_image = load_image("./conditioning_image_1.png")
+prompt = "pale golden rod circle with old lace background"
+
+# generate image
+generator = torch.manual_seed(0)
+image = pipe(
+ prompt, num_inference_steps=20, generator=generator, image=control_image
+).images[0]
+image.save("./output.png")
+```
+
+## Training with Flax/JAX
+
+For faster training on TPUs and GPUs you can leverage the flax training example. Follow the instructions above to get the model and dataset before running the script.
+
+### Running on Google Cloud TPU
+
+See below for commands to set up a TPU VM(`--accelerator-type v4-8`). For more details about how to set up and use TPUs, refer to [Cloud docs for single VM setup](https://cloud.google.com/tpu/docs/run-calculation-jax).
+
+First create a single TPUv4-8 VM and connect to it:
+
+```
+ZONE=us-central2-b
+TPU_TYPE=v4-8
+VM_NAME=hg_flax
+
+gcloud alpha compute tpus tpu-vm create $VM_NAME \
+ --zone $ZONE \
+ --accelerator-type $TPU_TYPE \
+ --version tpu-vm-v4-base
+
+gcloud alpha compute tpus tpu-vm ssh $VM_NAME --zone $ZONE -- \
+```
+
+When connected install JAX `0.4.5`:
+
+```sh
+pip install "jax[tpu]==0.4.5" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
+```
+
+To verify that JAX was correctly installed, you can run the following command:
+
+```py
+import jax
+jax.device_count()
+```
+
+This should display the number of TPU cores, which should be 4 on a TPUv4-8 VM.
+
+Then install Diffusers and the library's training dependencies:
+
+```bash
+git clone https://github.com/huggingface/diffusers
+cd diffusers
+pip install .
+```
+
+Then cd in the example folder and run
+
+```bash
+pip install -U -r requirements_flax.txt
+```
+
+If you want to use Weights and Biases logging, you should also install `wandb` now
+
+```bash
+pip install wandb
+```
+
+
+Now let's downloading two conditioning images that we will use to run validation during the training in order to track our progress
+
+```sh
+wget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_1.png
+wget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_2.png
+```
+
+We encourage you to store or share your model with the community. To use huggingface hub, please login to your Hugging Face account, or ([create one](https://huggingface.co/docs/diffusers/main/en/training/hf.co/join) if you don’t have one already):
+
+```sh
+huggingface-cli login
+```
+
+Make sure you have the `MODEL_DIR`,`OUTPUT_DIR` and `HUB_MODEL_ID` environment variables set. The `OUTPUT_DIR` and `HUB_MODEL_ID` variables specify where to save the model to on the Hub:
+
+```bash
+export MODEL_DIR="stable-diffusion-v1-5/stable-diffusion-v1-5"
+export OUTPUT_DIR="runs/fill-circle-{timestamp}"
+export HUB_MODEL_ID="controlnet-fill-circle"
+```
+
+And finally start the training
+
+```bash
+python3 train_controlnet_flax.py \
+ --pretrained_model_name_or_path=$MODEL_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --dataset_name=fusing/fill50k \
+ --resolution=512 \
+ --learning_rate=1e-5 \
+ --validation_image "./conditioning_image_1.png" "./conditioning_image_2.png" \
+ --validation_prompt "red circle with blue background" "cyan circle with brown floral background" \
+ --validation_steps=1000 \
+ --train_batch_size=2 \
+ --revision="non-ema" \
+ --from_pt \
+ --report_to="wandb" \
+ --tracker_project_name=$HUB_MODEL_ID \
+ --num_train_epochs=11 \
+ --push_to_hub \
+ --hub_model_id=$HUB_MODEL_ID
+ ```
+
+Since we passed the `--push_to_hub` flag, it will automatically create a model repo under your huggingface account based on `$HUB_MODEL_ID`. By the end of training, the final checkpoint will be automatically stored on the hub. You can find an example model repo [here](https://huggingface.co/YiYiXu/fill-circle-controlnet).
+
+Our training script also provides limited support for streaming large datasets from the Hugging Face Hub. In order to enable streaming, one must also set `--max_train_samples`. Here is an example command (from [this blog article](https://huggingface.co/blog/train-your-controlnet)):
+
+```bash
+export MODEL_DIR="stable-diffusion-v1-5/stable-diffusion-v1-5"
+export OUTPUT_DIR="runs/uncanny-faces-{timestamp}"
+export HUB_MODEL_ID="controlnet-uncanny-faces"
+
+python3 train_controlnet_flax.py \
+ --pretrained_model_name_or_path=$MODEL_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --dataset_name=multimodalart/facesyntheticsspigacaptioned \
+ --streaming \
+ --conditioning_image_column=spiga_seg \
+ --image_column=image \
+ --caption_column=image_caption \
+ --resolution=512 \
+ --max_train_samples 100000 \
+ --learning_rate=1e-5 \
+ --train_batch_size=1 \
+ --revision="flax" \
+ --report_to="wandb" \
+ --tracker_project_name=$HUB_MODEL_ID
+```
+
+Note, however, that the performance of the TPUs might get bottlenecked as streaming with `datasets` is not optimized for images. For ensuring maximum throughput, we encourage you to explore the following options:
+
+* [Webdataset](https://webdataset.github.io/webdataset/)
+* [TorchData](https://github.com/pytorch/data)
+* [TensorFlow Datasets](https://www.tensorflow.org/datasets/tfless_tfds)
+
+When work with a larger dataset, you may need to run training process for a long time and it’s useful to save regular checkpoints during the process. You can use the following argument to enable intermediate checkpointing:
+
+```bash
+ --checkpointing_steps=500
+```
+This will save the trained model in subfolders of your output_dir. Subfolder names is the number of steps performed so far; for example: a checkpoint saved after 500 training steps would be saved in a subfolder named 500
+
+You can then start your training from this saved checkpoint with
+
+```bash
+ --controlnet_model_name_or_path="./control_out/500"
+```
+
+We support training with the Min-SNR weighting strategy proposed in [Efficient Diffusion Training via Min-SNR Weighting Strategy](https://arxiv.org/abs/2303.09556) which helps to achieve faster convergence by rebalancing the loss. To use it, one needs to set the `--snr_gamma` argument. The recommended value when using it is `5.0`.
+
+We also support gradient accumulation - it is a technique that lets you use a bigger batch size than your machine would normally be able to fit into memory. You can use `gradient_accumulation_steps` argument to set gradient accumulation steps. The ControlNet author recommends using gradient accumulation to achieve better convergence. Read more [here](https://github.com/lllyasviel/ControlNet/blob/main/docs/train.md#more-consideration-sudden-converge-phenomenon-and-gradient-accumulation).
+
+You can **profile your code** with:
+
+```bash
+ --profile_steps==5
+```
+
+Refer to the [JAX documentation on profiling](https://jax.readthedocs.io/en/latest/profiling.html). To inspect the profile trace, you'll have to install and start Tensorboard with the profile plugin:
+
+```bash
+pip install tensorflow tensorboard-plugin-profile
+tensorboard --logdir runs/fill-circle-100steps-20230411_165612/
+```
+
+The profile can then be inspected at http://localhost:6006/#profile
+
+Sometimes you'll get version conflicts (error messages like `Duplicate plugins for name projector`), which means that you have to uninstall and reinstall all versions of Tensorflow/Tensorboard (e.g. with `pip uninstall tensorflow tf-nightly tensorboard tb-nightly tensorboard-plugin-profile && pip install tf-nightly tbp-nightly tensorboard-plugin-profile`).
+
+Note that the debugging functionality of the Tensorboard `profile` plugin is still under active development. Not all views are fully functional, and for example the `trace_viewer` cuts off events after 1M (which can result in all your device traces getting lost if you for example profile the compilation step by accident).
+
+## Support for Stable Diffusion XL
+
+We provide a training script for training a ControlNet with [Stable Diffusion XL](https://huggingface.co/papers/2307.01952). Please refer to [README_sdxl.md](./README_sdxl.md) for more details.
diff --git a/diffusers/examples/controlnet/README_sd3.md b/diffusers/examples/controlnet/README_sd3.md
new file mode 100644
index 0000000000000000000000000000000000000000..1788e07a21d69297f12f2f3dae6645305940cf19
--- /dev/null
+++ b/diffusers/examples/controlnet/README_sd3.md
@@ -0,0 +1,152 @@
+# ControlNet training example for Stable Diffusion 3 (SD3)
+
+The `train_controlnet_sd3.py` script shows how to implement the ControlNet training procedure and adapt it for [Stable Diffusion 3](https://arxiv.org/abs/2403.03206).
+
+## Running locally with PyTorch
+
+### Installing the dependencies
+
+Before running the scripts, make sure to install the library's training dependencies:
+
+**Important**
+
+To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
+
+```bash
+git clone https://github.com/huggingface/diffusers
+cd diffusers
+pip install -e .
+```
+
+Then cd in the `examples/controlnet` folder and run
+```bash
+pip install -r requirements_sd3.txt
+```
+
+And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
+
+```bash
+accelerate config
+```
+
+Or for a default accelerate configuration without answering questions about your environment
+
+```bash
+accelerate config default
+```
+
+Or if your environment doesn't support an interactive shell (e.g., a notebook)
+
+```python
+from accelerate.utils import write_basic_config
+write_basic_config()
+```
+
+When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
+
+## Circle filling dataset
+
+The original dataset is hosted in the [ControlNet repo](https://huggingface.co/lllyasviel/ControlNet/blob/main/training/fill50k.zip). We re-uploaded it to be compatible with `datasets` [here](https://huggingface.co/datasets/fusing/fill50k). Note that `datasets` handles dataloading within the training script.
+Please download the dataset and unzip it in the directory `fill50k` in the `examples/controlnet` folder.
+
+## Training
+
+First download the SD3 model from [Hugging Face Hub](https://huggingface.co/stabilityai/stable-diffusion-3-medium). We will use it as a base model for the ControlNet training.
+> [!NOTE]
+> As the model is gated, before using it with diffusers you first need to go to the [Stable Diffusion 3 Medium Hugging Face page](https://huggingface.co/stabilityai/stable-diffusion-3-medium-diffusers), fill in the form and accept the gate. Once you are in, you need to log in so that your system knows you’ve accepted the gate. Use the command below to log in:
+
+```bash
+huggingface-cli login
+```
+
+This will also allow us to push the trained model parameters to the Hugging Face Hub platform.
+
+
+Our training examples use two test conditioning images. They can be downloaded by running
+
+```sh
+wget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_1.png
+
+wget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_2.png
+```
+
+Then run the following commands to train a ControlNet model.
+
+```bash
+export MODEL_DIR="stabilityai/stable-diffusion-3-medium-diffusers"
+export OUTPUT_DIR="sd3-controlnet-out"
+
+accelerate launch train_controlnet_sd3.py \
+ --pretrained_model_name_or_path=$MODEL_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --train_data_dir="fill50k" \
+ --resolution=1024 \
+ --learning_rate=1e-5 \
+ --max_train_steps=15000 \
+ --validation_image "./conditioning_image_1.png" "./conditioning_image_2.png" \
+ --validation_prompt "red circle with blue background" "cyan circle with brown floral background" \
+ --validation_steps=100 \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=4
+```
+
+To better track our training experiments, we're using flags `validation_image`, `validation_prompt`, and `validation_steps` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.
+
+Our experiments were conducted on a single 40GB A100 GPU.
+
+### Inference
+
+Once training is done, we can perform inference like so:
+
+```python
+from diffusers import StableDiffusion3ControlNetPipeline, SD3ControlNetModel
+from diffusers.utils import load_image
+import torch
+
+base_model_path = "stabilityai/stable-diffusion-3-medium-diffusers"
+controlnet_path = "sd3-controlnet-out/checkpoint-6500/controlnet"
+
+controlnet = SD3ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16)
+pipe = StableDiffusion3ControlNetPipeline.from_pretrained(
+ base_model_path, controlnet=controlnet
+)
+pipe.to("cuda", torch.float16)
+
+
+control_image = load_image("./conditioning_image_1.png").resize((1024, 1024))
+prompt = "pale golden rod circle with old lace background"
+
+# generate image
+generator = torch.manual_seed(0)
+image = pipe(
+ prompt, num_inference_steps=20, generator=generator, control_image=control_image
+).images[0]
+image.save("./output.png")
+```
+
+## Notes
+
+### GPU usage
+
+SD3 is a large model and requires a lot of GPU memory.
+We recommend using one GPU with at least 80GB of memory.
+Make sure to use the right GPU when configuring the [accelerator](https://huggingface.co/docs/transformers/en/accelerate).
+
+
+## Example results
+
+#### After 500 steps with batch size 8
+
+| | |
+|-------------------|:-------------------------:|
+|| pale golden rod circle with old lace background |
+  |  |
+
+
+#### After 6500 steps with batch size 8:
+
+| | |
+|-------------------|:-------------------------:|
+|| pale golden rod circle with old lace background |
+  |  |
+
diff --git a/diffusers/examples/controlnet/README_sdxl.md b/diffusers/examples/controlnet/README_sdxl.md
new file mode 100644
index 0000000000000000000000000000000000000000..75511385ff37d5dc198a22c68731dcbf224f8fe0
--- /dev/null
+++ b/diffusers/examples/controlnet/README_sdxl.md
@@ -0,0 +1,141 @@
+# ControlNet training example for Stable Diffusion XL (SDXL)
+
+The `train_controlnet_sdxl.py` script shows how to implement the ControlNet training procedure and adapt it for [Stable Diffusion XL](https://huggingface.co/papers/2307.01952).
+
+## Running locally with PyTorch
+
+### Installing the dependencies
+
+Before running the scripts, make sure to install the library's training dependencies:
+
+**Important**
+
+To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
+
+```bash
+git clone https://github.com/huggingface/diffusers
+cd diffusers
+pip install -e .
+```
+
+Then cd in the `examples/controlnet` folder and run
+```bash
+pip install -r requirements_sdxl.txt
+```
+
+And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
+
+```bash
+accelerate config
+```
+
+Or for a default accelerate configuration without answering questions about your environment
+
+```bash
+accelerate config default
+```
+
+Or if your environment doesn't support an interactive shell (e.g., a notebook)
+
+```python
+from accelerate.utils import write_basic_config
+write_basic_config()
+```
+
+When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
+
+## Circle filling dataset
+
+The original dataset is hosted in the [ControlNet repo](https://huggingface.co/lllyasviel/ControlNet/blob/main/training/fill50k.zip). We re-uploaded it to be compatible with `datasets` [here](https://huggingface.co/datasets/fusing/fill50k). Note that `datasets` handles dataloading within the training script.
+
+## Training
+
+Our training examples use two test conditioning images. They can be downloaded by running
+
+```sh
+wget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_1.png
+
+wget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_2.png
+```
+
+Then run `huggingface-cli login` to log into your Hugging Face account. This is needed to be able to push the trained ControlNet parameters to Hugging Face Hub.
+
+```bash
+export MODEL_DIR="stabilityai/stable-diffusion-xl-base-1.0"
+export OUTPUT_DIR="path to save model"
+
+accelerate launch train_controlnet_sdxl.py \
+ --pretrained_model_name_or_path=$MODEL_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --dataset_name=fusing/fill50k \
+ --mixed_precision="fp16" \
+ --resolution=1024 \
+ --learning_rate=1e-5 \
+ --max_train_steps=15000 \
+ --validation_image "./conditioning_image_1.png" "./conditioning_image_2.png" \
+ --validation_prompt "red circle with blue background" "cyan circle with brown floral background" \
+ --validation_steps=100 \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=4 \
+ --report_to="wandb" \
+ --seed=42 \
+ --push_to_hub
+```
+
+To better track our training experiments, we're using the following flags in the command above:
+
+* `report_to="wandb` will ensure the training runs are tracked on Weights and Biases. To use it, be sure to install `wandb` with `pip install wandb`.
+* `validation_image`, `validation_prompt`, and `validation_steps` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.
+
+Our experiments were conducted on a single 40GB A100 GPU.
+
+### Inference
+
+Once training is done, we can perform inference like so:
+
+```python
+from diffusers import StableDiffusionXLControlNetPipeline, ControlNetModel, UniPCMultistepScheduler
+from diffusers.utils import load_image
+import torch
+
+base_model_path = "stabilityai/stable-diffusion-xl-base-1.0"
+controlnet_path = "path to controlnet"
+
+controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16)
+pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
+ base_model_path, controlnet=controlnet, torch_dtype=torch.float16
+)
+
+# speed up diffusion process with faster scheduler and memory optimization
+pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
+# remove following line if xformers is not installed or when using Torch 2.0.
+pipe.enable_xformers_memory_efficient_attention()
+# memory optimization.
+pipe.enable_model_cpu_offload()
+
+control_image = load_image("./conditioning_image_1.png").resize((1024, 1024))
+prompt = "pale golden rod circle with old lace background"
+
+# generate image
+generator = torch.manual_seed(0)
+image = pipe(
+ prompt, num_inference_steps=20, generator=generator, image=control_image
+).images[0]
+image.save("./output.png")
+```
+
+## Notes
+
+### Specifying a better VAE
+
+SDXL's VAE is known to suffer from numerical instability issues. This is why we also expose a CLI argument namely `--pretrained_vae_model_name_or_path` that lets you specify the location of an alternative VAE (such as [`madebyollin/sdxl-vae-fp16-fix`](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)).
+
+If you're using this VAE during training, you need to ensure you're using it during inference too. You do so by:
+
+```diff
++ vae = AutoencoderKL.from_pretrained(vae_path_or_repo_id, torch_dtype=torch.float16)
+controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16)
+pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
+ base_model_path, controlnet=controlnet, torch_dtype=torch.float16,
++ vae=vae,
+)
diff --git a/diffusers/examples/controlnet/requirements.txt b/diffusers/examples/controlnet/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..d19c62296702868c768596bdd866dd5b504e4180
--- /dev/null
+++ b/diffusers/examples/controlnet/requirements.txt
@@ -0,0 +1,6 @@
+accelerate>=0.16.0
+torchvision
+transformers>=4.25.1
+ftfy
+tensorboard
+datasets
diff --git a/diffusers/examples/controlnet/requirements_flax.txt b/diffusers/examples/controlnet/requirements_flax.txt
new file mode 100644
index 0000000000000000000000000000000000000000..b6eb64e254625ee8eff2ef126d67adfd5b6994dc
--- /dev/null
+++ b/diffusers/examples/controlnet/requirements_flax.txt
@@ -0,0 +1,9 @@
+transformers>=4.25.1
+datasets
+flax
+optax
+torch
+torchvision
+ftfy
+tensorboard
+Jinja2
diff --git a/diffusers/examples/controlnet/requirements_sd3.txt b/diffusers/examples/controlnet/requirements_sd3.txt
new file mode 100644
index 0000000000000000000000000000000000000000..5ab6e9932e10a1e5337f3bc3faa8a192f4f60a52
--- /dev/null
+++ b/diffusers/examples/controlnet/requirements_sd3.txt
@@ -0,0 +1,8 @@
+accelerate>=0.16.0
+torchvision
+transformers>=4.25.1
+ftfy
+tensorboard
+Jinja2
+datasets
+wandb
diff --git a/diffusers/examples/controlnet/requirements_sdxl.txt b/diffusers/examples/controlnet/requirements_sdxl.txt
new file mode 100644
index 0000000000000000000000000000000000000000..5ab6e9932e10a1e5337f3bc3faa8a192f4f60a52
--- /dev/null
+++ b/diffusers/examples/controlnet/requirements_sdxl.txt
@@ -0,0 +1,8 @@
+accelerate>=0.16.0
+torchvision
+transformers>=4.25.1
+ftfy
+tensorboard
+Jinja2
+datasets
+wandb
diff --git a/diffusers/examples/controlnet/test_controlnet.py b/diffusers/examples/controlnet/test_controlnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..77b5614c7fb0362447529d95c05b948510215f09
--- /dev/null
+++ b/diffusers/examples/controlnet/test_controlnet.py
@@ -0,0 +1,138 @@
+# coding=utf-8
+# Copyright 2024 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import os
+import sys
+import tempfile
+
+
+sys.path.append("..")
+from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
+
+
+logging.basicConfig(level=logging.DEBUG)
+
+logger = logging.getLogger()
+stream_handler = logging.StreamHandler(sys.stdout)
+logger.addHandler(stream_handler)
+
+
+class ControlNet(ExamplesTestsAccelerate):
+ def test_controlnet_checkpointing_checkpoints_total_limit(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/controlnet/train_controlnet.py
+ --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
+ --dataset_name=hf-internal-testing/fill10
+ --output_dir={tmpdir}
+ --resolution=64
+ --train_batch_size=1
+ --gradient_accumulation_steps=1
+ --max_train_steps=6
+ --checkpoints_total_limit=2
+ --checkpointing_steps=2
+ --controlnet_model_name_or_path=hf-internal-testing/tiny-controlnet
+ """.split()
+
+ run_command(self._launch_args + test_args)
+
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-4", "checkpoint-6"},
+ )
+
+ def test_controlnet_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/controlnet/train_controlnet.py
+ --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
+ --dataset_name=hf-internal-testing/fill10
+ --output_dir={tmpdir}
+ --resolution=64
+ --train_batch_size=1
+ --gradient_accumulation_steps=1
+ --controlnet_model_name_or_path=hf-internal-testing/tiny-controlnet
+ --max_train_steps=6
+ --checkpointing_steps=2
+ """.split()
+
+ run_command(self._launch_args + test_args)
+
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-2", "checkpoint-4", "checkpoint-6"},
+ )
+
+ resume_run_args = f"""
+ examples/controlnet/train_controlnet.py
+ --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
+ --dataset_name=hf-internal-testing/fill10
+ --output_dir={tmpdir}
+ --resolution=64
+ --train_batch_size=1
+ --gradient_accumulation_steps=1
+ --controlnet_model_name_or_path=hf-internal-testing/tiny-controlnet
+ --max_train_steps=8
+ --checkpointing_steps=2
+ --resume_from_checkpoint=checkpoint-6
+ --checkpoints_total_limit=2
+ """.split()
+
+ run_command(self._launch_args + resume_run_args)
+
+ self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})
+
+
+class ControlNetSDXL(ExamplesTestsAccelerate):
+ def test_controlnet_sdxl(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/controlnet/train_controlnet_sdxl.py
+ --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-xl-pipe
+ --dataset_name=hf-internal-testing/fill10
+ --output_dir={tmpdir}
+ --resolution=64
+ --train_batch_size=1
+ --gradient_accumulation_steps=1
+ --controlnet_model_name_or_path=hf-internal-testing/tiny-controlnet-sdxl
+ --max_train_steps=4
+ --checkpointing_steps=2
+ """.split()
+
+ run_command(self._launch_args + test_args)
+
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "diffusion_pytorch_model.safetensors")))
+
+
+class ControlNetSD3(ExamplesTestsAccelerate):
+ def test_controlnet_sd3(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/controlnet/train_controlnet_sd3.py
+ --pretrained_model_name_or_path=DavyMorgan/tiny-sd3-pipe
+ --dataset_name=hf-internal-testing/fill10
+ --output_dir={tmpdir}
+ --resolution=64
+ --train_batch_size=1
+ --gradient_accumulation_steps=1
+ --controlnet_model_name_or_path=DavyMorgan/tiny-controlnet-sd3
+ --max_train_steps=4
+ --checkpointing_steps=2
+ """.split()
+
+ run_command(self._launch_args + test_args)
+
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "diffusion_pytorch_model.safetensors")))
diff --git a/diffusers/examples/controlnet/train_controlnet.py b/diffusers/examples/controlnet/train_controlnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..2e902db7ffc7e64a7c29977b93370be2afe7b9fb
--- /dev/null
+++ b/diffusers/examples/controlnet/train_controlnet.py
@@ -0,0 +1,1187 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+import argparse
+import contextlib
+import gc
+import logging
+import math
+import os
+import random
+import shutil
+from pathlib import Path
+
+import accelerate
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import ProjectConfiguration, set_seed
+from datasets import load_dataset
+from huggingface_hub import create_repo, upload_folder
+from packaging import version
+from PIL import Image
+from torchvision import transforms
+from tqdm.auto import tqdm
+from transformers import AutoTokenizer, PretrainedConfig
+
+import diffusers
+from diffusers import (
+ AutoencoderKL,
+ ControlNetModel,
+ DDPMScheduler,
+ StableDiffusionControlNetPipeline,
+ UNet2DConditionModel,
+ UniPCMultistepScheduler,
+)
+from diffusers.optimization import get_scheduler
+from diffusers.utils import check_min_version, is_wandb_available
+from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
+from diffusers.utils.import_utils import is_xformers_available
+from diffusers.utils.torch_utils import is_compiled_module
+
+
+if is_wandb_available():
+ import wandb
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.31.0.dev0")
+
+logger = get_logger(__name__)
+
+
+def image_grid(imgs, rows, cols):
+ assert len(imgs) == rows * cols
+
+ w, h = imgs[0].size
+ grid = Image.new("RGB", size=(cols * w, rows * h))
+
+ for i, img in enumerate(imgs):
+ grid.paste(img, box=(i % cols * w, i // cols * h))
+ return grid
+
+
+def log_validation(
+ vae, text_encoder, tokenizer, unet, controlnet, args, accelerator, weight_dtype, step, is_final_validation=False
+):
+ logger.info("Running validation... ")
+
+ if not is_final_validation:
+ controlnet = accelerator.unwrap_model(controlnet)
+ else:
+ controlnet = ControlNetModel.from_pretrained(args.output_dir, torch_dtype=weight_dtype)
+
+ pipeline = StableDiffusionControlNetPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ controlnet=controlnet,
+ safety_checker=None,
+ revision=args.revision,
+ variant=args.variant,
+ torch_dtype=weight_dtype,
+ )
+ pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config)
+ pipeline = pipeline.to(accelerator.device)
+ pipeline.set_progress_bar_config(disable=True)
+
+ if args.enable_xformers_memory_efficient_attention:
+ pipeline.enable_xformers_memory_efficient_attention()
+
+ if args.seed is None:
+ generator = None
+ else:
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
+
+ if len(args.validation_image) == len(args.validation_prompt):
+ validation_images = args.validation_image
+ validation_prompts = args.validation_prompt
+ elif len(args.validation_image) == 1:
+ validation_images = args.validation_image * len(args.validation_prompt)
+ validation_prompts = args.validation_prompt
+ elif len(args.validation_prompt) == 1:
+ validation_images = args.validation_image
+ validation_prompts = args.validation_prompt * len(args.validation_image)
+ else:
+ raise ValueError(
+ "number of `args.validation_image` and `args.validation_prompt` should be checked in `parse_args`"
+ )
+
+ image_logs = []
+ inference_ctx = contextlib.nullcontext() if is_final_validation else torch.autocast("cuda")
+
+ for validation_prompt, validation_image in zip(validation_prompts, validation_images):
+ validation_image = Image.open(validation_image).convert("RGB")
+
+ images = []
+
+ for _ in range(args.num_validation_images):
+ with inference_ctx:
+ image = pipeline(
+ validation_prompt, validation_image, num_inference_steps=20, generator=generator
+ ).images[0]
+
+ images.append(image)
+
+ image_logs.append(
+ {"validation_image": validation_image, "images": images, "validation_prompt": validation_prompt}
+ )
+
+ tracker_key = "test" if is_final_validation else "validation"
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ for log in image_logs:
+ images = log["images"]
+ validation_prompt = log["validation_prompt"]
+ validation_image = log["validation_image"]
+
+ formatted_images = []
+
+ formatted_images.append(np.asarray(validation_image))
+
+ for image in images:
+ formatted_images.append(np.asarray(image))
+
+ formatted_images = np.stack(formatted_images)
+
+ tracker.writer.add_images(validation_prompt, formatted_images, step, dataformats="NHWC")
+ elif tracker.name == "wandb":
+ formatted_images = []
+
+ for log in image_logs:
+ images = log["images"]
+ validation_prompt = log["validation_prompt"]
+ validation_image = log["validation_image"]
+
+ formatted_images.append(wandb.Image(validation_image, caption="Controlnet conditioning"))
+
+ for image in images:
+ image = wandb.Image(image, caption=validation_prompt)
+ formatted_images.append(image)
+
+ tracker.log({tracker_key: formatted_images})
+ else:
+ logger.warning(f"image logging not implemented for {tracker.name}")
+
+ del pipeline
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ return image_logs
+
+
+def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
+ text_encoder_config = PretrainedConfig.from_pretrained(
+ pretrained_model_name_or_path,
+ subfolder="text_encoder",
+ revision=revision,
+ )
+ model_class = text_encoder_config.architectures[0]
+
+ if model_class == "CLIPTextModel":
+ from transformers import CLIPTextModel
+
+ return CLIPTextModel
+ elif model_class == "RobertaSeriesModelWithTransformation":
+ from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation
+
+ return RobertaSeriesModelWithTransformation
+ else:
+ raise ValueError(f"{model_class} is not supported.")
+
+
+def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=None):
+ img_str = ""
+ if image_logs is not None:
+ img_str = "You can find some example images below.\n\n"
+ for i, log in enumerate(image_logs):
+ images = log["images"]
+ validation_prompt = log["validation_prompt"]
+ validation_image = log["validation_image"]
+ validation_image.save(os.path.join(repo_folder, "image_control.png"))
+ img_str += f"prompt: {validation_prompt}\n"
+ images = [validation_image] + images
+ image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images_{i}.png"))
+ img_str += f"\n"
+
+ model_description = f"""
+# controlnet-{repo_id}
+
+These are controlnet weights trained on {base_model} with new type of conditioning.
+{img_str}
+"""
+ model_card = load_or_create_model_card(
+ repo_id_or_path=repo_id,
+ from_training=True,
+ license="creativeml-openrail-m",
+ base_model=base_model,
+ model_description=model_description,
+ inference=True,
+ )
+
+ tags = [
+ "stable-diffusion",
+ "stable-diffusion-diffusers",
+ "text-to-image",
+ "diffusers",
+ "controlnet",
+ "diffusers-training",
+ ]
+ model_card = populate_model_card(model_card, tags=tags)
+
+ model_card.save(os.path.join(repo_folder, "README.md"))
+
+
+def parse_args(input_args=None):
+ parser = argparse.ArgumentParser(description="Simple example of a ControlNet training script.")
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--controlnet_model_name_or_path",
+ type=str,
+ default=None,
+ help="Path to pretrained controlnet model or model identifier from huggingface.co/models."
+ " If not specified controlnet weights are initialized from unet.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--variant",
+ type=str,
+ default=None,
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
+ )
+ parser.add_argument(
+ "--tokenizer_name",
+ type=str,
+ default=None,
+ help="Pretrained tokenizer name or path if not the same as model_name",
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="controlnet-model",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=1)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. "
+ "In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference."
+ "Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components."
+ "See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step"
+ "instructions."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=5e-6,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--lr_num_cycles",
+ type=int,
+ default=1,
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
+ )
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
+ parser.add_argument(
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
+ )
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument(
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
+ )
+ parser.add_argument(
+ "--set_grads_to_none",
+ action="store_true",
+ help=(
+ "Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain"
+ " behaviors, so disable this argument if it causes any problems. More info:"
+ " https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html"
+ ),
+ )
+ parser.add_argument(
+ "--dataset_name",
+ type=str,
+ default=None,
+ help=(
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
+ " or to a folder containing files that 🤗 Datasets can understand."
+ ),
+ )
+ parser.add_argument(
+ "--dataset_config_name",
+ type=str,
+ default=None,
+ help="The config of the Dataset, leave as None if there's only one config.",
+ )
+ parser.add_argument(
+ "--train_data_dir",
+ type=str,
+ default=None,
+ help=(
+ "A folder containing the training data. Folder contents must follow the structure described in"
+ " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
+ " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
+ ),
+ )
+ parser.add_argument(
+ "--image_column", type=str, default="image", help="The column of the dataset containing the target image."
+ )
+ parser.add_argument(
+ "--conditioning_image_column",
+ type=str,
+ default="conditioning_image",
+ help="The column of the dataset containing the controlnet conditioning image.",
+ )
+ parser.add_argument(
+ "--caption_column",
+ type=str,
+ default="text",
+ help="The column of the dataset containing a caption or a list of captions.",
+ )
+ parser.add_argument(
+ "--max_train_samples",
+ type=int,
+ default=None,
+ help=(
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ ),
+ )
+ parser.add_argument(
+ "--proportion_empty_prompts",
+ type=float,
+ default=0,
+ help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).",
+ )
+ parser.add_argument(
+ "--validation_prompt",
+ type=str,
+ default=None,
+ nargs="+",
+ help=(
+ "A set of prompts evaluated every `--validation_steps` and logged to `--report_to`."
+ " Provide either a matching number of `--validation_image`s, a single `--validation_image`"
+ " to be used with all prompts, or a single prompt that will be used with all `--validation_image`s."
+ ),
+ )
+ parser.add_argument(
+ "--validation_image",
+ type=str,
+ default=None,
+ nargs="+",
+ help=(
+ "A set of paths to the controlnet conditioning image be evaluated every `--validation_steps`"
+ " and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a"
+ " a single `--validation_prompt` to be used with all `--validation_image`s, or a single"
+ " `--validation_image` that will be used with all `--validation_prompt`s."
+ ),
+ )
+ parser.add_argument(
+ "--num_validation_images",
+ type=int,
+ default=4,
+ help="Number of images to be generated for each `--validation_image`, `--validation_prompt` pair",
+ )
+ parser.add_argument(
+ "--validation_steps",
+ type=int,
+ default=100,
+ help=(
+ "Run validation every X steps. Validation consists of running the prompt"
+ " `args.validation_prompt` multiple times: `args.num_validation_images`"
+ " and logging the images."
+ ),
+ )
+ parser.add_argument(
+ "--tracker_project_name",
+ type=str,
+ default="train_controlnet",
+ help=(
+ "The `project_name` argument passed to Accelerator.init_trackers for"
+ " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
+ ),
+ )
+
+ if input_args is not None:
+ args = parser.parse_args(input_args)
+ else:
+ args = parser.parse_args()
+
+ if args.dataset_name is None and args.train_data_dir is None:
+ raise ValueError("Specify either `--dataset_name` or `--train_data_dir`")
+
+ if args.dataset_name is not None and args.train_data_dir is not None:
+ raise ValueError("Specify only one of `--dataset_name` or `--train_data_dir`")
+
+ if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1:
+ raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].")
+
+ if args.validation_prompt is not None and args.validation_image is None:
+ raise ValueError("`--validation_image` must be set if `--validation_prompt` is set")
+
+ if args.validation_prompt is None and args.validation_image is not None:
+ raise ValueError("`--validation_prompt` must be set if `--validation_image` is set")
+
+ if (
+ args.validation_image is not None
+ and args.validation_prompt is not None
+ and len(args.validation_image) != 1
+ and len(args.validation_prompt) != 1
+ and len(args.validation_image) != len(args.validation_prompt)
+ ):
+ raise ValueError(
+ "Must provide either 1 `--validation_image`, 1 `--validation_prompt`,"
+ " or the same number of `--validation_prompt`s and `--validation_image`s"
+ )
+
+ if args.resolution % 8 != 0:
+ raise ValueError(
+ "`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the controlnet encoder."
+ )
+
+ return args
+
+
+def make_train_dataset(args, tokenizer, accelerator):
+ # Get the datasets: you can either provide your own training and evaluation files (see below)
+ # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
+
+ # In distributed training, the load_dataset function guarantees that only one local process can concurrently
+ # download the dataset.
+ if args.dataset_name is not None:
+ # Downloading and loading a dataset from the hub.
+ dataset = load_dataset(
+ args.dataset_name,
+ args.dataset_config_name,
+ cache_dir=args.cache_dir,
+ )
+ else:
+ if args.train_data_dir is not None:
+ dataset = load_dataset(
+ args.train_data_dir,
+ cache_dir=args.cache_dir,
+ )
+ # See more about loading custom images at
+ # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script
+
+ # Preprocessing the datasets.
+ # We need to tokenize inputs and targets.
+ column_names = dataset["train"].column_names
+
+ # 6. Get the column names for input/target.
+ if args.image_column is None:
+ image_column = column_names[0]
+ logger.info(f"image column defaulting to {image_column}")
+ else:
+ image_column = args.image_column
+ if image_column not in column_names:
+ raise ValueError(
+ f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
+ )
+
+ if args.caption_column is None:
+ caption_column = column_names[1]
+ logger.info(f"caption column defaulting to {caption_column}")
+ else:
+ caption_column = args.caption_column
+ if caption_column not in column_names:
+ raise ValueError(
+ f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
+ )
+
+ if args.conditioning_image_column is None:
+ conditioning_image_column = column_names[2]
+ logger.info(f"conditioning image column defaulting to {conditioning_image_column}")
+ else:
+ conditioning_image_column = args.conditioning_image_column
+ if conditioning_image_column not in column_names:
+ raise ValueError(
+ f"`--conditioning_image_column` value '{args.conditioning_image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
+ )
+
+ def tokenize_captions(examples, is_train=True):
+ captions = []
+ for caption in examples[caption_column]:
+ if random.random() < args.proportion_empty_prompts:
+ captions.append("")
+ elif isinstance(caption, str):
+ captions.append(caption)
+ elif isinstance(caption, (list, np.ndarray)):
+ # take a random caption if there are multiple
+ captions.append(random.choice(caption) if is_train else caption[0])
+ else:
+ raise ValueError(
+ f"Caption column `{caption_column}` should contain either strings or lists of strings."
+ )
+ inputs = tokenizer(
+ captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
+ )
+ return inputs.input_ids
+
+ image_transforms = transforms.Compose(
+ [
+ transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
+ transforms.CenterCrop(args.resolution),
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+ )
+
+ conditioning_image_transforms = transforms.Compose(
+ [
+ transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
+ transforms.CenterCrop(args.resolution),
+ transforms.ToTensor(),
+ ]
+ )
+
+ def preprocess_train(examples):
+ images = [image.convert("RGB") for image in examples[image_column]]
+ images = [image_transforms(image) for image in images]
+
+ conditioning_images = [image.convert("RGB") for image in examples[conditioning_image_column]]
+ conditioning_images = [conditioning_image_transforms(image) for image in conditioning_images]
+
+ examples["pixel_values"] = images
+ examples["conditioning_pixel_values"] = conditioning_images
+ examples["input_ids"] = tokenize_captions(examples)
+
+ return examples
+
+ with accelerator.main_process_first():
+ if args.max_train_samples is not None:
+ dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
+ # Set the training transforms
+ train_dataset = dataset["train"].with_transform(preprocess_train)
+
+ return train_dataset
+
+
+def collate_fn(examples):
+ pixel_values = torch.stack([example["pixel_values"] for example in examples])
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
+
+ conditioning_pixel_values = torch.stack([example["conditioning_pixel_values"] for example in examples])
+ conditioning_pixel_values = conditioning_pixel_values.to(memory_format=torch.contiguous_format).float()
+
+ input_ids = torch.stack([example["input_ids"] for example in examples])
+
+ return {
+ "pixel_values": pixel_values,
+ "conditioning_pixel_values": conditioning_pixel_values,
+ "input_ids": input_ids,
+ }
+
+
+def main(args):
+ if args.report_to == "wandb" and args.hub_token is not None:
+ raise ValueError(
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
+ " Please use `huggingface-cli login` to authenticate with the Hub."
+ )
+
+ logging_dir = Path(args.output_dir, args.logging_dir)
+
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ )
+
+ # Disable AMP for MPS.
+ if torch.backends.mps.is_available():
+ accelerator.native_amp = False
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
+ ).repo_id
+
+ # Load the tokenizer
+ if args.tokenizer_name:
+ tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False)
+ elif args.pretrained_model_name_or_path:
+ tokenizer = AutoTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="tokenizer",
+ revision=args.revision,
+ use_fast=False,
+ )
+
+ # import correct text encoder class
+ text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)
+
+ # Load scheduler and models
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
+ text_encoder = text_encoder_cls.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
+ )
+ vae = AutoencoderKL.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant
+ )
+ unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
+ )
+
+ if args.controlnet_model_name_or_path:
+ logger.info("Loading existing controlnet weights")
+ controlnet = ControlNetModel.from_pretrained(args.controlnet_model_name_or_path)
+ else:
+ logger.info("Initializing controlnet weights from unet")
+ controlnet = ControlNetModel.from_unet(unet)
+
+ # Taken from [Sayak Paul's Diffusers PR #6511](https://github.com/huggingface/diffusers/pull/6511/files)
+ def unwrap_model(model):
+ model = accelerator.unwrap_model(model)
+ model = model._orig_mod if is_compiled_module(model) else model
+ return model
+
+ # `accelerate` 0.16.0 will have better support for customized saving
+ if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
+ def save_model_hook(models, weights, output_dir):
+ if accelerator.is_main_process:
+ i = len(weights) - 1
+
+ while len(weights) > 0:
+ weights.pop()
+ model = models[i]
+
+ sub_dir = "controlnet"
+ model.save_pretrained(os.path.join(output_dir, sub_dir))
+
+ i -= 1
+
+ def load_model_hook(models, input_dir):
+ while len(models) > 0:
+ # pop models so that they are not loaded again
+ model = models.pop()
+
+ # load diffusers style into model
+ load_model = ControlNetModel.from_pretrained(input_dir, subfolder="controlnet")
+ model.register_to_config(**load_model.config)
+
+ model.load_state_dict(load_model.state_dict())
+ del load_model
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
+
+ vae.requires_grad_(False)
+ unet.requires_grad_(False)
+ text_encoder.requires_grad_(False)
+ controlnet.train()
+
+ if args.enable_xformers_memory_efficient_attention:
+ if is_xformers_available():
+ import xformers
+
+ xformers_version = version.parse(xformers.__version__)
+ if xformers_version == version.parse("0.0.16"):
+ logger.warning(
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
+ )
+ unet.enable_xformers_memory_efficient_attention()
+ controlnet.enable_xformers_memory_efficient_attention()
+ else:
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
+
+ if args.gradient_checkpointing:
+ controlnet.enable_gradient_checkpointing()
+
+ # Check that all trainable models are in full precision
+ low_precision_error_string = (
+ " Please make sure to always have all model weights in full float32 precision when starting training - even if"
+ " doing mixed precision training, copy of the weights should still be float32."
+ )
+
+ if unwrap_model(controlnet).dtype != torch.float32:
+ raise ValueError(
+ f"Controlnet loaded as datatype {unwrap_model(controlnet).dtype}. {low_precision_error_string}"
+ )
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
+ )
+
+ optimizer_class = bnb.optim.AdamW8bit
+ else:
+ optimizer_class = torch.optim.AdamW
+
+ # Optimizer creation
+ params_to_optimize = controlnet.parameters()
+ optimizer = optimizer_class(
+ params_to_optimize,
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ train_dataset = make_train_dataset(args, tokenizer, accelerator)
+
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset,
+ shuffle=True,
+ collate_fn=collate_fn,
+ batch_size=args.train_batch_size,
+ num_workers=args.dataloader_num_workers,
+ )
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
+ num_cycles=args.lr_num_cycles,
+ power=args.lr_power,
+ )
+
+ # Prepare everything with our `accelerator`.
+ controlnet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ controlnet, optimizer, train_dataloader, lr_scheduler
+ )
+
+ # For mixed precision training we cast the text_encoder and vae weights to half-precision
+ # as these models are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ # Move vae, unet and text_encoder to device and cast to weight_dtype
+ vae.to(accelerator.device, dtype=weight_dtype)
+ unet.to(accelerator.device, dtype=weight_dtype)
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ tracker_config = dict(vars(args))
+
+ # tensorboard cannot handle list types for config
+ tracker_config.pop("validation_prompt")
+ tracker_config.pop("validation_image")
+
+ accelerator.init_trackers(args.tracker_project_name, config=tracker_config)
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the most recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ initial_global_step = 0
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ initial_global_step = global_step
+ first_epoch = global_step // num_update_steps_per_epoch
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
+
+ image_logs = None
+ for epoch in range(first_epoch, args.num_train_epochs):
+ for step, batch in enumerate(train_dataloader):
+ with accelerator.accumulate(controlnet):
+ # Convert images to latent space
+ latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
+ latents = latents * vae.config.scaling_factor
+
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(latents)
+ bsz = latents.shape[0]
+ # Sample a random timestep for each image
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
+ timesteps = timesteps.long()
+
+ # Add noise to the latents according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
+
+ # Get the text embedding for conditioning
+ encoder_hidden_states = text_encoder(batch["input_ids"], return_dict=False)[0]
+
+ controlnet_image = batch["conditioning_pixel_values"].to(dtype=weight_dtype)
+
+ down_block_res_samples, mid_block_res_sample = controlnet(
+ noisy_latents,
+ timesteps,
+ encoder_hidden_states=encoder_hidden_states,
+ controlnet_cond=controlnet_image,
+ return_dict=False,
+ )
+
+ # Predict the noise residual
+ model_pred = unet(
+ noisy_latents,
+ timesteps,
+ encoder_hidden_states=encoder_hidden_states,
+ down_block_additional_residuals=[
+ sample.to(dtype=weight_dtype) for sample in down_block_res_samples
+ ],
+ mid_block_additional_residual=mid_block_res_sample.to(dtype=weight_dtype),
+ return_dict=False,
+ )[0]
+
+ # Get the target for loss depending on the prediction type
+ if noise_scheduler.config.prediction_type == "epsilon":
+ target = noise
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
+ else:
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
+
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ params_to_clip = controlnet.parameters()
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad(set_to_none=args.set_grads_to_none)
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ global_step += 1
+
+ if accelerator.is_main_process:
+ if global_step % args.checkpointing_steps == 0:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ if args.validation_prompt is not None and global_step % args.validation_steps == 0:
+ image_logs = log_validation(
+ vae,
+ text_encoder,
+ tokenizer,
+ unet,
+ controlnet,
+ args,
+ accelerator,
+ weight_dtype,
+ global_step,
+ )
+
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+ accelerator.log(logs, step=global_step)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ # Create the pipeline using using the trained modules and save it.
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ controlnet = unwrap_model(controlnet)
+ controlnet.save_pretrained(args.output_dir)
+
+ # Run a final round of validation.
+ image_logs = None
+ if args.validation_prompt is not None:
+ image_logs = log_validation(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ controlnet=None,
+ args=args,
+ accelerator=accelerator,
+ weight_dtype=weight_dtype,
+ step=global_step,
+ is_final_validation=True,
+ )
+
+ if args.push_to_hub:
+ save_model_card(
+ repo_id,
+ image_logs=image_logs,
+ base_model=args.pretrained_model_name_or_path,
+ repo_folder=args.output_dir,
+ )
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ main(args)
diff --git a/diffusers/examples/controlnet/train_controlnet_flax.py b/diffusers/examples/controlnet/train_controlnet_flax.py
new file mode 100644
index 0000000000000000000000000000000000000000..1aa9e881fca53f87c1901747c08032adcb27cb2d
--- /dev/null
+++ b/diffusers/examples/controlnet/train_controlnet_flax.py
@@ -0,0 +1,1152 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+import argparse
+import logging
+import math
+import os
+import random
+import time
+from pathlib import Path
+
+import jax
+import jax.numpy as jnp
+import numpy as np
+import optax
+import torch
+import torch.utils.checkpoint
+import transformers
+from datasets import load_dataset, load_from_disk
+from flax import jax_utils
+from flax.core.frozen_dict import unfreeze
+from flax.training import train_state
+from flax.training.common_utils import shard
+from huggingface_hub import create_repo, upload_folder
+from PIL import Image, PngImagePlugin
+from torch.utils.data import IterableDataset
+from torchvision import transforms
+from tqdm.auto import tqdm
+from transformers import CLIPTokenizer, FlaxCLIPTextModel, set_seed
+
+from diffusers import (
+ FlaxAutoencoderKL,
+ FlaxControlNetModel,
+ FlaxDDPMScheduler,
+ FlaxStableDiffusionControlNetPipeline,
+ FlaxUNet2DConditionModel,
+)
+from diffusers.utils import check_min_version, is_wandb_available, make_image_grid
+from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
+
+
+# To prevent an error that occurs when there are abnormally large compressed data chunk in the png image
+# see more https://github.com/python-pillow/Pillow/issues/5610
+LARGE_ENOUGH_NUMBER = 100
+PngImagePlugin.MAX_TEXT_CHUNK = LARGE_ENOUGH_NUMBER * (1024**2)
+
+if is_wandb_available():
+ import wandb
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.31.0.dev0")
+
+logger = logging.getLogger(__name__)
+
+
+def log_validation(pipeline, pipeline_params, controlnet_params, tokenizer, args, rng, weight_dtype):
+ logger.info("Running validation...")
+
+ pipeline_params = pipeline_params.copy()
+ pipeline_params["controlnet"] = controlnet_params
+
+ num_samples = jax.device_count()
+ prng_seed = jax.random.split(rng, jax.device_count())
+
+ if len(args.validation_image) == len(args.validation_prompt):
+ validation_images = args.validation_image
+ validation_prompts = args.validation_prompt
+ elif len(args.validation_image) == 1:
+ validation_images = args.validation_image * len(args.validation_prompt)
+ validation_prompts = args.validation_prompt
+ elif len(args.validation_prompt) == 1:
+ validation_images = args.validation_image
+ validation_prompts = args.validation_prompt * len(args.validation_image)
+ else:
+ raise ValueError(
+ "number of `args.validation_image` and `args.validation_prompt` should be checked in `parse_args`"
+ )
+
+ image_logs = []
+
+ for validation_prompt, validation_image in zip(validation_prompts, validation_images):
+ prompts = num_samples * [validation_prompt]
+ prompt_ids = pipeline.prepare_text_inputs(prompts)
+ prompt_ids = shard(prompt_ids)
+
+ validation_image = Image.open(validation_image).convert("RGB")
+ processed_image = pipeline.prepare_image_inputs(num_samples * [validation_image])
+ processed_image = shard(processed_image)
+ images = pipeline(
+ prompt_ids=prompt_ids,
+ image=processed_image,
+ params=pipeline_params,
+ prng_seed=prng_seed,
+ num_inference_steps=50,
+ jit=True,
+ ).images
+
+ images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])
+ images = pipeline.numpy_to_pil(images)
+
+ image_logs.append(
+ {"validation_image": validation_image, "images": images, "validation_prompt": validation_prompt}
+ )
+
+ if args.report_to == "wandb":
+ formatted_images = []
+ for log in image_logs:
+ images = log["images"]
+ validation_prompt = log["validation_prompt"]
+ validation_image = log["validation_image"]
+
+ formatted_images.append(wandb.Image(validation_image, caption="Controlnet conditioning"))
+ for image in images:
+ image = wandb.Image(image, caption=validation_prompt)
+ formatted_images.append(image)
+
+ wandb.log({"validation": formatted_images})
+ else:
+ logger.warning(f"image logging not implemented for {args.report_to}")
+
+ return image_logs
+
+
+def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=None):
+ img_str = ""
+ if image_logs is not None:
+ for i, log in enumerate(image_logs):
+ images = log["images"]
+ validation_prompt = log["validation_prompt"]
+ validation_image = log["validation_image"]
+ validation_image.save(os.path.join(repo_folder, "image_control.png"))
+ img_str += f"prompt: {validation_prompt}\n"
+ images = [validation_image] + images
+ make_image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images_{i}.png"))
+ img_str += f"\n"
+
+ model_description = f"""
+# controlnet- {repo_id}
+
+These are controlnet weights trained on {base_model} with new type of conditioning. You can find some example images in the following. \n
+{img_str}
+"""
+
+ model_card = load_or_create_model_card(
+ repo_id_or_path=repo_id,
+ from_training=True,
+ license="creativeml-openrail-m",
+ base_model=base_model,
+ model_description=model_description,
+ inference=True,
+ )
+
+ tags = [
+ "stable-diffusion",
+ "stable-diffusion-diffusers",
+ "text-to-image",
+ "diffusers",
+ "controlnet",
+ "jax-diffusers-event",
+ "diffusers-training",
+ ]
+ model_card = populate_model_card(model_card, tags=tags)
+
+ model_card.save(os.path.join(repo_folder, "README.md"))
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--controlnet_model_name_or_path",
+ type=str,
+ default=None,
+ help="Path to pretrained controlnet model or model identifier from huggingface.co/models."
+ " If not specified controlnet weights are initialized from unet.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--from_pt",
+ action="store_true",
+ help="Load the pretrained model from a PyTorch checkpoint.",
+ )
+ parser.add_argument(
+ "--controlnet_revision",
+ type=str,
+ default=None,
+ help="Revision of controlnet model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--profile_steps",
+ type=int,
+ default=0,
+ help="How many training steps to profile in the beginning.",
+ )
+ parser.add_argument(
+ "--profile_validation",
+ action="store_true",
+ help="Whether to profile the (last) validation.",
+ )
+ parser.add_argument(
+ "--profile_memory",
+ action="store_true",
+ help="Whether to dump an initial (before training loop) and a final (at program end) memory profile.",
+ )
+ parser.add_argument(
+ "--ccache",
+ type=str,
+ default=None,
+ help="Enables compilation cache.",
+ )
+ parser.add_argument(
+ "--controlnet_from_pt",
+ action="store_true",
+ help="Load the controlnet model from a PyTorch checkpoint.",
+ )
+ parser.add_argument(
+ "--tokenizer_name",
+ type=str,
+ default=None,
+ help="Pretrained tokenizer name or path if not the same as model_name",
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="runs/{timestamp}",
+ help="The output directory where the model predictions and checkpoints will be written. "
+ "Can contain placeholders: {timestamp}.",
+ )
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+ parser.add_argument("--seed", type=int, default=0, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=1, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=100)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform.",
+ )
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=5000,
+ help=("Save a checkpoint of the training state every X updates."),
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--snr_gamma",
+ type=float,
+ default=None,
+ help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
+ "More details here: https://arxiv.org/abs/2303.09556.",
+ )
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_steps",
+ type=int,
+ default=100,
+ help=("log training metric every X steps to `--report_t`"),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="wandb",
+ help=('The integration to report the results and logs to. Currently only supported platforms are `"wandb"`'),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default="no",
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose"
+ "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
+ "and an Nvidia Ampere GPU."
+ ),
+ )
+ parser.add_argument(
+ "--dataset_name",
+ type=str,
+ default=None,
+ help=(
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
+ " or to a folder containing files that 🤗 Datasets can understand."
+ ),
+ )
+ parser.add_argument("--streaming", action="store_true", help="To stream a large dataset from Hub.")
+ parser.add_argument(
+ "--dataset_config_name",
+ type=str,
+ default=None,
+ help="The config of the Dataset, leave as None if there's only one config.",
+ )
+ parser.add_argument(
+ "--train_data_dir",
+ type=str,
+ default=None,
+ help=(
+ "A folder containing the training dataset. By default it will use `load_dataset` method to load a custom dataset from the folder."
+ "Folder must contain a dataset script as described here https://huggingface.co/docs/datasets/dataset_script) ."
+ "If `--load_from_disk` flag is passed, it will use `load_from_disk` method instead. Ignored if `dataset_name` is specified."
+ ),
+ )
+ parser.add_argument(
+ "--load_from_disk",
+ action="store_true",
+ help=(
+ "If True, will load a dataset that was previously saved using `save_to_disk` from `--train_data_dir`"
+ "See more https://huggingface.co/docs/datasets/package_reference/main_classes#datasets.Dataset.load_from_disk"
+ ),
+ )
+ parser.add_argument(
+ "--image_column", type=str, default="image", help="The column of the dataset containing the target image."
+ )
+ parser.add_argument(
+ "--conditioning_image_column",
+ type=str,
+ default="conditioning_image",
+ help="The column of the dataset containing the controlnet conditioning image.",
+ )
+ parser.add_argument(
+ "--caption_column",
+ type=str,
+ default="text",
+ help="The column of the dataset containing a caption or a list of captions.",
+ )
+ parser.add_argument(
+ "--max_train_samples",
+ type=int,
+ default=None,
+ help=(
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set. Needed if `streaming` is set to True."
+ ),
+ )
+ parser.add_argument(
+ "--proportion_empty_prompts",
+ type=float,
+ default=0,
+ help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).",
+ )
+ parser.add_argument(
+ "--validation_prompt",
+ type=str,
+ default=None,
+ nargs="+",
+ help=(
+ "A set of prompts evaluated every `--validation_steps` and logged to `--report_to`."
+ " Provide either a matching number of `--validation_image`s, a single `--validation_image`"
+ " to be used with all prompts, or a single prompt that will be used with all `--validation_image`s."
+ ),
+ )
+ parser.add_argument(
+ "--validation_image",
+ type=str,
+ default=None,
+ nargs="+",
+ help=(
+ "A set of paths to the controlnet conditioning image be evaluated every `--validation_steps`"
+ " and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a"
+ " a single `--validation_prompt` to be used with all `--validation_image`s, or a single"
+ " `--validation_image` that will be used with all `--validation_prompt`s."
+ ),
+ )
+ parser.add_argument(
+ "--validation_steps",
+ type=int,
+ default=100,
+ help=(
+ "Run validation every X steps. Validation consists of running the prompt"
+ " `args.validation_prompt` and logging the images."
+ ),
+ )
+ parser.add_argument("--wandb_entity", type=str, default=None, help=("The wandb entity to use (for teams)."))
+ parser.add_argument(
+ "--tracker_project_name",
+ type=str,
+ default="train_controlnet_flax",
+ help=("The `project` argument passed to wandb"),
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps", type=int, default=1, help="Number of steps to accumulate gradients over"
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+
+ args = parser.parse_args()
+ args.output_dir = args.output_dir.replace("{timestamp}", time.strftime("%Y%m%d_%H%M%S"))
+
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ # Sanity checks
+ if args.dataset_name is None and args.train_data_dir is None:
+ raise ValueError("Need either a dataset name or a training folder.")
+ if args.dataset_name is not None and args.train_data_dir is not None:
+ raise ValueError("Specify only one of `--dataset_name` or `--train_data_dir`")
+
+ if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1:
+ raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].")
+
+ if args.validation_prompt is not None and args.validation_image is None:
+ raise ValueError("`--validation_image` must be set if `--validation_prompt` is set")
+
+ if args.validation_prompt is None and args.validation_image is not None:
+ raise ValueError("`--validation_prompt` must be set if `--validation_image` is set")
+
+ if (
+ args.validation_image is not None
+ and args.validation_prompt is not None
+ and len(args.validation_image) != 1
+ and len(args.validation_prompt) != 1
+ and len(args.validation_image) != len(args.validation_prompt)
+ ):
+ raise ValueError(
+ "Must provide either 1 `--validation_image`, 1 `--validation_prompt`,"
+ " or the same number of `--validation_prompt`s and `--validation_image`s"
+ )
+
+ # This idea comes from
+ # https://github.com/borisdayma/dalle-mini/blob/d2be512d4a6a9cda2d63ba04afc33038f98f705f/src/dalle_mini/data.py#L370
+ if args.streaming and args.max_train_samples is None:
+ raise ValueError("You must specify `max_train_samples` when using dataset streaming.")
+
+ return args
+
+
+def make_train_dataset(args, tokenizer, batch_size=None):
+ # Get the datasets: you can either provide your own training and evaluation files (see below)
+ # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
+
+ # In distributed training, the load_dataset function guarantees that only one local process can concurrently
+ # download the dataset.
+ if args.dataset_name is not None:
+ # Downloading and loading a dataset from the hub.
+ dataset = load_dataset(
+ args.dataset_name,
+ args.dataset_config_name,
+ cache_dir=args.cache_dir,
+ streaming=args.streaming,
+ )
+ else:
+ if args.train_data_dir is not None:
+ if args.load_from_disk:
+ dataset = load_from_disk(
+ args.train_data_dir,
+ )
+ else:
+ dataset = load_dataset(
+ args.train_data_dir,
+ cache_dir=args.cache_dir,
+ )
+ # See more about loading custom images at
+ # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script
+
+ # Preprocessing the datasets.
+ # We need to tokenize inputs and targets.
+ if isinstance(dataset["train"], IterableDataset):
+ column_names = next(iter(dataset["train"])).keys()
+ else:
+ column_names = dataset["train"].column_names
+
+ # 6. Get the column names for input/target.
+ if args.image_column is None:
+ image_column = column_names[0]
+ logger.info(f"image column defaulting to {image_column}")
+ else:
+ image_column = args.image_column
+ if image_column not in column_names:
+ raise ValueError(
+ f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
+ )
+
+ if args.caption_column is None:
+ caption_column = column_names[1]
+ logger.info(f"caption column defaulting to {caption_column}")
+ else:
+ caption_column = args.caption_column
+ if caption_column not in column_names:
+ raise ValueError(
+ f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
+ )
+
+ if args.conditioning_image_column is None:
+ conditioning_image_column = column_names[2]
+ logger.info(f"conditioning image column defaulting to {caption_column}")
+ else:
+ conditioning_image_column = args.conditioning_image_column
+ if conditioning_image_column not in column_names:
+ raise ValueError(
+ f"`--conditioning_image_column` value '{args.conditioning_image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
+ )
+
+ def tokenize_captions(examples, is_train=True):
+ captions = []
+ for caption in examples[caption_column]:
+ if random.random() < args.proportion_empty_prompts:
+ captions.append("")
+ elif isinstance(caption, str):
+ captions.append(caption)
+ elif isinstance(caption, (list, np.ndarray)):
+ # take a random caption if there are multiple
+ captions.append(random.choice(caption) if is_train else caption[0])
+ else:
+ raise ValueError(
+ f"Caption column `{caption_column}` should contain either strings or lists of strings."
+ )
+ inputs = tokenizer(
+ captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
+ )
+ return inputs.input_ids
+
+ image_transforms = transforms.Compose(
+ [
+ transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
+ transforms.CenterCrop(args.resolution),
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+ )
+
+ conditioning_image_transforms = transforms.Compose(
+ [
+ transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
+ transforms.CenterCrop(args.resolution),
+ transforms.ToTensor(),
+ ]
+ )
+
+ def preprocess_train(examples):
+ images = [image.convert("RGB") for image in examples[image_column]]
+ images = [image_transforms(image) for image in images]
+
+ conditioning_images = [image.convert("RGB") for image in examples[conditioning_image_column]]
+ conditioning_images = [conditioning_image_transforms(image) for image in conditioning_images]
+
+ examples["pixel_values"] = images
+ examples["conditioning_pixel_values"] = conditioning_images
+ examples["input_ids"] = tokenize_captions(examples)
+
+ return examples
+
+ if jax.process_index() == 0:
+ if args.max_train_samples is not None:
+ if args.streaming:
+ dataset["train"] = dataset["train"].shuffle(seed=args.seed).take(args.max_train_samples)
+ else:
+ dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
+ # Set the training transforms
+ if args.streaming:
+ train_dataset = dataset["train"].map(
+ preprocess_train,
+ batched=True,
+ batch_size=batch_size,
+ remove_columns=list(dataset["train"].features.keys()),
+ )
+ else:
+ train_dataset = dataset["train"].with_transform(preprocess_train)
+
+ return train_dataset
+
+
+def collate_fn(examples):
+ pixel_values = torch.stack([example["pixel_values"] for example in examples])
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
+
+ conditioning_pixel_values = torch.stack([example["conditioning_pixel_values"] for example in examples])
+ conditioning_pixel_values = conditioning_pixel_values.to(memory_format=torch.contiguous_format).float()
+
+ input_ids = torch.stack([example["input_ids"] for example in examples])
+
+ batch = {
+ "pixel_values": pixel_values,
+ "conditioning_pixel_values": conditioning_pixel_values,
+ "input_ids": input_ids,
+ }
+ batch = {k: v.numpy() for k, v in batch.items()}
+ return batch
+
+
+def get_params_to_save(params):
+ return jax.device_get(jax.tree_util.tree_map(lambda x: x[0], params))
+
+
+def main():
+ args = parse_args()
+
+ if args.report_to == "wandb" and args.hub_token is not None:
+ raise ValueError(
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
+ " Please use `huggingface-cli login` to authenticate with the Hub."
+ )
+
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ # Setup logging, we only want one process per machine to log things on the screen.
+ logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
+ if jax.process_index() == 0:
+ transformers.utils.logging.set_verbosity_info()
+ else:
+ transformers.utils.logging.set_verbosity_error()
+
+ # wandb init
+ if jax.process_index() == 0 and args.report_to == "wandb":
+ wandb.init(
+ entity=args.wandb_entity,
+ project=args.tracker_project_name,
+ job_type="train",
+ config=args,
+ )
+
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ rng = jax.random.PRNGKey(0)
+
+ # Handle the repository creation
+ if jax.process_index() == 0:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
+ ).repo_id
+
+ # Load the tokenizer and add the placeholder token as a additional special token
+ if args.tokenizer_name:
+ tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
+ elif args.pretrained_model_name_or_path:
+ tokenizer = CLIPTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
+ )
+ else:
+ raise NotImplementedError("No tokenizer specified!")
+
+ # Get the datasets: you can either provide your own training and evaluation files (see below)
+ total_train_batch_size = args.train_batch_size * jax.local_device_count() * args.gradient_accumulation_steps
+ train_dataset = make_train_dataset(args, tokenizer, batch_size=total_train_batch_size)
+
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset,
+ shuffle=not args.streaming,
+ collate_fn=collate_fn,
+ batch_size=total_train_batch_size,
+ num_workers=args.dataloader_num_workers,
+ drop_last=True,
+ )
+
+ weight_dtype = jnp.float32
+ if args.mixed_precision == "fp16":
+ weight_dtype = jnp.float16
+ elif args.mixed_precision == "bf16":
+ weight_dtype = jnp.bfloat16
+
+ # Load models and create wrapper for stable diffusion
+ text_encoder = FlaxCLIPTextModel.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="text_encoder",
+ dtype=weight_dtype,
+ revision=args.revision,
+ from_pt=args.from_pt,
+ )
+ vae, vae_params = FlaxAutoencoderKL.from_pretrained(
+ args.pretrained_model_name_or_path,
+ revision=args.revision,
+ subfolder="vae",
+ dtype=weight_dtype,
+ from_pt=args.from_pt,
+ )
+ unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="unet",
+ dtype=weight_dtype,
+ revision=args.revision,
+ from_pt=args.from_pt,
+ )
+
+ if args.controlnet_model_name_or_path:
+ logger.info("Loading existing controlnet weights")
+ controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
+ args.controlnet_model_name_or_path,
+ revision=args.controlnet_revision,
+ from_pt=args.controlnet_from_pt,
+ dtype=jnp.float32,
+ )
+ else:
+ logger.info("Initializing controlnet weights from unet")
+ rng, rng_params = jax.random.split(rng)
+
+ controlnet = FlaxControlNetModel(
+ in_channels=unet.config.in_channels,
+ down_block_types=unet.config.down_block_types,
+ only_cross_attention=unet.config.only_cross_attention,
+ block_out_channels=unet.config.block_out_channels,
+ layers_per_block=unet.config.layers_per_block,
+ attention_head_dim=unet.config.attention_head_dim,
+ cross_attention_dim=unet.config.cross_attention_dim,
+ use_linear_projection=unet.config.use_linear_projection,
+ flip_sin_to_cos=unet.config.flip_sin_to_cos,
+ freq_shift=unet.config.freq_shift,
+ )
+ controlnet_params = controlnet.init_weights(rng=rng_params)
+ controlnet_params = unfreeze(controlnet_params)
+ for key in [
+ "conv_in",
+ "time_embedding",
+ "down_blocks_0",
+ "down_blocks_1",
+ "down_blocks_2",
+ "down_blocks_3",
+ "mid_block",
+ ]:
+ controlnet_params[key] = unet_params[key]
+
+ pipeline, pipeline_params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ tokenizer=tokenizer,
+ controlnet=controlnet,
+ safety_checker=None,
+ dtype=weight_dtype,
+ revision=args.revision,
+ from_pt=args.from_pt,
+ )
+ pipeline_params = jax_utils.replicate(pipeline_params)
+
+ # Optimization
+ if args.scale_lr:
+ args.learning_rate = args.learning_rate * total_train_batch_size
+
+ constant_scheduler = optax.constant_schedule(args.learning_rate)
+
+ adamw = optax.adamw(
+ learning_rate=constant_scheduler,
+ b1=args.adam_beta1,
+ b2=args.adam_beta2,
+ eps=args.adam_epsilon,
+ weight_decay=args.adam_weight_decay,
+ )
+
+ optimizer = optax.chain(
+ optax.clip_by_global_norm(args.max_grad_norm),
+ adamw,
+ )
+
+ state = train_state.TrainState.create(apply_fn=controlnet.__call__, params=controlnet_params, tx=optimizer)
+
+ noise_scheduler, noise_scheduler_state = FlaxDDPMScheduler.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="scheduler"
+ )
+
+ # Initialize our training
+ validation_rng, train_rngs = jax.random.split(rng)
+ train_rngs = jax.random.split(train_rngs, jax.local_device_count())
+
+ def compute_snr(timesteps):
+ """
+ Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
+ """
+ alphas_cumprod = noise_scheduler_state.common.alphas_cumprod
+ sqrt_alphas_cumprod = alphas_cumprod**0.5
+ sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
+
+ alpha = sqrt_alphas_cumprod[timesteps]
+ sigma = sqrt_one_minus_alphas_cumprod[timesteps]
+ # Compute SNR.
+ snr = (alpha / sigma) ** 2
+ return snr
+
+ def train_step(state, unet_params, text_encoder_params, vae_params, batch, train_rng):
+ # reshape batch, add grad_step_dim if gradient_accumulation_steps > 1
+ if args.gradient_accumulation_steps > 1:
+ grad_steps = args.gradient_accumulation_steps
+ batch = jax.tree_map(lambda x: x.reshape((grad_steps, x.shape[0] // grad_steps) + x.shape[1:]), batch)
+
+ def compute_loss(params, minibatch, sample_rng):
+ # Convert images to latent space
+ vae_outputs = vae.apply(
+ {"params": vae_params}, minibatch["pixel_values"], deterministic=True, method=vae.encode
+ )
+ latents = vae_outputs.latent_dist.sample(sample_rng)
+ # (NHWC) -> (NCHW)
+ latents = jnp.transpose(latents, (0, 3, 1, 2))
+ latents = latents * vae.config.scaling_factor
+
+ # Sample noise that we'll add to the latents
+ noise_rng, timestep_rng = jax.random.split(sample_rng)
+ noise = jax.random.normal(noise_rng, latents.shape)
+ # Sample a random timestep for each image
+ bsz = latents.shape[0]
+ timesteps = jax.random.randint(
+ timestep_rng,
+ (bsz,),
+ 0,
+ noise_scheduler.config.num_train_timesteps,
+ )
+
+ # Add noise to the latents according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ noisy_latents = noise_scheduler.add_noise(noise_scheduler_state, latents, noise, timesteps)
+
+ # Get the text embedding for conditioning
+ encoder_hidden_states = text_encoder(
+ minibatch["input_ids"],
+ params=text_encoder_params,
+ train=False,
+ )[0]
+
+ controlnet_cond = minibatch["conditioning_pixel_values"]
+
+ # Predict the noise residual and compute loss
+ down_block_res_samples, mid_block_res_sample = controlnet.apply(
+ {"params": params},
+ noisy_latents,
+ timesteps,
+ encoder_hidden_states,
+ controlnet_cond,
+ train=True,
+ return_dict=False,
+ )
+
+ model_pred = unet.apply(
+ {"params": unet_params},
+ noisy_latents,
+ timesteps,
+ encoder_hidden_states,
+ down_block_additional_residuals=down_block_res_samples,
+ mid_block_additional_residual=mid_block_res_sample,
+ ).sample
+
+ # Get the target for loss depending on the prediction type
+ if noise_scheduler.config.prediction_type == "epsilon":
+ target = noise
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ target = noise_scheduler.get_velocity(noise_scheduler_state, latents, noise, timesteps)
+ else:
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
+
+ loss = (target - model_pred) ** 2
+
+ if args.snr_gamma is not None:
+ snr = jnp.array(compute_snr(timesteps))
+ snr_loss_weights = jnp.where(snr < args.snr_gamma, snr, jnp.ones_like(snr) * args.snr_gamma)
+ if noise_scheduler.config.prediction_type == "epsilon":
+ snr_loss_weights = snr_loss_weights / snr
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ snr_loss_weights = snr_loss_weights / (snr + 1)
+
+ loss = loss * snr_loss_weights
+
+ loss = loss.mean()
+
+ return loss
+
+ grad_fn = jax.value_and_grad(compute_loss)
+
+ # get a minibatch (one gradient accumulation slice)
+ def get_minibatch(batch, grad_idx):
+ return jax.tree_util.tree_map(
+ lambda x: jax.lax.dynamic_index_in_dim(x, grad_idx, keepdims=False),
+ batch,
+ )
+
+ def loss_and_grad(grad_idx, train_rng):
+ # create minibatch for the grad step
+ minibatch = get_minibatch(batch, grad_idx) if grad_idx is not None else batch
+ sample_rng, train_rng = jax.random.split(train_rng, 2)
+ loss, grad = grad_fn(state.params, minibatch, sample_rng)
+ return loss, grad, train_rng
+
+ if args.gradient_accumulation_steps == 1:
+ loss, grad, new_train_rng = loss_and_grad(None, train_rng)
+ else:
+ init_loss_grad_rng = (
+ 0.0, # initial value for cumul_loss
+ jax.tree_map(jnp.zeros_like, state.params), # initial value for cumul_grad
+ train_rng, # initial value for train_rng
+ )
+
+ def cumul_grad_step(grad_idx, loss_grad_rng):
+ cumul_loss, cumul_grad, train_rng = loss_grad_rng
+ loss, grad, new_train_rng = loss_and_grad(grad_idx, train_rng)
+ cumul_loss, cumul_grad = jax.tree_map(jnp.add, (cumul_loss, cumul_grad), (loss, grad))
+ return cumul_loss, cumul_grad, new_train_rng
+
+ loss, grad, new_train_rng = jax.lax.fori_loop(
+ 0,
+ args.gradient_accumulation_steps,
+ cumul_grad_step,
+ init_loss_grad_rng,
+ )
+ loss, grad = jax.tree_map(lambda x: x / args.gradient_accumulation_steps, (loss, grad))
+
+ grad = jax.lax.pmean(grad, "batch")
+
+ new_state = state.apply_gradients(grads=grad)
+
+ metrics = {"loss": loss}
+ metrics = jax.lax.pmean(metrics, axis_name="batch")
+
+ def l2(xs):
+ return jnp.sqrt(sum([jnp.vdot(x, x) for x in jax.tree_util.tree_leaves(xs)]))
+
+ metrics["l2_grads"] = l2(jax.tree_util.tree_leaves(grad))
+
+ return new_state, metrics, new_train_rng
+
+ # Create parallel version of the train step
+ p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
+
+ # Replicate the train state on each device
+ state = jax_utils.replicate(state)
+ unet_params = jax_utils.replicate(unet_params)
+ text_encoder_params = jax_utils.replicate(text_encoder.params)
+ vae_params = jax_utils.replicate(vae_params)
+
+ # Train!
+ if args.streaming:
+ dataset_length = args.max_train_samples
+ else:
+ dataset_length = len(train_dataloader)
+ num_update_steps_per_epoch = math.ceil(dataset_length / args.gradient_accumulation_steps)
+
+ # Scheduler and math around the number of training steps.
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {args.max_train_samples if args.streaming else len(train_dataset)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel & distributed) = {total_train_batch_size}")
+ logger.info(f" Total optimization steps = {args.num_train_epochs * num_update_steps_per_epoch}")
+
+ if jax.process_index() == 0 and args.report_to == "wandb":
+ wandb.define_metric("*", step_metric="train/step")
+ wandb.define_metric("train/step", step_metric="walltime")
+ wandb.config.update(
+ {
+ "num_train_examples": args.max_train_samples if args.streaming else len(train_dataset),
+ "total_train_batch_size": total_train_batch_size,
+ "total_optimization_step": args.num_train_epochs * num_update_steps_per_epoch,
+ "num_devices": jax.device_count(),
+ "controlnet_params": sum(np.prod(x.shape) for x in jax.tree_util.tree_leaves(state.params)),
+ }
+ )
+
+ global_step = step0 = 0
+ epochs = tqdm(
+ range(args.num_train_epochs),
+ desc="Epoch ... ",
+ position=0,
+ disable=jax.process_index() > 0,
+ )
+ if args.profile_memory:
+ jax.profiler.save_device_memory_profile(os.path.join(args.output_dir, "memory_initial.prof"))
+ t00 = t0 = time.monotonic()
+ for epoch in epochs:
+ # ======================== Training ================================
+
+ train_metrics = []
+ train_metric = None
+
+ steps_per_epoch = (
+ args.max_train_samples // total_train_batch_size
+ if args.streaming or args.max_train_samples
+ else len(train_dataset) // total_train_batch_size
+ )
+ train_step_progress_bar = tqdm(
+ total=steps_per_epoch,
+ desc="Training...",
+ position=1,
+ leave=False,
+ disable=jax.process_index() > 0,
+ )
+ # train
+ for batch in train_dataloader:
+ if args.profile_steps and global_step == 1:
+ train_metric["loss"].block_until_ready()
+ jax.profiler.start_trace(args.output_dir)
+ if args.profile_steps and global_step == 1 + args.profile_steps:
+ train_metric["loss"].block_until_ready()
+ jax.profiler.stop_trace()
+
+ batch = shard(batch)
+ with jax.profiler.StepTraceAnnotation("train", step_num=global_step):
+ state, train_metric, train_rngs = p_train_step(
+ state, unet_params, text_encoder_params, vae_params, batch, train_rngs
+ )
+ train_metrics.append(train_metric)
+
+ train_step_progress_bar.update(1)
+
+ global_step += 1
+ if global_step >= args.max_train_steps:
+ break
+
+ if (
+ args.validation_prompt is not None
+ and global_step % args.validation_steps == 0
+ and jax.process_index() == 0
+ ):
+ _ = log_validation(
+ pipeline, pipeline_params, state.params, tokenizer, args, validation_rng, weight_dtype
+ )
+
+ if global_step % args.logging_steps == 0 and jax.process_index() == 0:
+ if args.report_to == "wandb":
+ train_metrics = jax_utils.unreplicate(train_metrics)
+ train_metrics = jax.tree_util.tree_map(lambda *m: jnp.array(m).mean(), *train_metrics)
+ wandb.log(
+ {
+ "walltime": time.monotonic() - t00,
+ "train/step": global_step,
+ "train/epoch": global_step / dataset_length,
+ "train/steps_per_sec": (global_step - step0) / (time.monotonic() - t0),
+ **{f"train/{k}": v for k, v in train_metrics.items()},
+ }
+ )
+ t0, step0 = time.monotonic(), global_step
+ train_metrics = []
+ if global_step % args.checkpointing_steps == 0 and jax.process_index() == 0:
+ controlnet.save_pretrained(
+ f"{args.output_dir}/{global_step}",
+ params=get_params_to_save(state.params),
+ )
+
+ train_metric = jax_utils.unreplicate(train_metric)
+ train_step_progress_bar.close()
+ epochs.write(f"Epoch... ({epoch + 1}/{args.num_train_epochs} | Loss: {train_metric['loss']})")
+
+ # Final validation & store model.
+ if jax.process_index() == 0:
+ if args.validation_prompt is not None:
+ if args.profile_validation:
+ jax.profiler.start_trace(args.output_dir)
+ image_logs = log_validation(
+ pipeline, pipeline_params, state.params, tokenizer, args, validation_rng, weight_dtype
+ )
+ if args.profile_validation:
+ jax.profiler.stop_trace()
+ else:
+ image_logs = None
+
+ controlnet.save_pretrained(
+ args.output_dir,
+ params=get_params_to_save(state.params),
+ )
+
+ if args.push_to_hub:
+ save_model_card(
+ repo_id,
+ image_logs=image_logs,
+ base_model=args.pretrained_model_name_or_path,
+ repo_folder=args.output_dir,
+ )
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ if args.profile_memory:
+ jax.profiler.save_device_memory_profile(os.path.join(args.output_dir, "memory_final.prof"))
+ logger.info("Finished training.")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/diffusers/examples/controlnet/train_controlnet_sd3.py b/diffusers/examples/controlnet/train_controlnet_sd3.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b255c501d99aadd39015d99056966c7bfa18a6f
--- /dev/null
+++ b/diffusers/examples/controlnet/train_controlnet_sd3.py
@@ -0,0 +1,1416 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+import argparse
+import contextlib
+import copy
+import functools
+import logging
+import math
+import os
+import random
+import shutil
+from pathlib import Path
+
+import accelerate
+import numpy as np
+import torch
+import torch.utils.checkpoint
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
+from datasets import load_dataset
+from huggingface_hub import create_repo, upload_folder
+from packaging import version
+from PIL import Image
+from torchvision import transforms
+from tqdm.auto import tqdm
+from transformers import CLIPTokenizer, PretrainedConfig, T5TokenizerFast
+
+import diffusers
+from diffusers import (
+ AutoencoderKL,
+ FlowMatchEulerDiscreteScheduler,
+ SD3ControlNetModel,
+ SD3Transformer2DModel,
+ StableDiffusion3ControlNetPipeline,
+)
+from diffusers.optimization import get_scheduler
+from diffusers.training_utils import (
+ clear_objs_and_retain_memory,
+ compute_density_for_timestep_sampling,
+ compute_loss_weighting_for_sd3,
+)
+from diffusers.utils import check_min_version, is_wandb_available
+from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
+from diffusers.utils.torch_utils import is_compiled_module
+
+
+if is_wandb_available():
+ import wandb
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.30.0.dev0")
+
+logger = get_logger(__name__)
+
+
+def image_grid(imgs, rows, cols):
+ assert len(imgs) == rows * cols
+
+ w, h = imgs[0].size
+ grid = Image.new("RGB", size=(cols * w, rows * h))
+
+ for i, img in enumerate(imgs):
+ grid.paste(img, box=(i % cols * w, i // cols * h))
+ return grid
+
+
+def log_validation(controlnet, args, accelerator, weight_dtype, step, is_final_validation=False):
+ logger.info("Running validation... ")
+
+ if not is_final_validation:
+ controlnet = accelerator.unwrap_model(controlnet)
+ else:
+ controlnet = SD3ControlNetModel.from_pretrained(args.output_dir, torch_dtype=weight_dtype)
+
+ pipeline = StableDiffusion3ControlNetPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ controlnet=controlnet,
+ safety_checker=None,
+ revision=args.revision,
+ variant=args.variant,
+ torch_dtype=weight_dtype,
+ )
+ pipeline = pipeline.to(torch.device(accelerator.device))
+ pipeline.set_progress_bar_config(disable=True)
+
+ if args.seed is None:
+ generator = None
+ else:
+ generator = torch.manual_seed(args.seed)
+
+ if len(args.validation_image) == len(args.validation_prompt):
+ validation_images = args.validation_image
+ validation_prompts = args.validation_prompt
+ elif len(args.validation_image) == 1:
+ validation_images = args.validation_image * len(args.validation_prompt)
+ validation_prompts = args.validation_prompt
+ elif len(args.validation_prompt) == 1:
+ validation_images = args.validation_image
+ validation_prompts = args.validation_prompt * len(args.validation_image)
+ else:
+ raise ValueError(
+ "number of `args.validation_image` and `args.validation_prompt` should be checked in `parse_args`"
+ )
+
+ image_logs = []
+ inference_ctx = contextlib.nullcontext() if is_final_validation else torch.autocast(accelerator.device.type)
+
+ for validation_prompt, validation_image in zip(validation_prompts, validation_images):
+ validation_image = Image.open(validation_image).convert("RGB")
+
+ images = []
+
+ for _ in range(args.num_validation_images):
+ with inference_ctx:
+ image = pipeline(
+ validation_prompt, control_image=validation_image, num_inference_steps=20, generator=generator
+ ).images[0]
+
+ images.append(image)
+
+ image_logs.append(
+ {"validation_image": validation_image, "images": images, "validation_prompt": validation_prompt}
+ )
+
+ tracker_key = "test" if is_final_validation else "validation"
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ for log in image_logs:
+ images = log["images"]
+ validation_prompt = log["validation_prompt"]
+ validation_image = log["validation_image"]
+
+ tracker.writer.add_image(
+ "Controlnet conditioning", np.asarray([validation_image]), step, dataformats="NHWC"
+ )
+
+ formatted_images = []
+ for image in images:
+ formatted_images.append(np.asarray(image))
+
+ formatted_images = np.stack(formatted_images)
+
+ tracker.writer.add_images(validation_prompt, formatted_images, step, dataformats="NHWC")
+ elif tracker.name == "wandb":
+ formatted_images = []
+
+ for log in image_logs:
+ images = log["images"]
+ validation_prompt = log["validation_prompt"]
+ validation_image = log["validation_image"]
+
+ formatted_images.append(wandb.Image(validation_image, caption="Controlnet conditioning"))
+
+ for image in images:
+ image = wandb.Image(image, caption=validation_prompt)
+ formatted_images.append(image)
+
+ tracker.log({tracker_key: formatted_images})
+ else:
+ logger.warning(f"image logging not implemented for {tracker.name}")
+
+ clear_objs_and_retain_memory(pipeline)
+
+ if not is_final_validation:
+ controlnet.to(accelerator.device)
+
+ return image_logs
+
+
+# Copied from dreambooth sd3 example
+def load_text_encoders(class_one, class_two, class_three):
+ text_encoder_one = class_one.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
+ )
+ text_encoder_two = class_two.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant
+ )
+ text_encoder_three = class_three.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder_3", revision=args.revision, variant=args.variant
+ )
+ return text_encoder_one, text_encoder_two, text_encoder_three
+
+
+# Copied from dreambooth sd3 example
+def import_model_class_from_model_name_or_path(
+ pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
+):
+ text_encoder_config = PretrainedConfig.from_pretrained(
+ pretrained_model_name_or_path, subfolder=subfolder, revision=revision
+ )
+ model_class = text_encoder_config.architectures[0]
+ if model_class == "CLIPTextModelWithProjection":
+ from transformers import CLIPTextModelWithProjection
+
+ return CLIPTextModelWithProjection
+ elif model_class == "T5EncoderModel":
+ from transformers import T5EncoderModel
+
+ return T5EncoderModel
+ else:
+ raise ValueError(f"{model_class} is not supported.")
+
+
+def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=None):
+ img_str = ""
+ if image_logs is not None:
+ img_str = "You can find some example images below.\n\n"
+ for i, log in enumerate(image_logs):
+ images = log["images"]
+ validation_prompt = log["validation_prompt"]
+ validation_image = log["validation_image"]
+ validation_image.save(os.path.join(repo_folder, "image_control.png"))
+ img_str += f"prompt: {validation_prompt}\n"
+ images = [validation_image] + images
+ image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images_{i}.png"))
+ img_str += f"\n"
+
+ model_description = f"""
+# SD3 controlnet-{repo_id}
+
+These are controlnet weights trained on {base_model} with new type of conditioning.
+The weights were trained using [ControlNet](https://github.com/lllyasviel/ControlNet) with the [SD3 diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/controlnet/README_sd3.md).
+{img_str}
+
+Please adhere to the licensing terms as described `[here](https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/LICENSE)`.
+"""
+ model_card = load_or_create_model_card(
+ repo_id_or_path=repo_id,
+ from_training=True,
+ license="openrail++",
+ base_model=base_model,
+ model_description=model_description,
+ inference=True,
+ )
+
+ tags = [
+ "text-to-image",
+ "diffusers-training",
+ "diffusers",
+ "sd3",
+ "sd3-diffusers",
+ "controlnet",
+ ]
+ model_card = populate_model_card(model_card, tags=tags)
+
+ model_card.save(os.path.join(repo_folder, "README.md"))
+
+
+def parse_args(input_args=None):
+ parser = argparse.ArgumentParser(description="Simple example of a ControlNet training script.")
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--controlnet_model_name_or_path",
+ type=str,
+ default=None,
+ help="Path to pretrained controlnet model or model identifier from huggingface.co/models."
+ " If not specified controlnet weights are initialized from unet.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--variant",
+ type=str,
+ default=None,
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="controlnet-model",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=1)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. "
+ "In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference."
+ "Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components."
+ "See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step"
+ "instructions."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=5e-6,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--lr_num_cycles",
+ type=int,
+ default=1,
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
+ )
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
+ parser.add_argument(
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
+ )
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ parser.add_argument(
+ "--weighting_scheme",
+ type=str,
+ default="logit_normal",
+ choices=["sigma_sqrt", "logit_normal", "mode", "cosmap"],
+ )
+ parser.add_argument(
+ "--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme."
+ )
+ parser.add_argument(
+ "--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme."
+ )
+ parser.add_argument(
+ "--mode_scale",
+ type=float,
+ default=1.29,
+ help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.",
+ )
+ parser.add_argument(
+ "--precondition_outputs",
+ type=int,
+ default=1,
+ help="Flag indicating if we are preconditioning the model outputs or not as done in EDM. This affects how "
+ "model `target` is calculated.",
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument(
+ "--set_grads_to_none",
+ action="store_true",
+ help=(
+ "Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain"
+ " behaviors, so disable this argument if it causes any problems. More info:"
+ " https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html"
+ ),
+ )
+ parser.add_argument(
+ "--dataset_name",
+ type=str,
+ default=None,
+ help=(
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
+ " or to a folder containing files that 🤗 Datasets can understand."
+ ),
+ )
+ parser.add_argument(
+ "--dataset_config_name",
+ type=str,
+ default=None,
+ help="The config of the Dataset, leave as None if there's only one config.",
+ )
+ parser.add_argument(
+ "--train_data_dir",
+ type=str,
+ default=None,
+ help=(
+ "A folder containing the training data. Folder contents must follow the structure described in"
+ " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
+ " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
+ ),
+ )
+ parser.add_argument(
+ "--image_column", type=str, default="image", help="The column of the dataset containing the target image."
+ )
+ parser.add_argument(
+ "--conditioning_image_column",
+ type=str,
+ default="conditioning_image",
+ help="The column of the dataset containing the controlnet conditioning image.",
+ )
+ parser.add_argument(
+ "--caption_column",
+ type=str,
+ default="text",
+ help="The column of the dataset containing a caption or a list of captions.",
+ )
+ parser.add_argument(
+ "--max_train_samples",
+ type=int,
+ default=None,
+ help=(
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ ),
+ )
+ parser.add_argument(
+ "--proportion_empty_prompts",
+ type=float,
+ default=0,
+ help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).",
+ )
+ parser.add_argument(
+ "--max_sequence_length",
+ type=int,
+ default=77,
+ help="Maximum sequence length to use with with the T5 text encoder",
+ )
+ parser.add_argument(
+ "--validation_prompt",
+ type=str,
+ default=None,
+ nargs="+",
+ help=(
+ "A set of prompts evaluated every `--validation_steps` and logged to `--report_to`."
+ " Provide either a matching number of `--validation_image`s, a single `--validation_image`"
+ " to be used with all prompts, or a single prompt that will be used with all `--validation_image`s."
+ ),
+ )
+ parser.add_argument(
+ "--validation_image",
+ type=str,
+ default=None,
+ nargs="+",
+ help=(
+ "A set of paths to the controlnet conditioning image be evaluated every `--validation_steps`"
+ " and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a"
+ " a single `--validation_prompt` to be used with all `--validation_image`s, or a single"
+ " `--validation_image` that will be used with all `--validation_prompt`s."
+ ),
+ )
+ parser.add_argument(
+ "--num_validation_images",
+ type=int,
+ default=4,
+ help="Number of images to be generated for each `--validation_image`, `--validation_prompt` pair",
+ )
+ parser.add_argument(
+ "--validation_steps",
+ type=int,
+ default=100,
+ help=(
+ "Run validation every X steps. Validation consists of running the prompt"
+ " `args.validation_prompt` multiple times: `args.num_validation_images`"
+ " and logging the images."
+ ),
+ )
+ parser.add_argument(
+ "--tracker_project_name",
+ type=str,
+ default="train_controlnet",
+ help=(
+ "The `project_name` argument passed to Accelerator.init_trackers for"
+ " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
+ ),
+ )
+
+ if input_args is not None:
+ args = parser.parse_args(input_args)
+ else:
+ args = parser.parse_args()
+
+ if args.dataset_name is None and args.train_data_dir is None:
+ raise ValueError("Specify either `--dataset_name` or `--train_data_dir`")
+
+ if args.dataset_name is not None and args.train_data_dir is not None:
+ raise ValueError("Specify only one of `--dataset_name` or `--train_data_dir`")
+
+ if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1:
+ raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].")
+
+ if args.validation_prompt is not None and args.validation_image is None:
+ raise ValueError("`--validation_image` must be set if `--validation_prompt` is set")
+
+ if args.validation_prompt is None and args.validation_image is not None:
+ raise ValueError("`--validation_prompt` must be set if `--validation_image` is set")
+
+ if (
+ args.validation_image is not None
+ and args.validation_prompt is not None
+ and len(args.validation_image) != 1
+ and len(args.validation_prompt) != 1
+ and len(args.validation_image) != len(args.validation_prompt)
+ ):
+ raise ValueError(
+ "Must provide either 1 `--validation_image`, 1 `--validation_prompt`,"
+ " or the same number of `--validation_prompt`s and `--validation_image`s"
+ )
+
+ if args.resolution % 8 != 0:
+ raise ValueError(
+ "`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the controlnet encoder."
+ )
+
+ return args
+
+
+def make_train_dataset(args, tokenizer_one, tokenizer_two, tokenizer_three, accelerator):
+ # Get the datasets: you can either provide your own training and evaluation files (see below)
+ # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
+
+ # In distributed training, the load_dataset function guarantees that only one local process can concurrently
+ # download the dataset.
+ if args.dataset_name is not None:
+ # Downloading and loading a dataset from the hub.
+ dataset = load_dataset(
+ args.dataset_name,
+ args.dataset_config_name,
+ cache_dir=args.cache_dir,
+ )
+ else:
+ if args.train_data_dir is not None:
+ dataset = load_dataset(
+ args.train_data_dir,
+ cache_dir=args.cache_dir,
+ )
+ # See more about loading custom images at
+ # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script
+
+ # Preprocessing the datasets.
+ # We need to tokenize inputs and targets.
+ column_names = dataset["train"].column_names
+
+ # 6. Get the column names for input/target.
+ if args.image_column is None:
+ image_column = column_names[0]
+ logger.info(f"image column defaulting to {image_column}")
+ else:
+ image_column = args.image_column
+ if image_column not in column_names:
+ raise ValueError(
+ f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
+ )
+
+ if args.caption_column is None:
+ caption_column = column_names[1]
+ logger.info(f"caption column defaulting to {caption_column}")
+ else:
+ caption_column = args.caption_column
+ if caption_column not in column_names:
+ raise ValueError(
+ f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
+ )
+
+ if args.conditioning_image_column is None:
+ conditioning_image_column = column_names[2]
+ logger.info(f"conditioning image column defaulting to {conditioning_image_column}")
+ else:
+ conditioning_image_column = args.conditioning_image_column
+ if conditioning_image_column not in column_names:
+ raise ValueError(
+ f"`--conditioning_image_column` value '{args.conditioning_image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
+ )
+
+ def process_captions(examples, is_train=True):
+ captions = []
+ for caption in examples[caption_column]:
+ if random.random() < args.proportion_empty_prompts:
+ captions.append("")
+ elif isinstance(caption, str):
+ captions.append(caption)
+ elif isinstance(caption, (list, np.ndarray)):
+ # take a random caption if there are multiple
+ captions.append(random.choice(caption) if is_train else caption[0])
+ else:
+ raise ValueError(
+ f"Caption column `{caption_column}` should contain either strings or lists of strings."
+ )
+ return captions
+
+ image_transforms = transforms.Compose(
+ [
+ transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
+ transforms.CenterCrop(args.resolution),
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+ )
+
+ conditioning_image_transforms = transforms.Compose(
+ [
+ transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
+ transforms.CenterCrop(args.resolution),
+ transforms.ToTensor(),
+ ]
+ )
+
+ def preprocess_train(examples):
+ images = [image.convert("RGB") for image in examples[image_column]]
+ images = [image_transforms(image) for image in images]
+
+ conditioning_images = [image.convert("RGB") for image in examples[conditioning_image_column]]
+ conditioning_images = [conditioning_image_transforms(image) for image in conditioning_images]
+
+ examples["pixel_values"] = images
+ examples["conditioning_pixel_values"] = conditioning_images
+ examples["prompts"] = process_captions(examples)
+
+ return examples
+
+ with accelerator.main_process_first():
+ if args.max_train_samples is not None:
+ dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
+ # Set the training transforms
+ train_dataset = dataset["train"].with_transform(preprocess_train)
+
+ return train_dataset
+
+
+def collate_fn(examples):
+ pixel_values = torch.stack([example["pixel_values"] for example in examples])
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
+
+ conditioning_pixel_values = torch.stack([example["conditioning_pixel_values"] for example in examples])
+ conditioning_pixel_values = conditioning_pixel_values.to(memory_format=torch.contiguous_format).float()
+
+ prompt_embeds = torch.stack([torch.tensor(example["prompt_embeds"]) for example in examples])
+ pooled_prompt_embeds = torch.stack([torch.tensor(example["pooled_prompt_embeds"]) for example in examples])
+
+ return {
+ "pixel_values": pixel_values,
+ "conditioning_pixel_values": conditioning_pixel_values,
+ "prompt_embeds": prompt_embeds,
+ "pooled_prompt_embeds": pooled_prompt_embeds,
+ }
+
+
+# Copied from dreambooth sd3 example
+def _encode_prompt_with_t5(
+ text_encoder,
+ tokenizer,
+ max_sequence_length,
+ prompt=None,
+ num_images_per_prompt=1,
+ device=None,
+):
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ text_inputs = tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ add_special_tokens=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ prompt_embeds = text_encoder(text_input_ids.to(device))[0]
+
+ dtype = text_encoder.dtype
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ _, seq_len, _ = prompt_embeds.shape
+
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ return prompt_embeds
+
+
+# Copied from dreambooth sd3 example
+def _encode_prompt_with_clip(
+ text_encoder,
+ tokenizer,
+ prompt: str,
+ device=None,
+ num_images_per_prompt: int = 1,
+):
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ text_inputs = tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=77,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ text_input_ids = text_inputs.input_ids
+ prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
+
+ pooled_prompt_embeds = prompt_embeds[0]
+ prompt_embeds = prompt_embeds.hidden_states[-2]
+ prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device)
+
+ _, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ return prompt_embeds, pooled_prompt_embeds
+
+
+# Copied from dreambooth sd3 example
+def encode_prompt(
+ text_encoders,
+ tokenizers,
+ prompt: str,
+ max_sequence_length,
+ device=None,
+ num_images_per_prompt: int = 1,
+):
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ clip_tokenizers = tokenizers[:2]
+ clip_text_encoders = text_encoders[:2]
+
+ clip_prompt_embeds_list = []
+ clip_pooled_prompt_embeds_list = []
+ for tokenizer, text_encoder in zip(clip_tokenizers, clip_text_encoders):
+ prompt_embeds, pooled_prompt_embeds = _encode_prompt_with_clip(
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ prompt=prompt,
+ device=device if device is not None else text_encoder.device,
+ num_images_per_prompt=num_images_per_prompt,
+ )
+ clip_prompt_embeds_list.append(prompt_embeds)
+ clip_pooled_prompt_embeds_list.append(pooled_prompt_embeds)
+
+ clip_prompt_embeds = torch.cat(clip_prompt_embeds_list, dim=-1)
+ pooled_prompt_embeds = torch.cat(clip_pooled_prompt_embeds_list, dim=-1)
+
+ t5_prompt_embed = _encode_prompt_with_t5(
+ text_encoders[-1],
+ tokenizers[-1],
+ max_sequence_length,
+ prompt=prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device if device is not None else text_encoders[-1].device,
+ )
+
+ clip_prompt_embeds = torch.nn.functional.pad(
+ clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1])
+ )
+ prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2)
+
+ return prompt_embeds, pooled_prompt_embeds
+
+
+def main(args):
+ if args.report_to == "wandb" and args.hub_token is not None:
+ raise ValueError(
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
+ " Please use `huggingface-cli login` to authenticate with the Hub."
+ )
+
+ if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
+ # due to pytorch#99272, MPS does not yet support bfloat16.
+ raise ValueError(
+ "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
+ )
+
+ logging_dir = Path(args.output_dir, args.logging_dir)
+
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+ kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ kwargs_handlers=[kwargs],
+ )
+
+ # Disable AMP for MPS.
+ if torch.backends.mps.is_available():
+ accelerator.native_amp = False
+
+ if args.report_to == "wandb":
+ if not is_wandb_available():
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
+ ).repo_id
+
+ # Load the tokenizer
+ tokenizer_one = CLIPTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="tokenizer",
+ revision=args.revision,
+ )
+ tokenizer_two = CLIPTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="tokenizer_2",
+ revision=args.revision,
+ )
+ tokenizer_three = T5TokenizerFast.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="tokenizer_3",
+ revision=args.revision,
+ )
+
+ # import correct text encoder class
+ text_encoder_cls_one = import_model_class_from_model_name_or_path(
+ args.pretrained_model_name_or_path, args.revision
+ )
+ text_encoder_cls_two = import_model_class_from_model_name_or_path(
+ args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2"
+ )
+ text_encoder_cls_three = import_model_class_from_model_name_or_path(
+ args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_3"
+ )
+
+ # Load scheduler and models
+ noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="scheduler"
+ )
+ noise_scheduler_copy = copy.deepcopy(noise_scheduler)
+ text_encoder_one, text_encoder_two, text_encoder_three = load_text_encoders(
+ text_encoder_cls_one, text_encoder_cls_two, text_encoder_cls_three
+ )
+ vae = AutoencoderKL.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="vae",
+ revision=args.revision,
+ variant=args.variant,
+ )
+ transformer = SD3Transformer2DModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="transformer", revision=args.revision, variant=args.variant
+ )
+
+ if args.controlnet_model_name_or_path:
+ logger.info("Loading existing controlnet weights")
+ controlnet = SD3ControlNetModel.from_pretrained(args.controlnet_model_name_or_path)
+ else:
+ logger.info("Initializing controlnet weights from transformer")
+ controlnet = SD3ControlNetModel.from_transformer(transformer)
+
+ transformer.requires_grad_(False)
+ vae.requires_grad_(False)
+ text_encoder_one.requires_grad_(False)
+ text_encoder_two.requires_grad_(False)
+ text_encoder_three.requires_grad_(False)
+ controlnet.train()
+
+ # Taken from [Sayak Paul's Diffusers PR #6511](https://github.com/huggingface/diffusers/pull/6511/files)
+ def unwrap_model(model):
+ model = accelerator.unwrap_model(model)
+ model = model._orig_mod if is_compiled_module(model) else model
+ return model
+
+ # `accelerate` 0.16.0 will have better support for customized saving
+ if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
+ def save_model_hook(models, weights, output_dir):
+ if accelerator.is_main_process:
+ i = len(weights) - 1
+
+ while len(weights) > 0:
+ weights.pop()
+ model = models[i]
+
+ sub_dir = "controlnet"
+ model.save_pretrained(os.path.join(output_dir, sub_dir))
+
+ i -= 1
+
+ def load_model_hook(models, input_dir):
+ while len(models) > 0:
+ # pop models so that they are not loaded again
+ model = models.pop()
+
+ # load diffusers style into model
+ load_model = SD3ControlNetModel.from_pretrained(input_dir, subfolder="controlnet")
+ model.register_to_config(**load_model.config)
+
+ model.load_state_dict(load_model.state_dict())
+ del load_model
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
+
+ if args.gradient_checkpointing:
+ controlnet.enable_gradient_checkpointing()
+
+ # Check that all trainable models are in full precision
+ low_precision_error_string = (
+ " Please make sure to always have all model weights in full float32 precision when starting training - even if"
+ " doing mixed precision training, copy of the weights should still be float32."
+ )
+
+ if unwrap_model(controlnet).dtype != torch.float32:
+ raise ValueError(
+ f"Controlnet loaded as datatype {unwrap_model(controlnet).dtype}. {low_precision_error_string}"
+ )
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
+ )
+
+ optimizer_class = bnb.optim.AdamW8bit
+ else:
+ optimizer_class = torch.optim.AdamW
+
+ # Optimizer creation
+ params_to_optimize = controlnet.parameters()
+ optimizer = optimizer_class(
+ params_to_optimize,
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ # For mixed precision training we cast the text_encoder and vae weights to half-precision
+ # as these models are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ # Move vae, transformer and text_encoder to device and cast to weight_dtype
+ vae.to(accelerator.device, dtype=torch.float32)
+ transformer.to(accelerator.device, dtype=weight_dtype)
+ text_encoder_one.to(accelerator.device, dtype=weight_dtype)
+ text_encoder_two.to(accelerator.device, dtype=weight_dtype)
+ text_encoder_three.to(accelerator.device, dtype=weight_dtype)
+
+ train_dataset = make_train_dataset(args, tokenizer_one, tokenizer_two, tokenizer_three, accelerator)
+
+ tokenizers = [tokenizer_one, tokenizer_two, tokenizer_three]
+ text_encoders = [text_encoder_one, text_encoder_two, text_encoder_three]
+
+ def compute_text_embeddings(batch, text_encoders, tokenizers):
+ with torch.no_grad():
+ prompt = batch["prompts"]
+ prompt_embeds, pooled_prompt_embeds = encode_prompt(
+ text_encoders, tokenizers, prompt, args.max_sequence_length
+ )
+ prompt_embeds = prompt_embeds.to(accelerator.device)
+ pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device)
+ return {"prompt_embeds": prompt_embeds, "pooled_prompt_embeds": pooled_prompt_embeds}
+
+ compute_embeddings_fn = functools.partial(
+ compute_text_embeddings,
+ text_encoders=text_encoders,
+ tokenizers=tokenizers,
+ )
+ with accelerator.main_process_first():
+ from datasets.fingerprint import Hasher
+
+ # fingerprint used by the cache for the other processes to load the result
+ # details: https://github.com/huggingface/diffusers/pull/4038#discussion_r1266078401
+ new_fingerprint = Hasher.hash(args)
+ train_dataset = train_dataset.map(compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint)
+
+ clear_objs_and_retain_memory(text_encoders + tokenizers)
+
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset,
+ shuffle=True,
+ collate_fn=collate_fn,
+ batch_size=args.train_batch_size,
+ num_workers=args.dataloader_num_workers,
+ )
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
+ num_cycles=args.lr_num_cycles,
+ power=args.lr_power,
+ )
+
+ # Prepare everything with our `accelerator`.
+ controlnet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ controlnet, optimizer, train_dataloader, lr_scheduler
+ )
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ tracker_config = dict(vars(args))
+
+ # tensorboard cannot handle list types for config
+ tracker_config.pop("validation_prompt")
+ tracker_config.pop("validation_image")
+
+ accelerator.init_trackers(args.tracker_project_name, config=tracker_config)
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the most recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ initial_global_step = 0
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ initial_global_step = global_step
+ first_epoch = global_step // num_update_steps_per_epoch
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
+
+ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
+ sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype)
+ schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device)
+ timesteps = timesteps.to(accelerator.device)
+ step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
+
+ sigma = sigmas[step_indices].flatten()
+ while len(sigma.shape) < n_dim:
+ sigma = sigma.unsqueeze(-1)
+ return sigma
+
+ image_logs = None
+ for epoch in range(first_epoch, args.num_train_epochs):
+ for step, batch in enumerate(train_dataloader):
+ with accelerator.accumulate(controlnet):
+ # Convert images to latent space
+ pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
+ model_input = vae.encode(pixel_values).latent_dist.sample()
+ model_input = (model_input - vae.config.shift_factor) * vae.config.scaling_factor
+ model_input = model_input.to(dtype=weight_dtype)
+
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(model_input)
+ bsz = model_input.shape[0]
+ # Sample a random timestep for each image
+ # for weighting schemes where we sample timesteps non-uniformly
+ u = compute_density_for_timestep_sampling(
+ weighting_scheme=args.weighting_scheme,
+ batch_size=bsz,
+ logit_mean=args.logit_mean,
+ logit_std=args.logit_std,
+ mode_scale=args.mode_scale,
+ )
+ indices = (u * noise_scheduler_copy.config.num_train_timesteps).long()
+ timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device)
+
+ # Add noise according to flow matching.
+ # zt = (1 - texp) * x + texp * z1
+ sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)
+ noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise
+
+ # Get the text embedding for conditioning
+ prompt_embeds = batch["prompt_embeds"]
+ pooled_prompt_embeds = batch["pooled_prompt_embeds"]
+
+ # controlnet(s) inference
+ controlnet_image = batch["conditioning_pixel_values"].to(dtype=weight_dtype)
+ controlnet_image = vae.encode(controlnet_image).latent_dist.sample()
+ controlnet_image = controlnet_image * vae.config.scaling_factor
+
+ control_block_res_samples = controlnet(
+ hidden_states=noisy_model_input,
+ timestep=timesteps,
+ encoder_hidden_states=prompt_embeds,
+ pooled_projections=pooled_prompt_embeds,
+ controlnet_cond=controlnet_image,
+ return_dict=False,
+ )[0]
+ control_block_res_samples = [sample.to(dtype=weight_dtype) for sample in control_block_res_samples]
+
+ # Predict the noise residual
+ model_pred = transformer(
+ hidden_states=noisy_model_input,
+ timestep=timesteps,
+ encoder_hidden_states=prompt_embeds,
+ pooled_projections=pooled_prompt_embeds,
+ block_controlnet_hidden_states=control_block_res_samples,
+ return_dict=False,
+ )[0]
+
+ # Follow: Section 5 of https://arxiv.org/abs/2206.00364.
+ # Preconditioning of the model outputs.
+ if args.precondition_outputs:
+ model_pred = model_pred * (-sigmas) + noisy_model_input
+
+ # these weighting schemes use a uniform timestep sampling
+ # and instead post-weight the loss
+ weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
+
+ # flow matching loss
+ if args.precondition_outputs:
+ target = model_input
+ else:
+ target = noise - model_input
+
+ # Compute regular loss.
+ loss = torch.mean(
+ (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1),
+ 1,
+ )
+ loss = loss.mean()
+
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ params_to_clip = controlnet.parameters()
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad(set_to_none=args.set_grads_to_none)
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ global_step += 1
+
+ if accelerator.is_main_process:
+ if global_step % args.checkpointing_steps == 0:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ if args.validation_prompt is not None and global_step % args.validation_steps == 0:
+ image_logs = log_validation(
+ controlnet,
+ args,
+ accelerator,
+ weight_dtype,
+ global_step,
+ )
+
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+ accelerator.log(logs, step=global_step)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ # Create the pipeline using using the trained modules and save it.
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ controlnet = unwrap_model(controlnet)
+ controlnet.save_pretrained(args.output_dir)
+
+ # Run a final round of validation.
+ image_logs = None
+ if args.validation_prompt is not None:
+ image_logs = log_validation(
+ controlnet=None,
+ args=args,
+ accelerator=accelerator,
+ weight_dtype=weight_dtype,
+ step=global_step,
+ is_final_validation=True,
+ )
+
+ if args.push_to_hub:
+ save_model_card(
+ repo_id,
+ image_logs=image_logs,
+ base_model=args.pretrained_model_name_or_path,
+ repo_folder=args.output_dir,
+ )
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ main(args)
diff --git a/diffusers/examples/controlnet/train_controlnet_sdxl.py b/diffusers/examples/controlnet/train_controlnet_sdxl.py
new file mode 100644
index 0000000000000000000000000000000000000000..877ca61358499168a633480e86f349648e4e0e89
--- /dev/null
+++ b/diffusers/examples/controlnet/train_controlnet_sdxl.py
@@ -0,0 +1,1346 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+import argparse
+import functools
+import gc
+import logging
+import math
+import os
+import random
+import shutil
+from contextlib import nullcontext
+from pathlib import Path
+
+import accelerate
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import DistributedType, ProjectConfiguration, set_seed
+from datasets import load_dataset
+from huggingface_hub import create_repo, upload_folder
+from packaging import version
+from PIL import Image
+from torchvision import transforms
+from tqdm.auto import tqdm
+from transformers import AutoTokenizer, PretrainedConfig
+
+import diffusers
+from diffusers import (
+ AutoencoderKL,
+ ControlNetModel,
+ DDPMScheduler,
+ StableDiffusionXLControlNetPipeline,
+ UNet2DConditionModel,
+ UniPCMultistepScheduler,
+)
+from diffusers.optimization import get_scheduler
+from diffusers.utils import check_min_version, is_wandb_available, make_image_grid
+from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
+from diffusers.utils.import_utils import is_torch_npu_available, is_xformers_available
+from diffusers.utils.torch_utils import is_compiled_module
+
+
+if is_wandb_available():
+ import wandb
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.31.0.dev0")
+
+logger = get_logger(__name__)
+if is_torch_npu_available():
+ torch.npu.config.allow_internal_format = False
+
+
+def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step, is_final_validation=False):
+ logger.info("Running validation... ")
+
+ if not is_final_validation:
+ controlnet = accelerator.unwrap_model(controlnet)
+ pipeline = StableDiffusionXLControlNetPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ vae=vae,
+ unet=unet,
+ controlnet=controlnet,
+ revision=args.revision,
+ variant=args.variant,
+ torch_dtype=weight_dtype,
+ )
+ else:
+ controlnet = ControlNetModel.from_pretrained(args.output_dir, torch_dtype=weight_dtype)
+ if args.pretrained_vae_model_name_or_path is not None:
+ vae = AutoencoderKL.from_pretrained(args.pretrained_vae_model_name_or_path, torch_dtype=weight_dtype)
+ else:
+ vae = AutoencoderKL.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="vae", torch_dtype=weight_dtype
+ )
+
+ pipeline = StableDiffusionXLControlNetPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ vae=vae,
+ controlnet=controlnet,
+ revision=args.revision,
+ variant=args.variant,
+ torch_dtype=weight_dtype,
+ )
+
+ pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config)
+ pipeline = pipeline.to(accelerator.device)
+ pipeline.set_progress_bar_config(disable=True)
+
+ if args.enable_xformers_memory_efficient_attention:
+ pipeline.enable_xformers_memory_efficient_attention()
+
+ if args.seed is None:
+ generator = None
+ else:
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
+
+ if len(args.validation_image) == len(args.validation_prompt):
+ validation_images = args.validation_image
+ validation_prompts = args.validation_prompt
+ elif len(args.validation_image) == 1:
+ validation_images = args.validation_image * len(args.validation_prompt)
+ validation_prompts = args.validation_prompt
+ elif len(args.validation_prompt) == 1:
+ validation_images = args.validation_image
+ validation_prompts = args.validation_prompt * len(args.validation_image)
+ else:
+ raise ValueError(
+ "number of `args.validation_image` and `args.validation_prompt` should be checked in `parse_args`"
+ )
+
+ image_logs = []
+ if is_final_validation or torch.backends.mps.is_available():
+ autocast_ctx = nullcontext()
+ else:
+ autocast_ctx = torch.autocast(accelerator.device.type)
+
+ for validation_prompt, validation_image in zip(validation_prompts, validation_images):
+ validation_image = Image.open(validation_image).convert("RGB")
+ validation_image = validation_image.resize((args.resolution, args.resolution))
+
+ images = []
+
+ for _ in range(args.num_validation_images):
+ with autocast_ctx:
+ image = pipeline(
+ prompt=validation_prompt, image=validation_image, num_inference_steps=20, generator=generator
+ ).images[0]
+ images.append(image)
+
+ image_logs.append(
+ {"validation_image": validation_image, "images": images, "validation_prompt": validation_prompt}
+ )
+
+ tracker_key = "test" if is_final_validation else "validation"
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ for log in image_logs:
+ images = log["images"]
+ validation_prompt = log["validation_prompt"]
+ validation_image = log["validation_image"]
+
+ formatted_images = []
+
+ formatted_images.append(np.asarray(validation_image))
+
+ for image in images:
+ formatted_images.append(np.asarray(image))
+
+ formatted_images = np.stack(formatted_images)
+
+ tracker.writer.add_images(validation_prompt, formatted_images, step, dataformats="NHWC")
+ elif tracker.name == "wandb":
+ formatted_images = []
+
+ for log in image_logs:
+ images = log["images"]
+ validation_prompt = log["validation_prompt"]
+ validation_image = log["validation_image"]
+
+ formatted_images.append(wandb.Image(validation_image, caption="Controlnet conditioning"))
+
+ for image in images:
+ image = wandb.Image(image, caption=validation_prompt)
+ formatted_images.append(image)
+
+ tracker.log({tracker_key: formatted_images})
+ else:
+ logger.warning(f"image logging not implemented for {tracker.name}")
+
+ del pipeline
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ return image_logs
+
+
+def import_model_class_from_model_name_or_path(
+ pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
+):
+ text_encoder_config = PretrainedConfig.from_pretrained(
+ pretrained_model_name_or_path, subfolder=subfolder, revision=revision
+ )
+ model_class = text_encoder_config.architectures[0]
+
+ if model_class == "CLIPTextModel":
+ from transformers import CLIPTextModel
+
+ return CLIPTextModel
+ elif model_class == "CLIPTextModelWithProjection":
+ from transformers import CLIPTextModelWithProjection
+
+ return CLIPTextModelWithProjection
+ else:
+ raise ValueError(f"{model_class} is not supported.")
+
+
+def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=None):
+ img_str = ""
+ if image_logs is not None:
+ img_str = "You can find some example images below.\n\n"
+ for i, log in enumerate(image_logs):
+ images = log["images"]
+ validation_prompt = log["validation_prompt"]
+ validation_image = log["validation_image"]
+ validation_image.save(os.path.join(repo_folder, "image_control.png"))
+ img_str += f"prompt: {validation_prompt}\n"
+ images = [validation_image] + images
+ make_image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images_{i}.png"))
+ img_str += f"\n"
+
+ model_description = f"""
+# controlnet-{repo_id}
+
+These are controlnet weights trained on {base_model} with new type of conditioning.
+{img_str}
+"""
+
+ model_card = load_or_create_model_card(
+ repo_id_or_path=repo_id,
+ from_training=True,
+ license="openrail++",
+ base_model=base_model,
+ model_description=model_description,
+ inference=True,
+ )
+
+ tags = [
+ "stable-diffusion-xl",
+ "stable-diffusion-xl-diffusers",
+ "text-to-image",
+ "diffusers",
+ "controlnet",
+ "diffusers-training",
+ ]
+ model_card = populate_model_card(model_card, tags=tags)
+
+ model_card.save(os.path.join(repo_folder, "README.md"))
+
+
+def parse_args(input_args=None):
+ parser = argparse.ArgumentParser(description="Simple example of a ControlNet training script.")
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--pretrained_vae_model_name_or_path",
+ type=str,
+ default=None,
+ help="Path to an improved VAE to stabilize training. For more details check out: https://github.com/huggingface/diffusers/pull/4038.",
+ )
+ parser.add_argument(
+ "--controlnet_model_name_or_path",
+ type=str,
+ default=None,
+ help="Path to pretrained controlnet model or model identifier from huggingface.co/models."
+ " If not specified controlnet weights are initialized from unet.",
+ )
+ parser.add_argument(
+ "--variant",
+ type=str,
+ default=None,
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--tokenizer_name",
+ type=str,
+ default=None,
+ help="Pretrained tokenizer name or path if not the same as model_name",
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="controlnet-model",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--crops_coords_top_left_h",
+ type=int,
+ default=0,
+ help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."),
+ )
+ parser.add_argument(
+ "--crops_coords_top_left_w",
+ type=int,
+ default=0,
+ help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."),
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=1)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. "
+ "In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference."
+ "Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components."
+ "See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step"
+ "instructions."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=5e-6,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--lr_num_cycles",
+ type=int,
+ default=1,
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
+ )
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
+ parser.add_argument(
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
+ )
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument(
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
+ )
+ parser.add_argument(
+ "--enable_npu_flash_attention", action="store_true", help="Whether or not to use npu flash attention."
+ )
+ parser.add_argument(
+ "--set_grads_to_none",
+ action="store_true",
+ help=(
+ "Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain"
+ " behaviors, so disable this argument if it causes any problems. More info:"
+ " https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html"
+ ),
+ )
+ parser.add_argument(
+ "--dataset_name",
+ type=str,
+ default=None,
+ help=(
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
+ " or to a folder containing files that 🤗 Datasets can understand."
+ ),
+ )
+ parser.add_argument(
+ "--dataset_config_name",
+ type=str,
+ default=None,
+ help="The config of the Dataset, leave as None if there's only one config.",
+ )
+ parser.add_argument(
+ "--train_data_dir",
+ type=str,
+ default=None,
+ help=(
+ "A folder containing the training data. Folder contents must follow the structure described in"
+ " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
+ " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
+ ),
+ )
+ parser.add_argument(
+ "--image_column", type=str, default="image", help="The column of the dataset containing the target image."
+ )
+ parser.add_argument(
+ "--conditioning_image_column",
+ type=str,
+ default="conditioning_image",
+ help="The column of the dataset containing the controlnet conditioning image.",
+ )
+ parser.add_argument(
+ "--caption_column",
+ type=str,
+ default="text",
+ help="The column of the dataset containing a caption or a list of captions.",
+ )
+ parser.add_argument(
+ "--max_train_samples",
+ type=int,
+ default=None,
+ help=(
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ ),
+ )
+ parser.add_argument(
+ "--proportion_empty_prompts",
+ type=float,
+ default=0,
+ help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).",
+ )
+ parser.add_argument(
+ "--validation_prompt",
+ type=str,
+ default=None,
+ nargs="+",
+ help=(
+ "A set of prompts evaluated every `--validation_steps` and logged to `--report_to`."
+ " Provide either a matching number of `--validation_image`s, a single `--validation_image`"
+ " to be used with all prompts, or a single prompt that will be used with all `--validation_image`s."
+ ),
+ )
+ parser.add_argument(
+ "--validation_image",
+ type=str,
+ default=None,
+ nargs="+",
+ help=(
+ "A set of paths to the controlnet conditioning image be evaluated every `--validation_steps`"
+ " and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a"
+ " a single `--validation_prompt` to be used with all `--validation_image`s, or a single"
+ " `--validation_image` that will be used with all `--validation_prompt`s."
+ ),
+ )
+ parser.add_argument(
+ "--num_validation_images",
+ type=int,
+ default=4,
+ help="Number of images to be generated for each `--validation_image`, `--validation_prompt` pair",
+ )
+ parser.add_argument(
+ "--validation_steps",
+ type=int,
+ default=100,
+ help=(
+ "Run validation every X steps. Validation consists of running the prompt"
+ " `args.validation_prompt` multiple times: `args.num_validation_images`"
+ " and logging the images."
+ ),
+ )
+ parser.add_argument(
+ "--tracker_project_name",
+ type=str,
+ default="sd_xl_train_controlnet",
+ help=(
+ "The `project_name` argument passed to Accelerator.init_trackers for"
+ " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
+ ),
+ )
+
+ if input_args is not None:
+ args = parser.parse_args(input_args)
+ else:
+ args = parser.parse_args()
+
+ if args.dataset_name is None and args.train_data_dir is None:
+ raise ValueError("Specify either `--dataset_name` or `--train_data_dir`")
+
+ if args.dataset_name is not None and args.train_data_dir is not None:
+ raise ValueError("Specify only one of `--dataset_name` or `--train_data_dir`")
+
+ if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1:
+ raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].")
+
+ if args.validation_prompt is not None and args.validation_image is None:
+ raise ValueError("`--validation_image` must be set if `--validation_prompt` is set")
+
+ if args.validation_prompt is None and args.validation_image is not None:
+ raise ValueError("`--validation_prompt` must be set if `--validation_image` is set")
+
+ if (
+ args.validation_image is not None
+ and args.validation_prompt is not None
+ and len(args.validation_image) != 1
+ and len(args.validation_prompt) != 1
+ and len(args.validation_image) != len(args.validation_prompt)
+ ):
+ raise ValueError(
+ "Must provide either 1 `--validation_image`, 1 `--validation_prompt`,"
+ " or the same number of `--validation_prompt`s and `--validation_image`s"
+ )
+
+ if args.resolution % 8 != 0:
+ raise ValueError(
+ "`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the controlnet encoder."
+ )
+
+ return args
+
+
+def get_train_dataset(args, accelerator):
+ # Get the datasets: you can either provide your own training and evaluation files (see below)
+ # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
+
+ # In distributed training, the load_dataset function guarantees that only one local process can concurrently
+ # download the dataset.
+ if args.dataset_name is not None:
+ # Downloading and loading a dataset from the hub.
+ dataset = load_dataset(
+ args.dataset_name,
+ args.dataset_config_name,
+ cache_dir=args.cache_dir,
+ )
+ else:
+ if args.train_data_dir is not None:
+ dataset = load_dataset(
+ args.train_data_dir,
+ cache_dir=args.cache_dir,
+ )
+ # See more about loading custom images at
+ # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script
+
+ # Preprocessing the datasets.
+ # We need to tokenize inputs and targets.
+ column_names = dataset["train"].column_names
+
+ # 6. Get the column names for input/target.
+ if args.image_column is None:
+ image_column = column_names[0]
+ logger.info(f"image column defaulting to {image_column}")
+ else:
+ image_column = args.image_column
+ if image_column not in column_names:
+ raise ValueError(
+ f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
+ )
+
+ if args.caption_column is None:
+ caption_column = column_names[1]
+ logger.info(f"caption column defaulting to {caption_column}")
+ else:
+ caption_column = args.caption_column
+ if caption_column not in column_names:
+ raise ValueError(
+ f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
+ )
+
+ if args.conditioning_image_column is None:
+ conditioning_image_column = column_names[2]
+ logger.info(f"conditioning image column defaulting to {conditioning_image_column}")
+ else:
+ conditioning_image_column = args.conditioning_image_column
+ if conditioning_image_column not in column_names:
+ raise ValueError(
+ f"`--conditioning_image_column` value '{args.conditioning_image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
+ )
+
+ with accelerator.main_process_first():
+ train_dataset = dataset["train"].shuffle(seed=args.seed)
+ if args.max_train_samples is not None:
+ train_dataset = train_dataset.select(range(args.max_train_samples))
+ return train_dataset
+
+
+# Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt
+def encode_prompt(prompt_batch, text_encoders, tokenizers, proportion_empty_prompts, is_train=True):
+ prompt_embeds_list = []
+
+ captions = []
+ for caption in prompt_batch:
+ if random.random() < proportion_empty_prompts:
+ captions.append("")
+ elif isinstance(caption, str):
+ captions.append(caption)
+ elif isinstance(caption, (list, np.ndarray)):
+ # take a random caption if there are multiple
+ captions.append(random.choice(caption) if is_train else caption[0])
+
+ with torch.no_grad():
+ for tokenizer, text_encoder in zip(tokenizers, text_encoders):
+ text_inputs = tokenizer(
+ captions,
+ padding="max_length",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ prompt_embeds = text_encoder(
+ text_input_ids.to(text_encoder.device),
+ output_hidden_states=True,
+ )
+
+ # We are only ALWAYS interested in the pooled output of the final text encoder
+ pooled_prompt_embeds = prompt_embeds[0]
+ prompt_embeds = prompt_embeds.hidden_states[-2]
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
+ prompt_embeds_list.append(prompt_embeds)
+
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
+ pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
+ return prompt_embeds, pooled_prompt_embeds
+
+
+def prepare_train_dataset(dataset, accelerator):
+ image_transforms = transforms.Compose(
+ [
+ transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
+ transforms.CenterCrop(args.resolution),
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+ )
+
+ conditioning_image_transforms = transforms.Compose(
+ [
+ transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
+ transforms.CenterCrop(args.resolution),
+ transforms.ToTensor(),
+ ]
+ )
+
+ def preprocess_train(examples):
+ images = [image.convert("RGB") for image in examples[args.image_column]]
+ images = [image_transforms(image) for image in images]
+
+ conditioning_images = [image.convert("RGB") for image in examples[args.conditioning_image_column]]
+ conditioning_images = [conditioning_image_transforms(image) for image in conditioning_images]
+
+ examples["pixel_values"] = images
+ examples["conditioning_pixel_values"] = conditioning_images
+
+ return examples
+
+ with accelerator.main_process_first():
+ dataset = dataset.with_transform(preprocess_train)
+
+ return dataset
+
+
+def collate_fn(examples):
+ pixel_values = torch.stack([example["pixel_values"] for example in examples])
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
+
+ conditioning_pixel_values = torch.stack([example["conditioning_pixel_values"] for example in examples])
+ conditioning_pixel_values = conditioning_pixel_values.to(memory_format=torch.contiguous_format).float()
+
+ prompt_ids = torch.stack([torch.tensor(example["prompt_embeds"]) for example in examples])
+
+ add_text_embeds = torch.stack([torch.tensor(example["text_embeds"]) for example in examples])
+ add_time_ids = torch.stack([torch.tensor(example["time_ids"]) for example in examples])
+
+ return {
+ "pixel_values": pixel_values,
+ "conditioning_pixel_values": conditioning_pixel_values,
+ "prompt_ids": prompt_ids,
+ "unet_added_conditions": {"text_embeds": add_text_embeds, "time_ids": add_time_ids},
+ }
+
+
+def main(args):
+ if args.report_to == "wandb" and args.hub_token is not None:
+ raise ValueError(
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
+ " Please use `huggingface-cli login` to authenticate with the Hub."
+ )
+
+ logging_dir = Path(args.output_dir, args.logging_dir)
+
+ if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
+ # due to pytorch#99272, MPS does not yet support bfloat16.
+ raise ValueError(
+ "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
+ )
+
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ )
+
+ # Disable AMP for MPS.
+ if torch.backends.mps.is_available():
+ accelerator.native_amp = False
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
+ ).repo_id
+
+ # Load the tokenizers
+ tokenizer_one = AutoTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="tokenizer",
+ revision=args.revision,
+ use_fast=False,
+ )
+ tokenizer_two = AutoTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="tokenizer_2",
+ revision=args.revision,
+ use_fast=False,
+ )
+
+ # import correct text encoder classes
+ text_encoder_cls_one = import_model_class_from_model_name_or_path(
+ args.pretrained_model_name_or_path, args.revision
+ )
+ text_encoder_cls_two = import_model_class_from_model_name_or_path(
+ args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2"
+ )
+
+ # Load scheduler and models
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
+ text_encoder_one = text_encoder_cls_one.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
+ )
+ text_encoder_two = text_encoder_cls_two.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant
+ )
+ vae_path = (
+ args.pretrained_model_name_or_path
+ if args.pretrained_vae_model_name_or_path is None
+ else args.pretrained_vae_model_name_or_path
+ )
+ vae = AutoencoderKL.from_pretrained(
+ vae_path,
+ subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
+ revision=args.revision,
+ variant=args.variant,
+ )
+ unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
+ )
+
+ if args.controlnet_model_name_or_path:
+ logger.info("Loading existing controlnet weights")
+ controlnet = ControlNetModel.from_pretrained(args.controlnet_model_name_or_path)
+ else:
+ logger.info("Initializing controlnet weights from unet")
+ controlnet = ControlNetModel.from_unet(unet)
+
+ def unwrap_model(model):
+ model = accelerator.unwrap_model(model)
+ model = model._orig_mod if is_compiled_module(model) else model
+ return model
+
+ # `accelerate` 0.16.0 will have better support for customized saving
+ if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
+ def save_model_hook(models, weights, output_dir):
+ if accelerator.is_main_process:
+ i = len(weights) - 1
+
+ while len(weights) > 0:
+ weights.pop()
+ model = models[i]
+
+ sub_dir = "controlnet"
+ model.save_pretrained(os.path.join(output_dir, sub_dir))
+
+ i -= 1
+
+ def load_model_hook(models, input_dir):
+ while len(models) > 0:
+ # pop models so that they are not loaded again
+ model = models.pop()
+
+ # load diffusers style into model
+ load_model = ControlNetModel.from_pretrained(input_dir, subfolder="controlnet")
+ model.register_to_config(**load_model.config)
+
+ model.load_state_dict(load_model.state_dict())
+ del load_model
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
+
+ vae.requires_grad_(False)
+ unet.requires_grad_(False)
+ text_encoder_one.requires_grad_(False)
+ text_encoder_two.requires_grad_(False)
+ controlnet.train()
+
+ if args.enable_npu_flash_attention:
+ if is_torch_npu_available():
+ logger.info("npu flash attention enabled.")
+ unet.enable_npu_flash_attention()
+ else:
+ raise ValueError("npu flash attention requires torch_npu extensions and is supported only on npu devices.")
+
+ if args.enable_xformers_memory_efficient_attention:
+ if is_xformers_available():
+ import xformers
+
+ xformers_version = version.parse(xformers.__version__)
+ if xformers_version == version.parse("0.0.16"):
+ logger.warning(
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
+ )
+ unet.enable_xformers_memory_efficient_attention()
+ controlnet.enable_xformers_memory_efficient_attention()
+ else:
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
+
+ if args.gradient_checkpointing:
+ controlnet.enable_gradient_checkpointing()
+ unet.enable_gradient_checkpointing()
+
+ # Check that all trainable models are in full precision
+ low_precision_error_string = (
+ " Please make sure to always have all model weights in full float32 precision when starting training - even if"
+ " doing mixed precision training, copy of the weights should still be float32."
+ )
+
+ if unwrap_model(controlnet).dtype != torch.float32:
+ raise ValueError(
+ f"Controlnet loaded as datatype {unwrap_model(controlnet).dtype}. {low_precision_error_string}"
+ )
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
+ )
+
+ optimizer_class = bnb.optim.AdamW8bit
+ else:
+ optimizer_class = torch.optim.AdamW
+
+ # Optimizer creation
+ params_to_optimize = controlnet.parameters()
+ optimizer = optimizer_class(
+ params_to_optimize,
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ # For mixed precision training we cast the text_encoder and vae weights to half-precision
+ # as these models are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ # Move vae, unet and text_encoder to device and cast to weight_dtype
+ # The VAE is in float32 to avoid NaN losses.
+ if args.pretrained_vae_model_name_or_path is not None:
+ vae.to(accelerator.device, dtype=weight_dtype)
+ else:
+ vae.to(accelerator.device, dtype=torch.float32)
+ unet.to(accelerator.device, dtype=weight_dtype)
+ text_encoder_one.to(accelerator.device, dtype=weight_dtype)
+ text_encoder_two.to(accelerator.device, dtype=weight_dtype)
+
+ # Here, we compute not just the text embeddings but also the additional embeddings
+ # needed for the SD XL UNet to operate.
+ def compute_embeddings(batch, proportion_empty_prompts, text_encoders, tokenizers, is_train=True):
+ original_size = (args.resolution, args.resolution)
+ target_size = (args.resolution, args.resolution)
+ crops_coords_top_left = (args.crops_coords_top_left_h, args.crops_coords_top_left_w)
+ prompt_batch = batch[args.caption_column]
+
+ prompt_embeds, pooled_prompt_embeds = encode_prompt(
+ prompt_batch, text_encoders, tokenizers, proportion_empty_prompts, is_train
+ )
+ add_text_embeds = pooled_prompt_embeds
+
+ # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
+ add_time_ids = torch.tensor([add_time_ids])
+
+ prompt_embeds = prompt_embeds.to(accelerator.device)
+ add_text_embeds = add_text_embeds.to(accelerator.device)
+ add_time_ids = add_time_ids.repeat(len(prompt_batch), 1)
+ add_time_ids = add_time_ids.to(accelerator.device, dtype=prompt_embeds.dtype)
+ unet_added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
+
+ return {"prompt_embeds": prompt_embeds, **unet_added_cond_kwargs}
+
+ # Let's first compute all the embeddings so that we can free up the text encoders
+ # from memory.
+ text_encoders = [text_encoder_one, text_encoder_two]
+ tokenizers = [tokenizer_one, tokenizer_two]
+ train_dataset = get_train_dataset(args, accelerator)
+ compute_embeddings_fn = functools.partial(
+ compute_embeddings,
+ text_encoders=text_encoders,
+ tokenizers=tokenizers,
+ proportion_empty_prompts=args.proportion_empty_prompts,
+ )
+ with accelerator.main_process_first():
+ from datasets.fingerprint import Hasher
+
+ # fingerprint used by the cache for the other processes to load the result
+ # details: https://github.com/huggingface/diffusers/pull/4038#discussion_r1266078401
+ new_fingerprint = Hasher.hash(args)
+ train_dataset = train_dataset.map(compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint)
+
+ del text_encoders, tokenizers
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ # Then get the training dataset ready to be passed to the dataloader.
+ train_dataset = prepare_train_dataset(train_dataset, accelerator)
+
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset,
+ shuffle=True,
+ collate_fn=collate_fn,
+ batch_size=args.train_batch_size,
+ num_workers=args.dataloader_num_workers,
+ )
+
+ # Scheduler and math around the number of training steps.
+ # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
+ num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
+ if args.max_train_steps is None:
+ len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
+ num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
+ num_training_steps_for_scheduler = (
+ args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes
+ )
+ else:
+ num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=num_warmup_steps_for_scheduler,
+ num_training_steps=num_training_steps_for_scheduler,
+ num_cycles=args.lr_num_cycles,
+ power=args.lr_power,
+ )
+
+ # Prepare everything with our `accelerator`.
+ controlnet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ controlnet, optimizer, train_dataloader, lr_scheduler
+ )
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ if num_training_steps_for_scheduler != args.max_train_steps * accelerator.num_processes:
+ logger.warning(
+ f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
+ f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
+ f"This inconsistency may result in the learning rate scheduler not functioning properly."
+ )
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ tracker_config = dict(vars(args))
+
+ # tensorboard cannot handle list types for config
+ tracker_config.pop("validation_prompt")
+ tracker_config.pop("validation_image")
+
+ accelerator.init_trackers(args.tracker_project_name, config=tracker_config)
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the most recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ initial_global_step = 0
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ initial_global_step = global_step
+ first_epoch = global_step // num_update_steps_per_epoch
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
+
+ image_logs = None
+ for epoch in range(first_epoch, args.num_train_epochs):
+ for step, batch in enumerate(train_dataloader):
+ with accelerator.accumulate(controlnet):
+ # Convert images to latent space
+ if args.pretrained_vae_model_name_or_path is not None:
+ pixel_values = batch["pixel_values"].to(dtype=weight_dtype)
+ else:
+ pixel_values = batch["pixel_values"]
+ latents = vae.encode(pixel_values).latent_dist.sample()
+ latents = latents * vae.config.scaling_factor
+ if args.pretrained_vae_model_name_or_path is None:
+ latents = latents.to(weight_dtype)
+
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(latents)
+ bsz = latents.shape[0]
+
+ # Sample a random timestep for each image
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
+ timesteps = timesteps.long()
+
+ # Add noise to the latents according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
+
+ # ControlNet conditioning.
+ controlnet_image = batch["conditioning_pixel_values"].to(dtype=weight_dtype)
+ down_block_res_samples, mid_block_res_sample = controlnet(
+ noisy_latents,
+ timesteps,
+ encoder_hidden_states=batch["prompt_ids"],
+ added_cond_kwargs=batch["unet_added_conditions"],
+ controlnet_cond=controlnet_image,
+ return_dict=False,
+ )
+
+ # Predict the noise residual
+ model_pred = unet(
+ noisy_latents,
+ timesteps,
+ encoder_hidden_states=batch["prompt_ids"],
+ added_cond_kwargs=batch["unet_added_conditions"],
+ down_block_additional_residuals=[
+ sample.to(dtype=weight_dtype) for sample in down_block_res_samples
+ ],
+ mid_block_additional_residual=mid_block_res_sample.to(dtype=weight_dtype),
+ return_dict=False,
+ )[0]
+
+ # Get the target for loss depending on the prediction type
+ if noise_scheduler.config.prediction_type == "epsilon":
+ target = noise
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
+ else:
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
+
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ params_to_clip = controlnet.parameters()
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad(set_to_none=args.set_grads_to_none)
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ global_step += 1
+
+ # DeepSpeed requires saving weights on every device; saving weights only on the main process would cause issues.
+ if accelerator.distributed_type == DistributedType.DEEPSPEED or accelerator.is_main_process:
+ if global_step % args.checkpointing_steps == 0:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ if args.validation_prompt is not None and global_step % args.validation_steps == 0:
+ image_logs = log_validation(
+ vae=vae,
+ unet=unet,
+ controlnet=controlnet,
+ args=args,
+ accelerator=accelerator,
+ weight_dtype=weight_dtype,
+ step=global_step,
+ )
+
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+ accelerator.log(logs, step=global_step)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ # Create the pipeline using using the trained modules and save it.
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ controlnet = unwrap_model(controlnet)
+ controlnet.save_pretrained(args.output_dir)
+
+ # Run a final round of validation.
+ # Setting `vae`, `unet`, and `controlnet` to None to load automatically from `args.output_dir`.
+ image_logs = None
+ if args.validation_prompt is not None:
+ image_logs = log_validation(
+ vae=None,
+ unet=None,
+ controlnet=None,
+ args=args,
+ accelerator=accelerator,
+ weight_dtype=weight_dtype,
+ step=global_step,
+ is_final_validation=True,
+ )
+
+ if args.push_to_hub:
+ save_model_card(
+ repo_id,
+ image_logs=image_logs,
+ base_model=args.pretrained_model_name_or_path,
+ repo_folder=args.output_dir,
+ )
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ main(args)
diff --git a/diffusers/examples/custom_diffusion/README.md b/diffusers/examples/custom_diffusion/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..9b0b52e5ae7e18e4ea8b0215c8f6cba5058b56ce
--- /dev/null
+++ b/diffusers/examples/custom_diffusion/README.md
@@ -0,0 +1,280 @@
+# Custom Diffusion training example
+
+[Custom Diffusion](https://arxiv.org/abs/2212.04488) is a method to customize text-to-image models like Stable Diffusion given just a few (4~5) images of a subject.
+The `train_custom_diffusion.py` script shows how to implement the training procedure and adapt it for stable diffusion.
+
+## Running locally with PyTorch
+
+### Installing the dependencies
+
+Before running the scripts, make sure to install the library's training dependencies:
+
+**Important**
+
+To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
+
+```bash
+git clone https://github.com/huggingface/diffusers
+cd diffusers
+pip install -e .
+```
+
+Then cd in the example folder and run
+
+```bash
+pip install -r requirements.txt
+pip install clip-retrieval
+```
+
+And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
+
+```bash
+accelerate config
+```
+
+Or for a default accelerate configuration without answering questions about your environment
+
+```bash
+accelerate config default
+```
+
+Or if your environment doesn't support an interactive shell e.g. a notebook
+
+```python
+from accelerate.utils import write_basic_config
+write_basic_config()
+```
+### Cat example 😺
+
+Now let's get our dataset. Download dataset from [here](https://www.cs.cmu.edu/~custom-diffusion/assets/data.zip) and unzip it.
+
+We also collect 200 real images using `clip-retrieval` which are combined with the target images in the training dataset as a regularization. This prevents overfitting to the given target image. The following flags enable the regularization `with_prior_preservation`, `real_prior` with `prior_loss_weight=1.`.
+The `class_prompt` should be the category name same as target image. The collected real images are with text captions similar to the `class_prompt`. The retrieved image are saved in `class_data_dir`. You can disable `real_prior` to use generated images as regularization. To collect the real images use this command first before training.
+
+```bash
+pip install clip-retrieval
+python retrieve.py --class_prompt cat --class_data_dir real_reg/samples_cat --num_class_images 200
+```
+
+**___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.___**
+
+```bash
+export MODEL_NAME="CompVis/stable-diffusion-v1-4"
+export OUTPUT_DIR="path-to-save-model"
+export INSTANCE_DIR="./data/cat"
+
+accelerate launch train_custom_diffusion.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --instance_data_dir=$INSTANCE_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --class_data_dir=./real_reg/samples_cat/ \
+ --with_prior_preservation --real_prior --prior_loss_weight=1.0 \
+ --class_prompt="cat" --num_class_images=200 \
+ --instance_prompt="photo of a cat" \
+ --resolution=512 \
+ --train_batch_size=2 \
+ --learning_rate=1e-5 \
+ --lr_warmup_steps=0 \
+ --max_train_steps=250 \
+ --scale_lr --hflip \
+ --modifier_token ""
+```
+
+**Use `--enable_xformers_memory_efficient_attention` for faster training with lower VRAM requirement (16GB per GPU). Follow [this guide](https://github.com/facebookresearch/xformers) for installation instructions.**
+
+To track your experiments using Weights and Biases (`wandb`) and to save intermediate results (which we HIGHLY recommend), follow these steps:
+
+* Install `wandb`: `pip install wandb`.
+* Authorize: `wandb login`.
+* Then specify a `validation_prompt` and set `report_to` to `wandb` while launching training. You can also configure the following related arguments:
+ * `num_validation_images`
+ * `validation_steps`
+
+Here is an example command:
+
+```bash
+accelerate launch train_custom_diffusion.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --instance_data_dir=$INSTANCE_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --class_data_dir=./real_reg/samples_cat/ \
+ --with_prior_preservation --real_prior --prior_loss_weight=1.0 \
+ --class_prompt="cat" --num_class_images=200 \
+ --instance_prompt="photo of a cat" \
+ --resolution=512 \
+ --train_batch_size=2 \
+ --learning_rate=1e-5 \
+ --lr_warmup_steps=0 \
+ --max_train_steps=250 \
+ --scale_lr --hflip \
+ --modifier_token "" \
+ --validation_prompt=" cat sitting in a bucket" \
+ --report_to="wandb"
+```
+
+Here is an example [Weights and Biases page](https://wandb.ai/sayakpaul/custom-diffusion/runs/26ghrcau) where you can check out the intermediate results along with other training details.
+
+If you specify `--push_to_hub`, the learned parameters will be pushed to a repository on the Hugging Face Hub. Here is an [example repository](https://huggingface.co/sayakpaul/custom-diffusion-cat).
+
+### Training on multiple concepts 🐱🪵
+
+Provide a [json](https://github.com/adobe-research/custom-diffusion/blob/main/assets/concept_list.json) file with the info about each concept, similar to [this](https://github.com/ShivamShrirao/diffusers/blob/main/examples/dreambooth/train_dreambooth.py).
+
+To collect the real images run this command for each concept in the json file.
+
+```bash
+pip install clip-retrieval
+python retrieve.py --class_prompt {} --class_data_dir {} --num_class_images 200
+```
+
+And then we're ready to start training!
+
+```bash
+export MODEL_NAME="CompVis/stable-diffusion-v1-4"
+export OUTPUT_DIR="path-to-save-model"
+
+accelerate launch train_custom_diffusion.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --output_dir=$OUTPUT_DIR \
+ --concepts_list=./concept_list.json \
+ --with_prior_preservation --real_prior --prior_loss_weight=1.0 \
+ --resolution=512 \
+ --train_batch_size=2 \
+ --learning_rate=1e-5 \
+ --lr_warmup_steps=0 \
+ --max_train_steps=500 \
+ --num_class_images=200 \
+ --scale_lr --hflip \
+ --modifier_token "+"
+```
+
+Here is an example [Weights and Biases page](https://wandb.ai/sayakpaul/custom-diffusion/runs/3990tzkg) where you can check out the intermediate results along with other training details.
+
+### Training on human faces
+
+For fine-tuning on human faces we found the following configuration to work better: `learning_rate=5e-6`, `max_train_steps=1000 to 2000`, and `freeze_model=crossattn` with at least 15-20 images.
+
+To collect the real images use this command first before training.
+
+```bash
+pip install clip-retrieval
+python retrieve.py --class_prompt person --class_data_dir real_reg/samples_person --num_class_images 200
+```
+
+Then start training!
+
+```bash
+export MODEL_NAME="CompVis/stable-diffusion-v1-4"
+export OUTPUT_DIR="path-to-save-model"
+export INSTANCE_DIR="path-to-images"
+
+accelerate launch train_custom_diffusion.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --instance_data_dir=$INSTANCE_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --class_data_dir=./real_reg/samples_person/ \
+ --with_prior_preservation --real_prior --prior_loss_weight=1.0 \
+ --class_prompt="person" --num_class_images=200 \
+ --instance_prompt="photo of a person" \
+ --resolution=512 \
+ --train_batch_size=2 \
+ --learning_rate=5e-6 \
+ --lr_warmup_steps=0 \
+ --max_train_steps=1000 \
+ --scale_lr --hflip --noaug \
+ --freeze_model crossattn \
+ --modifier_token "" \
+ --enable_xformers_memory_efficient_attention
+```
+
+## Inference
+
+Once you have trained a model using the above command, you can run inference using the below command. Make sure to include the `modifier token` (e.g. \ in above example) in your prompt.
+
+```python
+import torch
+from diffusers import DiffusionPipeline
+
+pipe = DiffusionPipeline.from_pretrained(
+ "CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16
+).to("cuda")
+pipe.unet.load_attn_procs(
+ "path-to-save-model", weight_name="pytorch_custom_diffusion_weights.bin"
+)
+pipe.load_textual_inversion("path-to-save-model", weight_name=".bin")
+
+image = pipe(
+ " cat sitting in a bucket",
+ num_inference_steps=100,
+ guidance_scale=6.0,
+ eta=1.0,
+).images[0]
+image.save("cat.png")
+```
+
+It's possible to directly load these parameters from a Hub repository:
+
+```python
+import torch
+from huggingface_hub.repocard import RepoCard
+from diffusers import DiffusionPipeline
+
+model_id = "sayakpaul/custom-diffusion-cat"
+card = RepoCard.load(model_id)
+base_model_id = card.data.to_dict()["base_model"]
+
+pipe = DiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16).to(
+"cuda")
+pipe.unet.load_attn_procs(model_id, weight_name="pytorch_custom_diffusion_weights.bin")
+pipe.load_textual_inversion(model_id, weight_name=".bin")
+
+image = pipe(
+ " cat sitting in a bucket",
+ num_inference_steps=100,
+ guidance_scale=6.0,
+ eta=1.0,
+).images[0]
+image.save("cat.png")
+```
+
+Here is an example of performing inference with multiple concepts:
+
+```python
+import torch
+from huggingface_hub.repocard import RepoCard
+from diffusers import DiffusionPipeline
+
+model_id = "sayakpaul/custom-diffusion-cat-wooden-pot"
+card = RepoCard.load(model_id)
+base_model_id = card.data.to_dict()["base_model"]
+
+pipe = DiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16).to(
+"cuda")
+pipe.unet.load_attn_procs(model_id, weight_name="pytorch_custom_diffusion_weights.bin")
+pipe.load_textual_inversion(model_id, weight_name=".bin")
+pipe.load_textual_inversion(model_id, weight_name=".bin")
+
+image = pipe(
+ "the cat sculpture in the style of a wooden pot",
+ num_inference_steps=100,
+ guidance_scale=6.0,
+ eta=1.0,
+).images[0]
+image.save("multi-subject.png")
+```
+
+Here, `cat` and `wooden pot` refer to the multiple concepts.
+
+### Inference from a training checkpoint
+
+You can also perform inference from one of the complete checkpoint saved during the training process, if you used the `--checkpointing_steps` argument.
+
+TODO.
+
+## Set grads to none
+To save even more memory, pass the `--set_grads_to_none` argument to the script. This will set grads to None instead of zero. However, be aware that it changes certain behaviors, so if you start experiencing any problems, remove this argument.
+
+More info: https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html
+
+## Experimental results
+You can refer to [our webpage](https://www.cs.cmu.edu/~custom-diffusion/) that discusses our experiments in detail. We also released a more extensive dataset of 101 concepts for evaluating model customization methods. For more details please refer to our [dataset webpage](https://www.cs.cmu.edu/~custom-diffusion/dataset.html).
\ No newline at end of file
diff --git a/diffusers/examples/custom_diffusion/requirements.txt b/diffusers/examples/custom_diffusion/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..7d93f3d03bd8eba09b8cab5e570d15380456b66a
--- /dev/null
+++ b/diffusers/examples/custom_diffusion/requirements.txt
@@ -0,0 +1,6 @@
+accelerate
+torchvision
+transformers>=4.25.1
+ftfy
+tensorboard
+Jinja2
diff --git a/diffusers/examples/custom_diffusion/retrieve.py b/diffusers/examples/custom_diffusion/retrieve.py
new file mode 100644
index 0000000000000000000000000000000000000000..a28fe344d93b469f414909cd8c4d61d7ff16e50f
--- /dev/null
+++ b/diffusers/examples/custom_diffusion/retrieve.py
@@ -0,0 +1,87 @@
+# Copyright 2024 Custom Diffusion authors. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import argparse
+import os
+from io import BytesIO
+from pathlib import Path
+
+import requests
+from clip_retrieval.clip_client import ClipClient
+from PIL import Image
+from tqdm import tqdm
+
+
+def retrieve(class_prompt, class_data_dir, num_class_images):
+ factor = 1.5
+ num_images = int(factor * num_class_images)
+ client = ClipClient(
+ url="https://knn.laion.ai/knn-service", indice_name="laion_400m", num_images=num_images, aesthetic_weight=0.1
+ )
+
+ os.makedirs(f"{class_data_dir}/images", exist_ok=True)
+ if len(list(Path(f"{class_data_dir}/images").iterdir())) >= num_class_images:
+ return
+
+ while True:
+ class_images = client.query(text=class_prompt)
+ if len(class_images) >= factor * num_class_images or num_images > 1e4:
+ break
+ else:
+ num_images = int(factor * num_images)
+ client = ClipClient(
+ url="https://knn.laion.ai/knn-service",
+ indice_name="laion_400m",
+ num_images=num_images,
+ aesthetic_weight=0.1,
+ )
+
+ count = 0
+ total = 0
+ pbar = tqdm(desc="downloading real regularization images", total=num_class_images)
+
+ with open(f"{class_data_dir}/caption.txt", "w") as f1, open(f"{class_data_dir}/urls.txt", "w") as f2, open(
+ f"{class_data_dir}/images.txt", "w"
+ ) as f3:
+ while total < num_class_images:
+ images = class_images[count]
+ count += 1
+ try:
+ img = requests.get(images["url"], timeout=30)
+ if img.status_code == 200:
+ _ = Image.open(BytesIO(img.content))
+ with open(f"{class_data_dir}/images/{total}.jpg", "wb") as f:
+ f.write(img.content)
+ f1.write(images["caption"] + "\n")
+ f2.write(images["url"] + "\n")
+ f3.write(f"{class_data_dir}/images/{total}.jpg" + "\n")
+ total += 1
+ pbar.update(1)
+ else:
+ continue
+ except Exception:
+ continue
+ return
+
+
+def parse_args():
+ parser = argparse.ArgumentParser("", add_help=False)
+ parser.add_argument("--class_prompt", help="text prompt to retrieve images", required=True, type=str)
+ parser.add_argument("--class_data_dir", help="path to save images", required=True, type=str)
+ parser.add_argument("--num_class_images", help="number of images to download", default=200, type=int)
+ return parser.parse_args()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ retrieve(args.class_prompt, args.class_data_dir, args.num_class_images)
diff --git a/diffusers/examples/custom_diffusion/test_custom_diffusion.py b/diffusers/examples/custom_diffusion/test_custom_diffusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..0d7e895b3bb838d2a432baae340e3eb2c44f14db
--- /dev/null
+++ b/diffusers/examples/custom_diffusion/test_custom_diffusion.py
@@ -0,0 +1,124 @@
+# coding=utf-8
+# Copyright 2024 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import os
+import sys
+import tempfile
+
+
+sys.path.append("..")
+from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
+
+
+logging.basicConfig(level=logging.DEBUG)
+
+logger = logging.getLogger()
+stream_handler = logging.StreamHandler(sys.stdout)
+logger.addHandler(stream_handler)
+
+
+class CustomDiffusion(ExamplesTestsAccelerate):
+ def test_custom_diffusion(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/custom_diffusion/train_custom_diffusion.py
+ --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe
+ --instance_data_dir docs/source/en/imgs
+ --instance_prompt
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --learning_rate 1.0e-05
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --modifier_token
+ --no_safe_serialization
+ --output_dir {tmpdir}
+ """.split()
+
+ run_command(self._launch_args + test_args)
+ # save_pretrained smoke test
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_custom_diffusion_weights.bin")))
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, ".bin")))
+
+ def test_custom_diffusion_checkpointing_checkpoints_total_limit(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/custom_diffusion/train_custom_diffusion.py
+ --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
+ --instance_data_dir=docs/source/en/imgs
+ --output_dir={tmpdir}
+ --instance_prompt=
+ --resolution=64
+ --train_batch_size=1
+ --modifier_token=
+ --dataloader_num_workers=0
+ --max_train_steps=6
+ --checkpoints_total_limit=2
+ --checkpointing_steps=2
+ --no_safe_serialization
+ """.split()
+
+ run_command(self._launch_args + test_args)
+
+ self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-4", "checkpoint-6"})
+
+ def test_custom_diffusion_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/custom_diffusion/train_custom_diffusion.py
+ --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
+ --instance_data_dir=docs/source/en/imgs
+ --output_dir={tmpdir}
+ --instance_prompt=
+ --resolution=64
+ --train_batch_size=1
+ --modifier_token=
+ --dataloader_num_workers=0
+ --max_train_steps=4
+ --checkpointing_steps=2
+ --no_safe_serialization
+ """.split()
+
+ run_command(self._launch_args + test_args)
+
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-2", "checkpoint-4"},
+ )
+
+ resume_run_args = f"""
+ examples/custom_diffusion/train_custom_diffusion.py
+ --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
+ --instance_data_dir=docs/source/en/imgs
+ --output_dir={tmpdir}
+ --instance_prompt=
+ --resolution=64
+ --train_batch_size=1
+ --modifier_token=
+ --dataloader_num_workers=0
+ --max_train_steps=8
+ --checkpointing_steps=2
+ --resume_from_checkpoint=checkpoint-4
+ --checkpoints_total_limit=2
+ --no_safe_serialization
+ """.split()
+
+ run_command(self._launch_args + resume_run_args)
+
+ self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})
diff --git a/diffusers/examples/custom_diffusion/train_custom_diffusion.py b/diffusers/examples/custom_diffusion/train_custom_diffusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..e498ca98b1c76abed63adba20e7e6118bb2339f5
--- /dev/null
+++ b/diffusers/examples/custom_diffusion/train_custom_diffusion.py
@@ -0,0 +1,1378 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2024 Custom Diffusion authors and the HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+import argparse
+import itertools
+import json
+import logging
+import math
+import os
+import random
+import shutil
+import warnings
+from pathlib import Path
+
+import numpy as np
+import safetensors
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import ProjectConfiguration, set_seed
+from huggingface_hub import HfApi, create_repo
+from huggingface_hub.utils import insecure_hashlib
+from packaging import version
+from PIL import Image
+from torch.utils.data import Dataset
+from torchvision import transforms
+from tqdm.auto import tqdm
+from transformers import AutoTokenizer, PretrainedConfig
+
+import diffusers
+from diffusers import (
+ AutoencoderKL,
+ DDPMScheduler,
+ DiffusionPipeline,
+ DPMSolverMultistepScheduler,
+ UNet2DConditionModel,
+)
+from diffusers.loaders import AttnProcsLayers
+from diffusers.models.attention_processor import (
+ CustomDiffusionAttnProcessor,
+ CustomDiffusionAttnProcessor2_0,
+ CustomDiffusionXFormersAttnProcessor,
+)
+from diffusers.optimization import get_scheduler
+from diffusers.utils import check_min_version, is_wandb_available
+from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
+from diffusers.utils.import_utils import is_xformers_available
+
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.31.0.dev0")
+
+logger = get_logger(__name__)
+
+
+def freeze_params(params):
+ for param in params:
+ param.requires_grad = False
+
+
+def save_model_card(repo_id: str, images=None, base_model=str, prompt=str, repo_folder=None):
+ img_str = ""
+ for i, image in enumerate(images):
+ image.save(os.path.join(repo_folder, f"image_{i}.png"))
+ img_str += f"\n"
+
+ model_description = f"""
+# Custom Diffusion - {repo_id}
+
+These are Custom Diffusion adaption weights for {base_model}. The weights were trained on {prompt} using [Custom Diffusion](https://www.cs.cmu.edu/~custom-diffusion). You can find some example images in the following. \n
+{img_str}
+
+\nFor more details on the training, please follow [this link](https://github.com/huggingface/diffusers/blob/main/examples/custom_diffusion).
+"""
+ model_card = load_or_create_model_card(
+ repo_id_or_path=repo_id,
+ from_training=True,
+ license="creativeml-openrail-m",
+ base_model=base_model,
+ prompt=prompt,
+ model_description=model_description,
+ inference=True,
+ )
+
+ tags = [
+ "text-to-image",
+ "diffusers",
+ "stable-diffusion",
+ "stable-diffusion-diffusers",
+ "custom-diffusion",
+ "diffusers-training",
+ ]
+ model_card = populate_model_card(model_card, tags=tags)
+
+ model_card.save(os.path.join(repo_folder, "README.md"))
+
+
+def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
+ text_encoder_config = PretrainedConfig.from_pretrained(
+ pretrained_model_name_or_path,
+ subfolder="text_encoder",
+ revision=revision,
+ )
+ model_class = text_encoder_config.architectures[0]
+
+ if model_class == "CLIPTextModel":
+ from transformers import CLIPTextModel
+
+ return CLIPTextModel
+ elif model_class == "RobertaSeriesModelWithTransformation":
+ from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation
+
+ return RobertaSeriesModelWithTransformation
+ else:
+ raise ValueError(f"{model_class} is not supported.")
+
+
+def collate_fn(examples, with_prior_preservation):
+ input_ids = [example["instance_prompt_ids"] for example in examples]
+ pixel_values = [example["instance_images"] for example in examples]
+ mask = [example["mask"] for example in examples]
+ # Concat class and instance examples for prior preservation.
+ # We do this to avoid doing two forward passes.
+ if with_prior_preservation:
+ input_ids += [example["class_prompt_ids"] for example in examples]
+ pixel_values += [example["class_images"] for example in examples]
+ mask += [example["class_mask"] for example in examples]
+
+ input_ids = torch.cat(input_ids, dim=0)
+ pixel_values = torch.stack(pixel_values)
+ mask = torch.stack(mask)
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
+ mask = mask.to(memory_format=torch.contiguous_format).float()
+
+ batch = {"input_ids": input_ids, "pixel_values": pixel_values, "mask": mask.unsqueeze(1)}
+ return batch
+
+
+class PromptDataset(Dataset):
+ """A simple dataset to prepare the prompts to generate class images on multiple GPUs."""
+
+ def __init__(self, prompt, num_samples):
+ self.prompt = prompt
+ self.num_samples = num_samples
+
+ def __len__(self):
+ return self.num_samples
+
+ def __getitem__(self, index):
+ example = {}
+ example["prompt"] = self.prompt
+ example["index"] = index
+ return example
+
+
+class CustomDiffusionDataset(Dataset):
+ """
+ A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
+ It pre-processes the images and the tokenizes prompts.
+ """
+
+ def __init__(
+ self,
+ concepts_list,
+ tokenizer,
+ size=512,
+ mask_size=64,
+ center_crop=False,
+ with_prior_preservation=False,
+ num_class_images=200,
+ hflip=False,
+ aug=True,
+ ):
+ self.size = size
+ self.mask_size = mask_size
+ self.center_crop = center_crop
+ self.tokenizer = tokenizer
+ self.interpolation = Image.BILINEAR
+ self.aug = aug
+
+ self.instance_images_path = []
+ self.class_images_path = []
+ self.with_prior_preservation = with_prior_preservation
+ for concept in concepts_list:
+ inst_img_path = [
+ (x, concept["instance_prompt"]) for x in Path(concept["instance_data_dir"]).iterdir() if x.is_file()
+ ]
+ self.instance_images_path.extend(inst_img_path)
+
+ if with_prior_preservation:
+ class_data_root = Path(concept["class_data_dir"])
+ if os.path.isdir(class_data_root):
+ class_images_path = list(class_data_root.iterdir())
+ class_prompt = [concept["class_prompt"] for _ in range(len(class_images_path))]
+ else:
+ with open(class_data_root, "r") as f:
+ class_images_path = f.read().splitlines()
+ with open(concept["class_prompt"], "r") as f:
+ class_prompt = f.read().splitlines()
+
+ class_img_path = list(zip(class_images_path, class_prompt))
+ self.class_images_path.extend(class_img_path[:num_class_images])
+
+ random.shuffle(self.instance_images_path)
+ self.num_instance_images = len(self.instance_images_path)
+ self.num_class_images = len(self.class_images_path)
+ self._length = max(self.num_class_images, self.num_instance_images)
+ self.flip = transforms.RandomHorizontalFlip(0.5 * hflip)
+
+ self.image_transforms = transforms.Compose(
+ [
+ self.flip,
+ transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
+ transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+ )
+
+ def __len__(self):
+ return self._length
+
+ def preprocess(self, image, scale, resample):
+ outer, inner = self.size, scale
+ factor = self.size // self.mask_size
+ if scale > self.size:
+ outer, inner = scale, self.size
+ top, left = np.random.randint(0, outer - inner + 1), np.random.randint(0, outer - inner + 1)
+ image = image.resize((scale, scale), resample=resample)
+ image = np.array(image).astype(np.uint8)
+ image = (image / 127.5 - 1.0).astype(np.float32)
+ instance_image = np.zeros((self.size, self.size, 3), dtype=np.float32)
+ mask = np.zeros((self.size // factor, self.size // factor))
+ if scale > self.size:
+ instance_image = image[top : top + inner, left : left + inner, :]
+ mask = np.ones((self.size // factor, self.size // factor))
+ else:
+ instance_image[top : top + inner, left : left + inner, :] = image
+ mask[
+ top // factor + 1 : (top + scale) // factor - 1, left // factor + 1 : (left + scale) // factor - 1
+ ] = 1.0
+ return instance_image, mask
+
+ def __getitem__(self, index):
+ example = {}
+ instance_image, instance_prompt = self.instance_images_path[index % self.num_instance_images]
+ instance_image = Image.open(instance_image)
+ if not instance_image.mode == "RGB":
+ instance_image = instance_image.convert("RGB")
+ instance_image = self.flip(instance_image)
+
+ # apply resize augmentation and create a valid image region mask
+ random_scale = self.size
+ if self.aug:
+ random_scale = (
+ np.random.randint(self.size // 3, self.size + 1)
+ if np.random.uniform() < 0.66
+ else np.random.randint(int(1.2 * self.size), int(1.4 * self.size))
+ )
+ instance_image, mask = self.preprocess(instance_image, random_scale, self.interpolation)
+
+ if random_scale < 0.6 * self.size:
+ instance_prompt = np.random.choice(["a far away ", "very small "]) + instance_prompt
+ elif random_scale > self.size:
+ instance_prompt = np.random.choice(["zoomed in ", "close up "]) + instance_prompt
+
+ example["instance_images"] = torch.from_numpy(instance_image).permute(2, 0, 1)
+ example["mask"] = torch.from_numpy(mask)
+ example["instance_prompt_ids"] = self.tokenizer(
+ instance_prompt,
+ truncation=True,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ return_tensors="pt",
+ ).input_ids
+
+ if self.with_prior_preservation:
+ class_image, class_prompt = self.class_images_path[index % self.num_class_images]
+ class_image = Image.open(class_image)
+ if not class_image.mode == "RGB":
+ class_image = class_image.convert("RGB")
+ example["class_images"] = self.image_transforms(class_image)
+ example["class_mask"] = torch.ones_like(example["mask"])
+ example["class_prompt_ids"] = self.tokenizer(
+ class_prompt,
+ truncation=True,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ return_tensors="pt",
+ ).input_ids
+
+ return example
+
+
+def save_new_embed(text_encoder, modifier_token_id, accelerator, args, output_dir, safe_serialization=True):
+ """Saves the new token embeddings from the text encoder."""
+ logger.info("Saving embeddings")
+ learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight
+ for x, y in zip(modifier_token_id, args.modifier_token):
+ learned_embeds_dict = {}
+ learned_embeds_dict[y] = learned_embeds[x]
+
+ if safe_serialization:
+ filename = f"{output_dir}/{y}.safetensors"
+ safetensors.torch.save_file(learned_embeds_dict, filename, metadata={"format": "pt"})
+ else:
+ filename = f"{output_dir}/{y}.bin"
+ torch.save(learned_embeds_dict, filename)
+
+
+def parse_args(input_args=None):
+ parser = argparse.ArgumentParser(description="Custom Diffusion training script.")
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--variant",
+ type=str,
+ default=None,
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
+ )
+ parser.add_argument(
+ "--tokenizer_name",
+ type=str,
+ default=None,
+ help="Pretrained tokenizer name or path if not the same as model_name",
+ )
+ parser.add_argument(
+ "--instance_data_dir",
+ type=str,
+ default=None,
+ help="A folder containing the training data of instance images.",
+ )
+ parser.add_argument(
+ "--class_data_dir",
+ type=str,
+ default=None,
+ help="A folder containing the training data of class images.",
+ )
+ parser.add_argument(
+ "--instance_prompt",
+ type=str,
+ default=None,
+ help="The prompt with identifier specifying the instance",
+ )
+ parser.add_argument(
+ "--class_prompt",
+ type=str,
+ default=None,
+ help="The prompt to specify images in the same class as provided instance images.",
+ )
+ parser.add_argument(
+ "--validation_prompt",
+ type=str,
+ default=None,
+ help="A prompt that is used during validation to verify that the model is learning.",
+ )
+ parser.add_argument(
+ "--num_validation_images",
+ type=int,
+ default=2,
+ help="Number of images that should be generated during validation with `validation_prompt`.",
+ )
+ parser.add_argument(
+ "--validation_steps",
+ type=int,
+ default=50,
+ help=(
+ "Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt"
+ " `args.validation_prompt` multiple times: `args.num_validation_images`."
+ ),
+ )
+ parser.add_argument(
+ "--with_prior_preservation",
+ default=False,
+ action="store_true",
+ help="Flag to add prior preservation loss.",
+ )
+ parser.add_argument(
+ "--real_prior",
+ default=False,
+ action="store_true",
+ help="real images as prior.",
+ )
+ parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.")
+ parser.add_argument(
+ "--num_class_images",
+ type=int,
+ default=200,
+ help=(
+ "Minimal class images for prior preservation loss. If there are not enough images already present in"
+ " class_data_dir, additional images will be sampled with class_prompt."
+ ),
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="custom-diffusion-model",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--center_crop",
+ default=False,
+ action="store_true",
+ help=(
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
+ " cropped. The images will be resized to the resolution first before cropping."
+ ),
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument(
+ "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=1)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=250,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
+ " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-5,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=2,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ parser.add_argument(
+ "--freeze_model",
+ type=str,
+ default="crossattn_kv",
+ choices=["crossattn_kv", "crossattn"],
+ help="crossattn to enable fine-tuning of all params in the cross attention",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument(
+ "--prior_generation_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp32", "fp16", "bf16"],
+ help=(
+ "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32."
+ ),
+ )
+ parser.add_argument(
+ "--concepts_list",
+ type=str,
+ default=None,
+ help="Path to json containing multiple concepts, will overwrite parameters like instance_prompt, class_prompt, etc.",
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+ parser.add_argument(
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
+ )
+ parser.add_argument(
+ "--set_grads_to_none",
+ action="store_true",
+ help=(
+ "Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain"
+ " behaviors, so disable this argument if it causes any problems. More info:"
+ " https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html"
+ ),
+ )
+ parser.add_argument(
+ "--modifier_token",
+ type=str,
+ default=None,
+ help="A token to use as a modifier for the concept.",
+ )
+ parser.add_argument(
+ "--initializer_token", type=str, default="ktn+pll+ucd", help="A token to use as initializer word."
+ )
+ parser.add_argument("--hflip", action="store_true", help="Apply horizontal flip data augmentation.")
+ parser.add_argument(
+ "--noaug",
+ action="store_true",
+ help="Dont apply augmentation during data augmentation when this flag is enabled.",
+ )
+ parser.add_argument(
+ "--no_safe_serialization",
+ action="store_true",
+ help="If specified save the checkpoint not in `safetensors` format, but in original PyTorch format instead.",
+ )
+
+ if input_args is not None:
+ args = parser.parse_args(input_args)
+ else:
+ args = parser.parse_args()
+
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ if args.with_prior_preservation:
+ if args.concepts_list is None:
+ if args.class_data_dir is None:
+ raise ValueError("You must specify a data directory for class images.")
+ if args.class_prompt is None:
+ raise ValueError("You must specify prompt for class images.")
+ else:
+ # logger is not available yet
+ if args.class_data_dir is not None:
+ warnings.warn("You need not use --class_data_dir without --with_prior_preservation.")
+ if args.class_prompt is not None:
+ warnings.warn("You need not use --class_prompt without --with_prior_preservation.")
+
+ return args
+
+
+def main(args):
+ if args.report_to == "wandb" and args.hub_token is not None:
+ raise ValueError(
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
+ " Please use `huggingface-cli login` to authenticate with the Hub."
+ )
+
+ logging_dir = Path(args.output_dir, args.logging_dir)
+
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ )
+
+ # Disable AMP for MPS.
+ if torch.backends.mps.is_available():
+ accelerator.native_amp = False
+
+ if args.report_to == "wandb":
+ if not is_wandb_available():
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
+ import wandb
+
+ # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
+ # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
+ # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ accelerator.init_trackers("custom-diffusion", config=vars(args))
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+ if args.concepts_list is None:
+ args.concepts_list = [
+ {
+ "instance_prompt": args.instance_prompt,
+ "class_prompt": args.class_prompt,
+ "instance_data_dir": args.instance_data_dir,
+ "class_data_dir": args.class_data_dir,
+ }
+ ]
+ else:
+ with open(args.concepts_list, "r") as f:
+ args.concepts_list = json.load(f)
+
+ # Generate class images if prior preservation is enabled.
+ if args.with_prior_preservation:
+ for i, concept in enumerate(args.concepts_list):
+ class_images_dir = Path(concept["class_data_dir"])
+ if not class_images_dir.exists():
+ class_images_dir.mkdir(parents=True, exist_ok=True)
+ if args.real_prior:
+ assert (
+ class_images_dir / "images"
+ ).exists(), f"Please run: python retrieve.py --class_prompt \"{concept['class_prompt']}\" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}"
+ assert (
+ len(list((class_images_dir / "images").iterdir())) == args.num_class_images
+ ), f"Please run: python retrieve.py --class_prompt \"{concept['class_prompt']}\" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}"
+ assert (
+ class_images_dir / "caption.txt"
+ ).exists(), f"Please run: python retrieve.py --class_prompt \"{concept['class_prompt']}\" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}"
+ assert (
+ class_images_dir / "images.txt"
+ ).exists(), f"Please run: python retrieve.py --class_prompt \"{concept['class_prompt']}\" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}"
+ concept["class_prompt"] = os.path.join(class_images_dir, "caption.txt")
+ concept["class_data_dir"] = os.path.join(class_images_dir, "images.txt")
+ args.concepts_list[i] = concept
+ accelerator.wait_for_everyone()
+ else:
+ cur_class_images = len(list(class_images_dir.iterdir()))
+
+ if cur_class_images < args.num_class_images:
+ torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
+ if args.prior_generation_precision == "fp32":
+ torch_dtype = torch.float32
+ elif args.prior_generation_precision == "fp16":
+ torch_dtype = torch.float16
+ elif args.prior_generation_precision == "bf16":
+ torch_dtype = torch.bfloat16
+ pipeline = DiffusionPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ torch_dtype=torch_dtype,
+ safety_checker=None,
+ revision=args.revision,
+ variant=args.variant,
+ )
+ pipeline.set_progress_bar_config(disable=True)
+
+ num_new_images = args.num_class_images - cur_class_images
+ logger.info(f"Number of class images to sample: {num_new_images}.")
+
+ sample_dataset = PromptDataset(concept["class_prompt"], num_new_images)
+ sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
+
+ sample_dataloader = accelerator.prepare(sample_dataloader)
+ pipeline.to(accelerator.device)
+
+ for example in tqdm(
+ sample_dataloader,
+ desc="Generating class images",
+ disable=not accelerator.is_local_main_process,
+ ):
+ images = pipeline(example["prompt"]).images
+
+ for i, image in enumerate(images):
+ hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest()
+ image_filename = (
+ class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
+ )
+ image.save(image_filename)
+
+ del pipeline
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
+ ).repo_id
+
+ # Load the tokenizer
+ if args.tokenizer_name:
+ tokenizer = AutoTokenizer.from_pretrained(
+ args.tokenizer_name,
+ revision=args.revision,
+ use_fast=False,
+ )
+ elif args.pretrained_model_name_or_path:
+ tokenizer = AutoTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="tokenizer",
+ revision=args.revision,
+ use_fast=False,
+ )
+
+ # import correct text encoder class
+ text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)
+
+ # Load scheduler and models
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
+ text_encoder = text_encoder_cls.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
+ )
+ vae = AutoencoderKL.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant
+ )
+ unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
+ )
+
+ # Adding a modifier token which is optimized ####
+ # Code taken from https://github.com/huggingface/diffusers/blob/main/examples/textual_inversion/textual_inversion.py
+ modifier_token_id = []
+ initializer_token_id = []
+ if args.modifier_token is not None:
+ args.modifier_token = args.modifier_token.split("+")
+ args.initializer_token = args.initializer_token.split("+")
+ if len(args.modifier_token) > len(args.initializer_token):
+ raise ValueError("You must specify + separated initializer token for each modifier token.")
+ for modifier_token, initializer_token in zip(
+ args.modifier_token, args.initializer_token[: len(args.modifier_token)]
+ ):
+ # Add the placeholder token in tokenizer
+ num_added_tokens = tokenizer.add_tokens(modifier_token)
+ if num_added_tokens == 0:
+ raise ValueError(
+ f"The tokenizer already contains the token {modifier_token}. Please pass a different"
+ " `modifier_token` that is not already in the tokenizer."
+ )
+
+ # Convert the initializer_token, placeholder_token to ids
+ token_ids = tokenizer.encode([initializer_token], add_special_tokens=False)
+ print(token_ids)
+ # Check if initializer_token is a single token or a sequence of tokens
+ if len(token_ids) > 1:
+ raise ValueError("The initializer token must be a single token.")
+
+ initializer_token_id.append(token_ids[0])
+ modifier_token_id.append(tokenizer.convert_tokens_to_ids(modifier_token))
+
+ # Resize the token embeddings as we are adding new special tokens to the tokenizer
+ text_encoder.resize_token_embeddings(len(tokenizer))
+
+ # Initialise the newly added placeholder token with the embeddings of the initializer token
+ token_embeds = text_encoder.get_input_embeddings().weight.data
+ for x, y in zip(modifier_token_id, initializer_token_id):
+ token_embeds[x] = token_embeds[y]
+
+ # Freeze all parameters except for the token embeddings in text encoder
+ params_to_freeze = itertools.chain(
+ text_encoder.text_model.encoder.parameters(),
+ text_encoder.text_model.final_layer_norm.parameters(),
+ text_encoder.text_model.embeddings.position_embedding.parameters(),
+ )
+ freeze_params(params_to_freeze)
+ ########################################################
+ ########################################################
+
+ vae.requires_grad_(False)
+ if args.modifier_token is None:
+ text_encoder.requires_grad_(False)
+ unet.requires_grad_(False)
+ # For mixed precision training we cast the text_encoder and vae weights to half-precision
+ # as these models are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ # Move unet, vae and text_encoder to device and cast to weight_dtype
+ if accelerator.mixed_precision != "fp16" and args.modifier_token is not None:
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
+ unet.to(accelerator.device, dtype=weight_dtype)
+ vae.to(accelerator.device, dtype=weight_dtype)
+
+ attention_class = (
+ CustomDiffusionAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else CustomDiffusionAttnProcessor
+ )
+ if args.enable_xformers_memory_efficient_attention:
+ if is_xformers_available():
+ import xformers
+
+ xformers_version = version.parse(xformers.__version__)
+ if xformers_version == version.parse("0.0.16"):
+ logger.warning(
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
+ )
+ attention_class = CustomDiffusionXFormersAttnProcessor
+ else:
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
+
+ # now we will add new Custom Diffusion weights to the attention layers
+ # It's important to realize here how many attention weights will be added and of which sizes
+ # The sizes of the attention layers consist only of two different variables:
+ # 1) - the "hidden_size", which is increased according to `unet.config.block_out_channels`.
+ # 2) - the "cross attention size", which is set to `unet.config.cross_attention_dim`.
+
+ # Let's first see how many attention processors we will have to set.
+ # For Stable Diffusion, it should be equal to:
+ # - down blocks (2x attention layers) * (2x transformer layers) * (3x down blocks) = 12
+ # - mid blocks (2x attention layers) * (1x transformer layers) * (1x mid blocks) = 2
+ # - up blocks (2x attention layers) * (3x transformer layers) * (3x down blocks) = 18
+ # => 32 layers
+
+ # Only train key, value projection layers if freeze_model = 'crossattn_kv' else train all params in the cross attention layer
+ train_kv = True
+ train_q_out = False if args.freeze_model == "crossattn_kv" else True
+ custom_diffusion_attn_procs = {}
+
+ st = unet.state_dict()
+ for name, _ in unet.attn_processors.items():
+ cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
+ if name.startswith("mid_block"):
+ hidden_size = unet.config.block_out_channels[-1]
+ elif name.startswith("up_blocks"):
+ block_id = int(name[len("up_blocks.")])
+ hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
+ elif name.startswith("down_blocks"):
+ block_id = int(name[len("down_blocks.")])
+ hidden_size = unet.config.block_out_channels[block_id]
+ layer_name = name.split(".processor")[0]
+ weights = {
+ "to_k_custom_diffusion.weight": st[layer_name + ".to_k.weight"],
+ "to_v_custom_diffusion.weight": st[layer_name + ".to_v.weight"],
+ }
+ if train_q_out:
+ weights["to_q_custom_diffusion.weight"] = st[layer_name + ".to_q.weight"]
+ weights["to_out_custom_diffusion.0.weight"] = st[layer_name + ".to_out.0.weight"]
+ weights["to_out_custom_diffusion.0.bias"] = st[layer_name + ".to_out.0.bias"]
+ if cross_attention_dim is not None:
+ custom_diffusion_attn_procs[name] = attention_class(
+ train_kv=train_kv,
+ train_q_out=train_q_out,
+ hidden_size=hidden_size,
+ cross_attention_dim=cross_attention_dim,
+ ).to(unet.device)
+ custom_diffusion_attn_procs[name].load_state_dict(weights)
+ else:
+ custom_diffusion_attn_procs[name] = attention_class(
+ train_kv=False,
+ train_q_out=False,
+ hidden_size=hidden_size,
+ cross_attention_dim=cross_attention_dim,
+ )
+ del st
+ unet.set_attn_processor(custom_diffusion_attn_procs)
+ custom_diffusion_layers = AttnProcsLayers(unet.attn_processors)
+
+ accelerator.register_for_checkpointing(custom_diffusion_layers)
+
+ if args.gradient_checkpointing:
+ unet.enable_gradient_checkpointing()
+ if args.modifier_token is not None:
+ text_encoder.gradient_checkpointing_enable()
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+ if args.with_prior_preservation:
+ args.learning_rate = args.learning_rate * 2.0
+
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
+ )
+
+ optimizer_class = bnb.optim.AdamW8bit
+ else:
+ optimizer_class = torch.optim.AdamW
+
+ # Optimizer creation
+ optimizer = optimizer_class(
+ itertools.chain(text_encoder.get_input_embeddings().parameters(), custom_diffusion_layers.parameters())
+ if args.modifier_token is not None
+ else custom_diffusion_layers.parameters(),
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ # Dataset and DataLoaders creation:
+ train_dataset = CustomDiffusionDataset(
+ concepts_list=args.concepts_list,
+ tokenizer=tokenizer,
+ with_prior_preservation=args.with_prior_preservation,
+ size=args.resolution,
+ mask_size=vae.encode(
+ torch.randn(1, 3, args.resolution, args.resolution).to(dtype=weight_dtype).to(accelerator.device)
+ )
+ .latent_dist.sample()
+ .size()[-1],
+ center_crop=args.center_crop,
+ num_class_images=args.num_class_images,
+ hflip=args.hflip,
+ aug=not args.noaug,
+ )
+
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset,
+ batch_size=args.train_batch_size,
+ shuffle=True,
+ collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),
+ num_workers=args.dataloader_num_workers,
+ )
+
+ # Scheduler and math around the number of training steps.
+ # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
+ num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
+ if args.max_train_steps is None:
+ len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
+ num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
+ num_training_steps_for_scheduler = (
+ args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes
+ )
+ else:
+ num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=num_warmup_steps_for_scheduler,
+ num_training_steps=num_training_steps_for_scheduler,
+ )
+
+ # Prepare everything with our `accelerator`.
+ if args.modifier_token is not None:
+ custom_diffusion_layers, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ custom_diffusion_layers, text_encoder, optimizer, train_dataloader, lr_scheduler
+ )
+ else:
+ custom_diffusion_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ custom_diffusion_layers, optimizer, train_dataloader, lr_scheduler
+ )
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ if num_training_steps_for_scheduler != args.max_train_steps * accelerator.num_processes:
+ logger.warning(
+ f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
+ f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
+ f"This inconsistency may result in the learning rate scheduler not functioning properly."
+ )
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the most recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ initial_global_step = 0
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ initial_global_step = global_step
+ first_epoch = global_step // num_update_steps_per_epoch
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
+
+ for epoch in range(first_epoch, args.num_train_epochs):
+ unet.train()
+ if args.modifier_token is not None:
+ text_encoder.train()
+ for step, batch in enumerate(train_dataloader):
+ with accelerator.accumulate(unet), accelerator.accumulate(text_encoder):
+ # Convert images to latent space
+ latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
+ latents = latents * vae.config.scaling_factor
+
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(latents)
+ bsz = latents.shape[0]
+ # Sample a random timestep for each image
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
+ timesteps = timesteps.long()
+
+ # Add noise to the latents according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
+
+ # Get the text embedding for conditioning
+ encoder_hidden_states = text_encoder(batch["input_ids"])[0]
+
+ # Predict the noise residual
+ model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
+
+ # Get the target for loss depending on the prediction type
+ if noise_scheduler.config.prediction_type == "epsilon":
+ target = noise
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
+ else:
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
+
+ if args.with_prior_preservation:
+ # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
+ model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
+ target, target_prior = torch.chunk(target, 2, dim=0)
+ mask = torch.chunk(batch["mask"], 2, dim=0)[0]
+ # Compute instance loss
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
+ loss = ((loss * mask).sum([1, 2, 3]) / mask.sum([1, 2, 3])).mean()
+
+ # Compute prior loss
+ prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
+
+ # Add the prior loss to the instance loss.
+ loss = loss + args.prior_loss_weight * prior_loss
+ else:
+ mask = batch["mask"]
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
+ loss = ((loss * mask).sum([1, 2, 3]) / mask.sum([1, 2, 3])).mean()
+ accelerator.backward(loss)
+ # Zero out the gradients for all token embeddings except the newly added
+ # embeddings for the concept, as we only want to optimize the concept embeddings
+ if args.modifier_token is not None:
+ if accelerator.num_processes > 1:
+ grads_text_encoder = text_encoder.module.get_input_embeddings().weight.grad
+ else:
+ grads_text_encoder = text_encoder.get_input_embeddings().weight.grad
+ # Get the index for tokens that we want to zero the grads for
+ index_grads_to_zero = torch.arange(len(tokenizer)) != modifier_token_id[0]
+ for i in range(1, len(modifier_token_id)):
+ index_grads_to_zero = index_grads_to_zero & (
+ torch.arange(len(tokenizer)) != modifier_token_id[i]
+ )
+ grads_text_encoder.data[index_grads_to_zero, :] = grads_text_encoder.data[
+ index_grads_to_zero, :
+ ].fill_(0)
+
+ if accelerator.sync_gradients:
+ params_to_clip = (
+ itertools.chain(text_encoder.parameters(), custom_diffusion_layers.parameters())
+ if args.modifier_token is not None
+ else custom_diffusion_layers.parameters()
+ )
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad(set_to_none=args.set_grads_to_none)
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ global_step += 1
+
+ if global_step % args.checkpointing_steps == 0:
+ if accelerator.is_main_process:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+ accelerator.log(logs, step=global_step)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ if accelerator.is_main_process:
+ images = []
+
+ if args.validation_prompt is not None and global_step % args.validation_steps == 0:
+ logger.info(
+ f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
+ f" {args.validation_prompt}."
+ )
+ # create pipeline
+ pipeline = DiffusionPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ unet=accelerator.unwrap_model(unet),
+ text_encoder=accelerator.unwrap_model(text_encoder),
+ tokenizer=tokenizer,
+ revision=args.revision,
+ variant=args.variant,
+ torch_dtype=weight_dtype,
+ )
+ pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
+ pipeline = pipeline.to(accelerator.device)
+ pipeline.set_progress_bar_config(disable=True)
+
+ # run inference
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
+ images = [
+ pipeline(args.validation_prompt, num_inference_steps=25, generator=generator, eta=1.0).images[
+ 0
+ ]
+ for _ in range(args.num_validation_images)
+ ]
+
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
+ if tracker.name == "wandb":
+ tracker.log(
+ {
+ "validation": [
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
+ for i, image in enumerate(images)
+ ]
+ }
+ )
+
+ del pipeline
+ torch.cuda.empty_cache()
+
+ # Save the custom diffusion layers
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ unet = unet.to(torch.float32)
+ unet.save_attn_procs(args.output_dir, safe_serialization=not args.no_safe_serialization)
+ save_new_embed(
+ text_encoder,
+ modifier_token_id,
+ accelerator,
+ args,
+ args.output_dir,
+ safe_serialization=not args.no_safe_serialization,
+ )
+
+ # Final inference
+ # Load previous pipeline
+ pipeline = DiffusionPipeline.from_pretrained(
+ args.pretrained_model_name_or_path, revision=args.revision, variant=args.variant, torch_dtype=weight_dtype
+ )
+ pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
+ pipeline = pipeline.to(accelerator.device)
+
+ # load attention processors
+ weight_name = (
+ "pytorch_custom_diffusion_weights.safetensors"
+ if not args.no_safe_serialization
+ else "pytorch_custom_diffusion_weights.bin"
+ )
+ pipeline.unet.load_attn_procs(args.output_dir, weight_name=weight_name)
+ for token in args.modifier_token:
+ token_weight_name = f"{token}.safetensors" if not args.no_safe_serialization else f"{token}.bin"
+ pipeline.load_textual_inversion(args.output_dir, weight_name=token_weight_name)
+
+ # run inference
+ if args.validation_prompt and args.num_validation_images > 0:
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
+ images = [
+ pipeline(args.validation_prompt, num_inference_steps=25, generator=generator, eta=1.0).images[0]
+ for _ in range(args.num_validation_images)
+ ]
+
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC")
+ if tracker.name == "wandb":
+ tracker.log(
+ {
+ "test": [
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
+ for i, image in enumerate(images)
+ ]
+ }
+ )
+
+ if args.push_to_hub:
+ save_model_card(
+ repo_id,
+ images=images,
+ base_model=args.pretrained_model_name_or_path,
+ prompt=args.instance_prompt,
+ repo_folder=args.output_dir,
+ )
+ api = HfApi(token=args.hub_token)
+ api.upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ main(args)
diff --git a/diffusers/examples/dreambooth/README.md b/diffusers/examples/dreambooth/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..f97a4d0cd0f4fb31b09e0a1e537ab532ee5d2837
--- /dev/null
+++ b/diffusers/examples/dreambooth/README.md
@@ -0,0 +1,744 @@
+# DreamBooth training example
+
+[DreamBooth](https://arxiv.org/abs/2208.12242) is a method to personalize text2image models like stable diffusion given just a few(3~5) images of a subject.
+The `train_dreambooth.py` script shows how to implement the training procedure and adapt it for stable diffusion.
+
+
+## Running locally with PyTorch
+
+### Installing the dependencies
+
+Before running the scripts, make sure to install the library's training dependencies:
+
+**Important**
+
+To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
+```bash
+git clone https://github.com/huggingface/diffusers
+cd diffusers
+pip install -e .
+```
+
+Then cd in the example folder and run
+```bash
+pip install -r requirements.txt
+```
+
+And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
+
+```bash
+accelerate config
+```
+
+Or for a default accelerate configuration without answering questions about your environment
+
+```bash
+accelerate config default
+```
+
+Or if your environment doesn't support an interactive shell e.g. a notebook
+
+```python
+from accelerate.utils import write_basic_config
+write_basic_config()
+```
+
+When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
+Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.
+
+### Dog toy example
+
+Now let's get our dataset. For this example we will use some dog images: https://huggingface.co/datasets/diffusers/dog-example.
+
+Let's first download it locally:
+
+```python
+from huggingface_hub import snapshot_download
+
+local_dir = "./dog"
+snapshot_download(
+ "diffusers/dog-example",
+ local_dir=local_dir, repo_type="dataset",
+ ignore_patterns=".gitattributes",
+)
+```
+
+And launch the training using:
+
+**___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.___**
+
+```bash
+export MODEL_NAME="CompVis/stable-diffusion-v1-4"
+export INSTANCE_DIR="dog"
+export OUTPUT_DIR="path-to-save-model"
+
+accelerate launch train_dreambooth.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --instance_data_dir=$INSTANCE_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --instance_prompt="a photo of sks dog" \
+ --resolution=512 \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=1 \
+ --learning_rate=5e-6 \
+ --lr_scheduler="constant" \
+ --lr_warmup_steps=0 \
+ --max_train_steps=400 \
+ --push_to_hub
+```
+
+### Training with prior-preservation loss
+
+Prior-preservation is used to avoid overfitting and language-drift. Refer to the paper to learn more about it. For prior-preservation we first generate images using the model with a class prompt and then use those during training along with our data.
+According to the paper, it's recommended to generate `num_epochs * num_samples` images for prior-preservation. 200-300 works well for most cases. The `num_class_images` flag sets the number of images to generate with the class prompt. You can place existing images in `class_data_dir`, and the training script will generate any additional images so that `num_class_images` are present in `class_data_dir` during training time.
+
+```bash
+export MODEL_NAME="CompVis/stable-diffusion-v1-4"
+export INSTANCE_DIR="dog"
+export CLASS_DIR="path-to-class-images"
+export OUTPUT_DIR="path-to-save-model"
+
+accelerate launch train_dreambooth.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --instance_data_dir=$INSTANCE_DIR \
+ --class_data_dir=$CLASS_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --with_prior_preservation --prior_loss_weight=1.0 \
+ --instance_prompt="a photo of sks dog" \
+ --class_prompt="a photo of dog" \
+ --resolution=512 \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=1 \
+ --learning_rate=5e-6 \
+ --lr_scheduler="constant" \
+ --lr_warmup_steps=0 \
+ --num_class_images=200 \
+ --max_train_steps=800 \
+ --push_to_hub
+```
+
+
+### Training on a 16GB GPU:
+
+With the help of gradient checkpointing and the 8-bit optimizer from bitsandbytes it's possible to run train dreambooth on a 16GB GPU.
+
+To install `bitsandbytes` please refer to this [readme](https://github.com/TimDettmers/bitsandbytes#requirements--installation).
+
+```bash
+export MODEL_NAME="CompVis/stable-diffusion-v1-4"
+export INSTANCE_DIR="dog"
+export CLASS_DIR="path-to-class-images"
+export OUTPUT_DIR="path-to-save-model"
+
+accelerate launch train_dreambooth.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --instance_data_dir=$INSTANCE_DIR \
+ --class_data_dir=$CLASS_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --with_prior_preservation --prior_loss_weight=1.0 \
+ --instance_prompt="a photo of sks dog" \
+ --class_prompt="a photo of dog" \
+ --resolution=512 \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=2 --gradient_checkpointing \
+ --use_8bit_adam \
+ --learning_rate=5e-6 \
+ --lr_scheduler="constant" \
+ --lr_warmup_steps=0 \
+ --num_class_images=200 \
+ --max_train_steps=800 \
+ --push_to_hub
+```
+
+
+### Training on a 12GB GPU:
+
+It is possible to run dreambooth on a 12GB GPU by using the following optimizations:
+- [gradient checkpointing and the 8-bit optimizer](#training-on-a-16gb-gpu)
+- [xformers](#training-with-xformers)
+- [setting grads to none](#set-grads-to-none)
+
+```bash
+export MODEL_NAME="CompVis/stable-diffusion-v1-4"
+export INSTANCE_DIR="dog"
+export CLASS_DIR="path-to-class-images"
+export OUTPUT_DIR="path-to-save-model"
+
+accelerate launch train_dreambooth.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --instance_data_dir=$INSTANCE_DIR \
+ --class_data_dir=$CLASS_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --with_prior_preservation --prior_loss_weight=1.0 \
+ --instance_prompt="a photo of sks dog" \
+ --class_prompt="a photo of dog" \
+ --resolution=512 \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=1 --gradient_checkpointing \
+ --use_8bit_adam \
+ --enable_xformers_memory_efficient_attention \
+ --set_grads_to_none \
+ --learning_rate=2e-6 \
+ --lr_scheduler="constant" \
+ --lr_warmup_steps=0 \
+ --num_class_images=200 \
+ --max_train_steps=800 \
+ --push_to_hub
+```
+
+
+### Training on a 8 GB GPU:
+
+By using [DeepSpeed](https://www.deepspeed.ai/) it's possible to offload some
+tensors from VRAM to either CPU or NVME allowing to train with less VRAM.
+
+DeepSpeed needs to be enabled with `accelerate config`. During configuration
+answer yes to "Do you want to use DeepSpeed?". With DeepSpeed stage 2, fp16
+mixed precision and offloading both parameters and optimizer state to cpu it's
+possible to train on under 8 GB VRAM with a drawback of requiring significantly
+more RAM (about 25 GB). See [documentation](https://huggingface.co/docs/accelerate/usage_guides/deepspeed) for more DeepSpeed configuration options.
+
+Changing the default Adam optimizer to DeepSpeed's special version of Adam
+`deepspeed.ops.adam.DeepSpeedCPUAdam` gives a substantial speedup but enabling
+it requires CUDA toolchain with the same version as pytorch. 8-bit optimizer
+does not seem to be compatible with DeepSpeed at the moment.
+
+```bash
+export MODEL_NAME="CompVis/stable-diffusion-v1-4"
+export INSTANCE_DIR="dog"
+export CLASS_DIR="path-to-class-images"
+export OUTPUT_DIR="path-to-save-model"
+
+accelerate launch --mixed_precision="fp16" train_dreambooth.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --instance_data_dir=$INSTANCE_DIR \
+ --class_data_dir=$CLASS_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --with_prior_preservation --prior_loss_weight=1.0 \
+ --instance_prompt="a photo of sks dog" \
+ --class_prompt="a photo of dog" \
+ --resolution=512 \
+ --train_batch_size=1 \
+ --sample_batch_size=1 \
+ --gradient_accumulation_steps=1 --gradient_checkpointing \
+ --learning_rate=5e-6 \
+ --lr_scheduler="constant" \
+ --lr_warmup_steps=0 \
+ --num_class_images=200 \
+ --max_train_steps=800 \
+ --push_to_hub
+```
+
+### Fine-tune text encoder with the UNet.
+
+The script also allows to fine-tune the `text_encoder` along with the `unet`. It's been observed experimentally that fine-tuning `text_encoder` gives much better results especially on faces.
+Pass the `--train_text_encoder` argument to the script to enable training `text_encoder`.
+
+___Note: Training text encoder requires more memory, with this option the training won't fit on 16GB GPU. It needs at least 24GB VRAM.___
+
+```bash
+export MODEL_NAME="CompVis/stable-diffusion-v1-4"
+export INSTANCE_DIR="dog"
+export CLASS_DIR="path-to-class-images"
+export OUTPUT_DIR="path-to-save-model"
+
+accelerate launch train_dreambooth.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --train_text_encoder \
+ --instance_data_dir=$INSTANCE_DIR \
+ --class_data_dir=$CLASS_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --with_prior_preservation --prior_loss_weight=1.0 \
+ --instance_prompt="a photo of sks dog" \
+ --class_prompt="a photo of dog" \
+ --resolution=512 \
+ --train_batch_size=1 \
+ --use_8bit_adam \
+ --gradient_checkpointing \
+ --learning_rate=2e-6 \
+ --lr_scheduler="constant" \
+ --lr_warmup_steps=0 \
+ --num_class_images=200 \
+ --max_train_steps=800 \
+ --push_to_hub
+```
+
+### Using DreamBooth for pipelines other than Stable Diffusion
+
+The [AltDiffusion pipeline](https://huggingface.co/docs/diffusers/api/pipelines/alt_diffusion) also supports dreambooth fine-tuning. The process is the same as above, all you need to do is replace the `MODEL_NAME` like this:
+
+```
+export MODEL_NAME="CompVis/stable-diffusion-v1-4" --> export MODEL_NAME="BAAI/AltDiffusion-m9"
+or
+export MODEL_NAME="CompVis/stable-diffusion-v1-4" --> export MODEL_NAME="BAAI/AltDiffusion"
+```
+
+### Inference
+
+Once you have trained a model using the above command, you can run inference simply using the `StableDiffusionPipeline`. Make sure to include the `identifier` (e.g. sks in above example) in your prompt.
+
+```python
+from diffusers import StableDiffusionPipeline
+import torch
+
+model_id = "path-to-your-trained-model"
+pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
+
+prompt = "A photo of sks dog in a bucket"
+image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5).images[0]
+
+image.save("dog-bucket.png")
+```
+
+### Inference from a training checkpoint
+
+You can also perform inference from one of the checkpoints saved during the training process, if you used the `--checkpointing_steps` argument. Please, refer to [the documentation](https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint) to see how to do it.
+
+## Training with Low-Rank Adaptation of Large Language Models (LoRA)
+
+Low-Rank Adaption of Large Language Models was first introduced by Microsoft in [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen*
+
+In a nutshell, LoRA allows to adapt pretrained models by adding pairs of rank-decomposition matrices to existing weights and **only** training those newly added weights. This has a couple of advantages:
+- Previous pretrained weights are kept frozen so that the model is not prone to [catastrophic forgetting](https://www.pnas.org/doi/10.1073/pnas.1611835114)
+- Rank-decomposition matrices have significantly fewer parameters than the original model, which means that trained LoRA weights are easily portable.
+- LoRA attention layers allow to control to which extent the model is adapted towards new training images via a `scale` parameter.
+
+[cloneofsimo](https://github.com/cloneofsimo) was the first to try out LoRA training for Stable Diffusion in
+the popular [lora](https://github.com/cloneofsimo/lora) GitHub repository.
+
+### Training
+
+Let's get started with a simple example. We will re-use the dog example of the [previous section](#dog-toy-example).
+
+First, you need to set-up your dreambooth training example as is explained in the [installation section](#Installing-the-dependencies).
+Next, let's download the dog dataset. Download images from [here](https://drive.google.com/drive/folders/1BO_dyz-p65qhBRRMRA4TbZ8qW4rB99JZ) and save them in a directory. Make sure to set `INSTANCE_DIR` to the name of your directory further below. This will be our training data.
+
+Now, you can launch the training. Here we will use [Stable Diffusion 1-5](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5).
+
+**___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.___**
+
+**___Note: It is quite useful to monitor the training progress by regularly generating sample images during training. [wandb](https://docs.wandb.ai/quickstart) is a nice solution to easily see generating images during training. All you need to do is to run `pip install wandb` before training and pass `--report_to="wandb"` to automatically log images.___**
+
+
+```bash
+export MODEL_NAME="stable-diffusion-v1-5/stable-diffusion-v1-5"
+export INSTANCE_DIR="dog"
+export OUTPUT_DIR="path-to-save-model"
+```
+
+For this example we want to directly store the trained LoRA embeddings on the Hub, so
+we need to be logged in and add the `--push_to_hub` flag.
+
+```bash
+huggingface-cli login
+```
+
+Now we can start training!
+
+```bash
+accelerate launch train_dreambooth_lora.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --instance_data_dir=$INSTANCE_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --instance_prompt="a photo of sks dog" \
+ --resolution=512 \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=1 \
+ --checkpointing_steps=100 \
+ --learning_rate=1e-4 \
+ --report_to="wandb" \
+ --lr_scheduler="constant" \
+ --lr_warmup_steps=0 \
+ --max_train_steps=500 \
+ --validation_prompt="A photo of sks dog in a bucket" \
+ --validation_epochs=50 \
+ --seed="0" \
+ --push_to_hub
+```
+
+**___Note: When using LoRA we can use a much higher learning rate compared to vanilla dreambooth. Here we
+use *1e-4* instead of the usual *2e-6*.___**
+
+The final LoRA embedding weights have been uploaded to [patrickvonplaten/lora_dreambooth_dog_example](https://huggingface.co/patrickvonplaten/lora_dreambooth_dog_example). **___Note: [The final weights](https://huggingface.co/patrickvonplaten/lora/blob/main/pytorch_attn_procs.bin) are only 3 MB in size which is orders of magnitudes smaller than the original model.**
+
+The training results are summarized [here](https://api.wandb.ai/report/patrickvonplaten/xm6cd5q5).
+You can use the `Step` slider to see how the model learned the features of our subject while the model trained.
+
+Optionally, we can also train additional LoRA layers for the text encoder. Specify the `--train_text_encoder` argument above for that. If you're interested to know more about how we
+enable this support, check out this [PR](https://github.com/huggingface/diffusers/pull/2918).
+
+With the default hyperparameters from the above, the training seems to go in a positive direction. Check out [this panel](https://wandb.ai/sayakpaul/dreambooth-lora/reports/test-23-04-17-17-00-13---Vmlldzo0MDkwNjMy). The trained LoRA layers are available [here](https://huggingface.co/sayakpaul/dreambooth).
+
+
+### Inference
+
+After training, LoRA weights can be loaded very easily into the original pipeline. First, you need to
+load the original pipeline:
+
+```python
+from diffusers import DiffusionPipeline
+pipe = DiffusionPipeline.from_pretrained("base-model-name").to("cuda")
+```
+
+Next, we can load the adapter layers into the pipeline with the [`load_lora_weights` function](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters#lora).
+
+```python
+pipe.load_lora_weights("path-to-the-lora-checkpoint")
+```
+
+Finally, we can run the model in inference.
+
+```python
+image = pipe("A picture of a sks dog in a bucket", num_inference_steps=25).images[0]
+```
+
+If you are loading the LoRA parameters from the Hub and if the Hub repository has
+a `base_model` tag (such as [this](https://huggingface.co/patrickvonplaten/lora_dreambooth_dog_example/blob/main/README.md?code=true#L4)), then
+you can do:
+
+```py
+from huggingface_hub.repocard import RepoCard
+
+lora_model_id = "patrickvonplaten/lora_dreambooth_dog_example"
+card = RepoCard.load(lora_model_id)
+base_model_id = card.data.to_dict()["base_model"]
+
+pipe = StableDiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16)
+...
+```
+
+If you used `--train_text_encoder` during training, then use `pipe.load_lora_weights()` to load the LoRA
+weights. For example:
+
+```python
+from huggingface_hub.repocard import RepoCard
+from diffusers import StableDiffusionPipeline
+import torch
+
+lora_model_id = "sayakpaul/dreambooth-text-encoder-test"
+card = RepoCard.load(lora_model_id)
+base_model_id = card.data.to_dict()["base_model"]
+
+pipe = StableDiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16)
+pipe = pipe.to("cuda")
+pipe.load_lora_weights(lora_model_id)
+image = pipe("A picture of a sks dog in a bucket", num_inference_steps=25).images[0]
+```
+
+Note that the use of [`StableDiffusionLoraLoaderMixin.load_lora_weights`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.StableDiffusionLoraLoaderMixin.load_lora_weights) is preferred to [`UNet2DConditionLoadersMixin.load_attn_procs`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.UNet2DConditionLoadersMixin.load_attn_procs) for loading LoRA parameters. This is because
+`StableDiffusionLoraLoaderMixin.load_lora_weights` can handle the following situations:
+
+* LoRA parameters that don't have separate identifiers for the UNet and the text encoder (such as [`"patrickvonplaten/lora_dreambooth_dog_example"`](https://huggingface.co/patrickvonplaten/lora_dreambooth_dog_example)). So, you can just do:
+
+ ```py
+ pipe.load_lora_weights(lora_model_path)
+ ```
+
+* LoRA parameters that have separate identifiers for the UNet and the text encoder such as: [`"sayakpaul/dreambooth"`](https://huggingface.co/sayakpaul/dreambooth).
+
+## Training with Flax/JAX
+
+For faster training on TPUs and GPUs you can leverage the flax training example. Follow the instructions above to get the model and dataset before running the script.
+
+____Note: The flax example don't yet support features like gradient checkpoint, gradient accumulation etc, so to use flax for faster training we will need >30GB cards.___
+
+
+Before running the scripts, make sure to install the library's training dependencies:
+
+```bash
+pip install -U -r requirements_flax.txt
+```
+
+
+### Training without prior preservation loss
+
+```bash
+export MODEL_NAME="duongna/stable-diffusion-v1-4-flax"
+export INSTANCE_DIR="dog"
+export OUTPUT_DIR="path-to-save-model"
+
+python train_dreambooth_flax.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --instance_data_dir=$INSTANCE_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --instance_prompt="a photo of sks dog" \
+ --resolution=512 \
+ --train_batch_size=1 \
+ --learning_rate=5e-6 \
+ --max_train_steps=400
+```
+
+
+### Training with prior preservation loss
+
+```bash
+export MODEL_NAME="duongna/stable-diffusion-v1-4-flax"
+export INSTANCE_DIR="dog"
+export CLASS_DIR="path-to-class-images"
+export OUTPUT_DIR="path-to-save-model"
+
+python train_dreambooth_flax.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --instance_data_dir=$INSTANCE_DIR \
+ --class_data_dir=$CLASS_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --with_prior_preservation --prior_loss_weight=1.0 \
+ --instance_prompt="a photo of sks dog" \
+ --class_prompt="a photo of dog" \
+ --resolution=512 \
+ --train_batch_size=1 \
+ --learning_rate=5e-6 \
+ --num_class_images=200 \
+ --max_train_steps=800
+```
+
+
+### Fine-tune text encoder with the UNet.
+
+```bash
+export MODEL_NAME="duongna/stable-diffusion-v1-4-flax"
+export INSTANCE_DIR="dog"
+export CLASS_DIR="path-to-class-images"
+export OUTPUT_DIR="path-to-save-model"
+
+python train_dreambooth_flax.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --train_text_encoder \
+ --instance_data_dir=$INSTANCE_DIR \
+ --class_data_dir=$CLASS_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --with_prior_preservation --prior_loss_weight=1.0 \
+ --instance_prompt="a photo of sks dog" \
+ --class_prompt="a photo of dog" \
+ --resolution=512 \
+ --train_batch_size=1 \
+ --learning_rate=2e-6 \
+ --num_class_images=200 \
+ --max_train_steps=800
+```
+
+### Training with xformers:
+You can enable memory efficient attention by [installing xFormers](https://github.com/facebookresearch/xformers#installing-xformers) and padding the `--enable_xformers_memory_efficient_attention` argument to the script. This is not available with the Flax/JAX implementation.
+
+You can also use Dreambooth to train the specialized in-painting model. See [the script in the research folder for details](https://github.com/huggingface/diffusers/tree/main/examples/research_projects/dreambooth_inpaint).
+
+### Set grads to none
+
+To save even more memory, pass the `--set_grads_to_none` argument to the script. This will set grads to None instead of zero. However, be aware that it changes certain behaviors, so if you start experiencing any problems, remove this argument.
+
+More info: https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html
+
+### Experimental results
+You can refer to [this blog post](https://huggingface.co/blog/dreambooth) that discusses some of DreamBooth experiments in detail. Specifically, it recommends a set of DreamBooth-specific tips and tricks that we have found to work well for a variety of subjects.
+
+## IF
+
+You can use the lora and full dreambooth scripts to train the text to image [IF model](https://huggingface.co/DeepFloyd/IF-I-XL-v1.0) and the stage II upscaler
+[IF model](https://huggingface.co/DeepFloyd/IF-II-L-v1.0).
+
+Note that IF has a predicted variance, and our finetuning scripts only train the models predicted error, so for finetuned IF models we switch to a fixed
+variance schedule. The full finetuning scripts will update the scheduler config for the full saved model. However, when loading saved LoRA weights, you
+must also update the pipeline's scheduler config.
+
+```py
+from diffusers import DiffusionPipeline
+
+pipe = DiffusionPipeline.from_pretrained("DeepFloyd/IF-I-XL-v1.0")
+
+pipe.load_lora_weights("")
+
+# Update scheduler config to fixed variance schedule
+pipe.scheduler = pipe.scheduler.__class__.from_config(pipe.scheduler.config, variance_type="fixed_small")
+```
+
+Additionally, a few alternative cli flags are needed for IF.
+
+`--resolution=64`: IF is a pixel space diffusion model. In order to operate on un-compressed pixels, the input images are of a much smaller resolution.
+
+`--pre_compute_text_embeddings`: IF uses [T5](https://huggingface.co/docs/transformers/model_doc/t5) for its text encoder. In order to save GPU memory, we pre compute all text embeddings and then de-allocate
+T5.
+
+`--tokenizer_max_length=77`: T5 has a longer default text length, but the default IF encoding procedure uses a smaller number.
+
+`--text_encoder_use_attention_mask`: T5 passes the attention mask to the text encoder.
+
+### Tips and Tricks
+We find LoRA to be sufficient for finetuning the stage I model as the low resolution of the model makes representing finegrained detail hard regardless.
+
+For common and/or not-visually complex object concepts, you can get away with not-finetuning the upscaler. Just be sure to adjust the prompt passed to the
+upscaler to remove the new token from the instance prompt. I.e. if your stage I prompt is "a sks dog", use "a dog" for your stage II prompt.
+
+For finegrained detail like faces that aren't present in the original training set, we find that full finetuning of the stage II upscaler is better than
+LoRA finetuning stage II.
+
+For finegrained detail like faces, we find that lower learning rates along with larger batch sizes work best.
+
+For stage II, we find that lower learning rates are also needed.
+
+We found experimentally that the DDPM scheduler with the default larger number of denoising steps to sometimes work better than the DPM Solver scheduler
+used in the training scripts.
+
+### Stage II additional validation images
+
+The stage II validation requires images to upscale, we can download a downsized version of the training set:
+
+```py
+from huggingface_hub import snapshot_download
+
+local_dir = "./dog_downsized"
+snapshot_download(
+ "diffusers/dog-example-downsized",
+ local_dir=local_dir,
+ repo_type="dataset",
+ ignore_patterns=".gitattributes",
+)
+```
+
+### IF stage I LoRA Dreambooth
+This training configuration requires ~28 GB VRAM.
+
+```sh
+export MODEL_NAME="DeepFloyd/IF-I-XL-v1.0"
+export INSTANCE_DIR="dog"
+export OUTPUT_DIR="dreambooth_dog_lora"
+
+accelerate launch train_dreambooth_lora.py \
+ --report_to wandb \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --instance_data_dir=$INSTANCE_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --instance_prompt="a sks dog" \
+ --resolution=64 \
+ --train_batch_size=4 \
+ --gradient_accumulation_steps=1 \
+ --learning_rate=5e-6 \
+ --scale_lr \
+ --max_train_steps=1200 \
+ --validation_prompt="a sks dog" \
+ --validation_epochs=25 \
+ --checkpointing_steps=100 \
+ --pre_compute_text_embeddings \
+ --tokenizer_max_length=77 \
+ --text_encoder_use_attention_mask
+```
+
+### IF stage II LoRA Dreambooth
+
+`--validation_images`: These images are upscaled during validation steps.
+
+`--class_labels_conditioning=timesteps`: Pass additional conditioning to the UNet needed for stage II.
+
+`--learning_rate=1e-6`: Lower learning rate than stage I.
+
+`--resolution=256`: The upscaler expects higher resolution inputs
+
+```sh
+export MODEL_NAME="DeepFloyd/IF-II-L-v1.0"
+export INSTANCE_DIR="dog"
+export OUTPUT_DIR="dreambooth_dog_upscale"
+export VALIDATION_IMAGES="dog_downsized/image_1.png dog_downsized/image_2.png dog_downsized/image_3.png dog_downsized/image_4.png"
+
+python train_dreambooth_lora.py \
+ --report_to wandb \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --instance_data_dir=$INSTANCE_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --instance_prompt="a sks dog" \
+ --resolution=256 \
+ --train_batch_size=4 \
+ --gradient_accumulation_steps=1 \
+ --learning_rate=1e-6 \
+ --max_train_steps=2000 \
+ --validation_prompt="a sks dog" \
+ --validation_epochs=100 \
+ --checkpointing_steps=500 \
+ --pre_compute_text_embeddings \
+ --tokenizer_max_length=77 \
+ --text_encoder_use_attention_mask \
+ --validation_images $VALIDATION_IMAGES \
+ --class_labels_conditioning=timesteps
+```
+
+### IF Stage I Full Dreambooth
+`--skip_save_text_encoder`: When training the full model, this will skip saving the entire T5 with the finetuned model. You can still load the pipeline
+with a T5 loaded from the original model.
+
+`use_8bit_adam`: Due to the size of the optimizer states, we recommend training the full XL IF model with 8bit adam.
+
+`--learning_rate=1e-7`: For full dreambooth, IF requires very low learning rates. With higher learning rates model quality will degrade. Note that it is
+likely the learning rate can be increased with larger batch sizes.
+
+Using 8bit adam and a batch size of 4, the model can be trained in ~48 GB VRAM.
+
+`--validation_scheduler`: Set a particular scheduler via a string. We found that it is better to use the DDPMScheduler for validation when training DeepFloyd IF.
+
+```sh
+export MODEL_NAME="DeepFloyd/IF-I-XL-v1.0"
+
+export INSTANCE_DIR="dog"
+export OUTPUT_DIR="dreambooth_if"
+
+accelerate launch train_dreambooth.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --instance_data_dir=$INSTANCE_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --instance_prompt="a photo of sks dog" \
+ --resolution=64 \
+ --train_batch_size=4 \
+ --gradient_accumulation_steps=1 \
+ --learning_rate=1e-7 \
+ --max_train_steps=150 \
+ --validation_prompt "a photo of sks dog" \
+ --validation_steps 25 \
+ --text_encoder_use_attention_mask \
+ --tokenizer_max_length 77 \
+ --pre_compute_text_embeddings \
+ --use_8bit_adam \
+ --set_grads_to_none \
+ --skip_save_text_encoder \
+ --validation_scheduler DDPMScheduler \
+ --push_to_hub
+```
+
+### IF Stage II Full Dreambooth
+
+`--learning_rate=5e-6`: With a smaller effective batch size of 4, we found that we required learning rates as low as
+1e-8.
+
+`--resolution=256`: The upscaler expects higher resolution inputs
+
+`--train_batch_size=2` and `--gradient_accumulation_steps=6`: We found that full training of stage II particularly with
+faces required large effective batch sizes.
+
+```sh
+export MODEL_NAME="DeepFloyd/IF-II-L-v1.0"
+export INSTANCE_DIR="dog"
+export OUTPUT_DIR="dreambooth_dog_upscale"
+export VALIDATION_IMAGES="dog_downsized/image_1.png dog_downsized/image_2.png dog_downsized/image_3.png dog_downsized/image_4.png"
+
+accelerate launch train_dreambooth.py \
+ --report_to wandb \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --instance_data_dir=$INSTANCE_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --instance_prompt="a sks dog" \
+ --resolution=256 \
+ --train_batch_size=2 \
+ --gradient_accumulation_steps=6 \
+ --learning_rate=5e-6 \
+ --max_train_steps=2000 \
+ --validation_prompt="a sks dog" \
+ --validation_steps=150 \
+ --checkpointing_steps=500 \
+ --pre_compute_text_embeddings \
+ --tokenizer_max_length=77 \
+ --text_encoder_use_attention_mask \
+ --validation_images $VALIDATION_IMAGES \
+ --class_labels_conditioning timesteps \
+ --validation_scheduler DDPMScheduler\
+ --push_to_hub
+```
+
+## Stable Diffusion XL
+
+We support fine-tuning of the UNet shipped in [Stable Diffusion XL](https://huggingface.co/papers/2307.01952) with DreamBooth and LoRA via the `train_dreambooth_lora_sdxl.py` script. Please refer to the docs [here](./README_sdxl.md).
diff --git a/diffusers/examples/dreambooth/README_flux.md b/diffusers/examples/dreambooth/README_flux.md
new file mode 100644
index 0000000000000000000000000000000000000000..69dfd241395b8779ff037ea7d7da4902a9560962
--- /dev/null
+++ b/diffusers/examples/dreambooth/README_flux.md
@@ -0,0 +1,232 @@
+# DreamBooth training example for FLUX.1 [dev]
+
+[DreamBooth](https://arxiv.org/abs/2208.12242) is a method to personalize text2image models like stable diffusion given just a few (3~5) images of a subject.
+
+The `train_dreambooth_flux.py` script shows how to implement the training procedure and adapt it for [FLUX.1 [dev]](https://blackforestlabs.ai/announcing-black-forest-labs/). We also provide a LoRA implementation in the `train_dreambooth_lora_flux.py` script.
+> [!NOTE]
+> **Memory consumption**
+>
+> Flux can be quite expensive to run on consumer hardware devices and as a result finetuning it comes with high memory requirements -
+> a LoRA with a rank of 16 (w/ all components trained) can exceed 40GB of VRAM for training.
+
+> For more tips & guidance on training on a resource-constrained device and general good practices please check out these great guides and trainers for FLUX:
+> 1) [`@bghira`'s guide](https://github.com/bghira/SimpleTuner/blob/main/documentation/quickstart/FLUX.md)
+> 2) [`ostris`'s guide](https://github.com/ostris/ai-toolkit?tab=readme-ov-file#flux1-training)
+
+> [!NOTE]
+> **Gated model**
+>
+> As the model is gated, before using it with diffusers you first need to go to the [FLUX.1 [dev] Hugging Face page](https://huggingface.co/black-forest-labs/FLUX.1-dev), fill in the form and accept the gate. Once you are in, you need to log in so that your system knows you’ve accepted the gate. Use the command below to log in:
+
+```bash
+huggingface-cli login
+```
+
+This will also allow us to push the trained model parameters to the Hugging Face Hub platform.
+
+## Running locally with PyTorch
+
+### Installing the dependencies
+
+Before running the scripts, make sure to install the library's training dependencies:
+
+**Important**
+
+To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
+
+```bash
+git clone https://github.com/huggingface/diffusers
+cd diffusers
+pip install -e .
+```
+
+Then cd in the `examples/dreambooth` folder and run
+```bash
+pip install -r requirements_flux.txt
+```
+
+And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
+
+```bash
+accelerate config
+```
+
+Or for a default accelerate configuration without answering questions about your environment
+
+```bash
+accelerate config default
+```
+
+Or if your environment doesn't support an interactive shell (e.g., a notebook)
+
+```python
+from accelerate.utils import write_basic_config
+write_basic_config()
+```
+
+When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
+Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.
+
+
+### Dog toy example
+
+Now let's get our dataset. For this example we will use some dog images: https://huggingface.co/datasets/diffusers/dog-example.
+
+Let's first download it locally:
+
+```python
+from huggingface_hub import snapshot_download
+
+local_dir = "./dog"
+snapshot_download(
+ "diffusers/dog-example",
+ local_dir=local_dir, repo_type="dataset",
+ ignore_patterns=".gitattributes",
+)
+```
+
+This will also allow us to push the trained LoRA parameters to the Hugging Face Hub platform.
+
+Now, we can launch training using:
+
+```bash
+export MODEL_NAME="black-forest-labs/FLUX.1-dev"
+export INSTANCE_DIR="dog"
+export OUTPUT_DIR="trained-flux"
+
+accelerate launch train_dreambooth_flux.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --instance_data_dir=$INSTANCE_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --mixed_precision="bf16" \
+ --instance_prompt="a photo of sks dog" \
+ --resolution=1024 \
+ --train_batch_size=1 \
+ --guidance_scale=1 \
+ --gradient_accumulation_steps=4 \
+ --optimizer="prodigy" \
+ --learning_rate=1. \
+ --report_to="wandb" \
+ --lr_scheduler="constant" \
+ --lr_warmup_steps=0 \
+ --max_train_steps=500 \
+ --validation_prompt="A photo of sks dog in a bucket" \
+ --validation_epochs=25 \
+ --seed="0" \
+ --push_to_hub
+```
+
+To better track our training experiments, we're using the following flags in the command above:
+
+* `report_to="wandb` will ensure the training runs are tracked on Weights and Biases. To use it, be sure to install `wandb` with `pip install wandb`.
+* `validation_prompt` and `validation_epochs` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.
+
+> [!NOTE]
+> If you want to train using long prompts with the T5 text encoder, you can use `--max_sequence_length` to set the token limit. The default is 77, but it can be increased to as high as 512. Note that this will use more resources and may slow down the training in some cases.
+
+## LoRA + DreamBooth
+
+[LoRA](https://huggingface.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora) is a popular parameter-efficient fine-tuning technique that allows you to achieve full-finetuning like performance but with a fraction of learnable parameters.
+
+Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.
+
+### Prodigy Optimizer
+Prodigy is an adaptive optimizer that dynamically adjusts the learning rate learned parameters based on past gradients, allowing for more efficient convergence.
+By using prodigy we can "eliminate" the need for manual learning rate tuning. read more [here](https://huggingface.co/blog/sdxl_lora_advanced_script#adaptive-optimizers).
+
+to use prodigy, specify
+```bash
+--optimizer="prodigy"
+```
+> [!TIP]
+> When using prodigy it's generally good practice to set- `--learning_rate=1.0`
+
+To perform DreamBooth with LoRA, run:
+
+```bash
+export MODEL_NAME="black-forest-labs/FLUX.1-dev"
+export INSTANCE_DIR="dog"
+export OUTPUT_DIR="trained-flux-lora"
+
+accelerate launch train_dreambooth_lora_flux.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --instance_data_dir=$INSTANCE_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --mixed_precision="bf16" \
+ --instance_prompt="a photo of sks dog" \
+ --resolution=512 \
+ --train_batch_size=1 \
+ --guidance_scale=1 \
+ --gradient_accumulation_steps=4 \
+ --optimizer="prodigy" \
+ --learning_rate=1. \
+ --report_to="wandb" \
+ --lr_scheduler="constant" \
+ --lr_warmup_steps=0 \
+ --max_train_steps=500 \
+ --validation_prompt="A photo of sks dog in a bucket" \
+ --validation_epochs=25 \
+ --seed="0" \
+ --push_to_hub
+```
+
+### Text Encoder Training
+
+Alongside the transformer, fine-tuning of the CLIP text encoder is also supported.
+To do so, just specify `--train_text_encoder` while launching training. Please keep the following points in mind:
+
+> [!NOTE]
+> This is still an experimental feature.
+> FLUX.1 has 2 text encoders (CLIP L/14 and T5-v1.1-XXL).
+By enabling `--train_text_encoder`, fine-tuning of the **CLIP encoder** is performed.
+> At the moment, T5 fine-tuning is not supported and weights remain frozen when text encoder training is enabled.
+
+To perform DreamBooth LoRA with text-encoder training, run:
+```bash
+export MODEL_NAME="black-forest-labs/FLUX.1-dev"
+export OUTPUT_DIR="trained-flux-dev-dreambooth-lora"
+
+accelerate launch train_dreambooth_lora_flux.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --instance_data_dir=$INSTANCE_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --mixed_precision="bf16" \
+ --train_text_encoder\
+ --instance_prompt="a photo of sks dog" \
+ --resolution=512 \
+ --train_batch_size=1 \
+ --guidance_scale=1 \
+ --gradient_accumulation_steps=4 \
+ --optimizer="prodigy" \
+ --learning_rate=1. \
+ --report_to="wandb" \
+ --lr_scheduler="constant" \
+ --lr_warmup_steps=0 \
+ --max_train_steps=500 \
+ --validation_prompt="A photo of sks dog in a bucket" \
+ --seed="0" \
+ --push_to_hub
+```
+
+## Memory Optimizations
+As mentioned, Flux Dreambooth LoRA training is very memory intensive Here are some options (some still experimental) for a more memory efficient training.
+### Image Resolution
+An easy way to mitigate some of the memory requirements is through `--resolution`. `--resolution` refers to the resolution for input images, all the images in the train/validation dataset are resized to this.
+Note that by default, images are resized to resolution of 512, but it's good to keep in mind in case you're accustomed to training on higher resolutions.
+### Gradient Checkpointing and Accumulation
+* `--gradient accumulation` refers to the number of updates steps to accumulate before performing a backward/update pass.
+by passing a value > 1 you can reduce the amount of backward/update passes and hence also memory reqs.
+* with `--gradient checkpointing` we can save memory by not storing all intermediate activations during the forward pass.
+Instead, only a subset of these activations (the checkpoints) are stored and the rest is recomputed as needed during the backward pass. Note that this comes at the expanse of a slower backward pass.
+### 8-bit-Adam Optimizer
+When training with `AdamW`(doesn't apply to `prodigy`) You can pass `--use_8bit_adam` to reduce the memory requirements of training.
+Make sure to install `bitsandbytes` if you want to do so.
+### Latent caching
+When training w/o validation runs, we can pre-encode the training images with the vae, and then delete it to free up some memory.
+to enable `latent_caching` simply pass `--cache_latents`.
+### Precision of saved LoRA layers
+By default, trained transformer layers are saved in the precision dtype in which training was performed. E.g. when training in mixed precision is enabled with `--mixed_precision="bf16"`, final finetuned layers will be saved in `torch.bfloat16` as well.
+This reduces memory requirements significantly w/o a significant quality loss. Note that if you do wish to save the final layers in float32 at the expanse of more memory usage, you can do so by passing `--upcast_before_saving`.
+
+## Other notes
+Thanks to `bghira` and `ostris` for their help with reviewing & insight sharing ♥️
\ No newline at end of file
diff --git a/diffusers/examples/dreambooth/README_sd3.md b/diffusers/examples/dreambooth/README_sd3.md
new file mode 100644
index 0000000000000000000000000000000000000000..6f41c395629ac96307807448f02ab8e85a66ea5a
--- /dev/null
+++ b/diffusers/examples/dreambooth/README_sd3.md
@@ -0,0 +1,188 @@
+# DreamBooth training example for Stable Diffusion 3 (SD3)
+
+[DreamBooth](https://arxiv.org/abs/2208.12242) is a method to personalize text2image models like stable diffusion given just a few (3~5) images of a subject.
+
+The `train_dreambooth_sd3.py` script shows how to implement the training procedure and adapt it for [Stable Diffusion 3](https://huggingface.co/papers/2403.03206). We also provide a LoRA implementation in the `train_dreambooth_lora_sd3.py` script.
+
+> [!NOTE]
+> As the model is gated, before using it with diffusers you first need to go to the [Stable Diffusion 3 Medium Hugging Face page](https://huggingface.co/stabilityai/stable-diffusion-3-medium-diffusers), fill in the form and accept the gate. Once you are in, you need to log in so that your system knows you’ve accepted the gate. Use the command below to log in:
+
+```bash
+huggingface-cli login
+```
+
+This will also allow us to push the trained model parameters to the Hugging Face Hub platform.
+
+## Running locally with PyTorch
+
+### Installing the dependencies
+
+Before running the scripts, make sure to install the library's training dependencies:
+
+**Important**
+
+To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
+
+```bash
+git clone https://github.com/huggingface/diffusers
+cd diffusers
+pip install -e .
+```
+
+Then cd in the `examples/dreambooth` folder and run
+```bash
+pip install -r requirements_sd3.txt
+```
+
+And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
+
+```bash
+accelerate config
+```
+
+Or for a default accelerate configuration without answering questions about your environment
+
+```bash
+accelerate config default
+```
+
+Or if your environment doesn't support an interactive shell (e.g., a notebook)
+
+```python
+from accelerate.utils import write_basic_config
+write_basic_config()
+```
+
+When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
+Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.
+
+
+### Dog toy example
+
+Now let's get our dataset. For this example we will use some dog images: https://huggingface.co/datasets/diffusers/dog-example.
+
+Let's first download it locally:
+
+```python
+from huggingface_hub import snapshot_download
+
+local_dir = "./dog"
+snapshot_download(
+ "diffusers/dog-example",
+ local_dir=local_dir, repo_type="dataset",
+ ignore_patterns=".gitattributes",
+)
+```
+
+This will also allow us to push the trained LoRA parameters to the Hugging Face Hub platform.
+
+Now, we can launch training using:
+
+```bash
+export MODEL_NAME="stabilityai/stable-diffusion-3-medium-diffusers"
+export INSTANCE_DIR="dog"
+export OUTPUT_DIR="trained-sd3"
+
+accelerate launch train_dreambooth_sd3.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --instance_data_dir=$INSTANCE_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --mixed_precision="fp16" \
+ --instance_prompt="a photo of sks dog" \
+ --resolution=1024 \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=4 \
+ --learning_rate=1e-4 \
+ --report_to="wandb" \
+ --lr_scheduler="constant" \
+ --lr_warmup_steps=0 \
+ --max_train_steps=500 \
+ --validation_prompt="A photo of sks dog in a bucket" \
+ --validation_epochs=25 \
+ --seed="0" \
+ --push_to_hub
+```
+
+To better track our training experiments, we're using the following flags in the command above:
+
+* `report_to="wandb` will ensure the training runs are tracked on Weights and Biases. To use it, be sure to install `wandb` with `pip install wandb`.
+* `validation_prompt` and `validation_epochs` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.
+
+> [!NOTE]
+> If you want to train using long prompts with the T5 text encoder, you can use `--max_sequence_length` to set the token limit. The default is 77, but it can be increased to as high as 512. Note that this will use more resources and may slow down the training in some cases.
+
+> [!TIP]
+> You can pass `--use_8bit_adam` to reduce the memory requirements of training. Make sure to install `bitsandbytes` if you want to do so.
+
+## LoRA + DreamBooth
+
+[LoRA](https://huggingface.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora) is a popular parameter-efficient fine-tuning technique that allows you to achieve full-finetuning like performance but with a fraction of learnable parameters.
+
+Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.
+
+To perform DreamBooth with LoRA, run:
+
+```bash
+export MODEL_NAME="stabilityai/stable-diffusion-3-medium-diffusers"
+export INSTANCE_DIR="dog"
+export OUTPUT_DIR="trained-sd3-lora"
+
+accelerate launch train_dreambooth_lora_sd3.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --instance_data_dir=$INSTANCE_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --mixed_precision="fp16" \
+ --instance_prompt="a photo of sks dog" \
+ --resolution=512 \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=4 \
+ --learning_rate=1e-5 \
+ --report_to="wandb" \
+ --lr_scheduler="constant" \
+ --lr_warmup_steps=0 \
+ --max_train_steps=500 \
+ --validation_prompt="A photo of sks dog in a bucket" \
+ --validation_epochs=25 \
+ --seed="0" \
+ --push_to_hub
+```
+
+### Text Encoder Training
+Alongside the transformer, LoRA fine-tuning of the CLIP text encoders is now also supported.
+To do so, just specify `--train_text_encoder` while launching training. Please keep the following points in mind:
+
+> [!NOTE]
+> SD3 has three text encoders (CLIP L/14, OpenCLIP bigG/14, and T5-v1.1-XXL).
+By enabling `--train_text_encoder`, LoRA fine-tuning of both **CLIP encoders** is performed. At the moment, T5 fine-tuning is not supported and weights remain frozen when text encoder training is enabled.
+
+To perform DreamBooth LoRA with text-encoder training, run:
+```bash
+export MODEL_NAME="stabilityai/stable-diffusion-3-medium-diffusers"
+export OUTPUT_DIR="trained-sd3-lora"
+
+accelerate launch train_dreambooth_lora_sd3.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --output_dir=$OUTPUT_DIR \
+ --dataset_name="Norod78/Yarn-art-style" \
+ --instance_prompt="a photo of TOK yarn art dog" \
+ --resolution=1024 \
+ --train_batch_size=1 \
+ --train_text_encoder\
+ --gradient_accumulation_steps=1 \
+ --optimizer="prodigy"\
+ --learning_rate=1.0 \
+ --text_encoder_lr=1.0 \
+ --report_to="wandb" \
+ --lr_scheduler="constant" \
+ --lr_warmup_steps=0 \
+ --max_train_steps=1500 \
+ --rank=32 \
+ --seed="0" \
+ --push_to_hub
+```
+
+## Other notes
+
+1. We default to the "logit_normal" weighting scheme for the loss following the SD3 paper. Thanks to @bghira for helping us discover that for other weighting schemes supported from the training script, training may incur numerical instabilities.
+2. Thanks to `bghira`, `JinxuXiang`, and `bendanzzc` for helping us discover a bug in how VAE encoding was being done previously. This has been fixed in [#8917](https://github.com/huggingface/diffusers/pull/8917).
+3. Additionally, we now have the option to control if we want to apply preconditioning to the model outputs via a `--precondition_outputs` CLI arg. It affects how the model `target` is calculated as well.
\ No newline at end of file
diff --git a/diffusers/examples/dreambooth/README_sdxl.md b/diffusers/examples/dreambooth/README_sdxl.md
new file mode 100644
index 0000000000000000000000000000000000000000..7a42bf8fffd8bf1db76f6db76302480be6dacf8c
--- /dev/null
+++ b/diffusers/examples/dreambooth/README_sdxl.md
@@ -0,0 +1,275 @@
+# DreamBooth training example for Stable Diffusion XL (SDXL)
+
+[DreamBooth](https://arxiv.org/abs/2208.12242) is a method to personalize text2image models like stable diffusion given just a few (3~5) images of a subject.
+
+The `train_dreambooth_lora_sdxl.py` script shows how to implement the training procedure and adapt it for [Stable Diffusion XL](https://huggingface.co/papers/2307.01952).
+
+> 💡 **Note**: For now, we only allow DreamBooth fine-tuning of the SDXL UNet via LoRA. LoRA is a parameter-efficient fine-tuning technique introduced in [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen*.
+
+## Running locally with PyTorch
+
+### Installing the dependencies
+
+Before running the scripts, make sure to install the library's training dependencies:
+
+**Important**
+
+To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
+
+```bash
+git clone https://github.com/huggingface/diffusers
+cd diffusers
+pip install -e .
+```
+
+Then cd in the `examples/dreambooth` folder and run
+```bash
+pip install -r requirements_sdxl.txt
+```
+
+And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
+
+```bash
+accelerate config
+```
+
+Or for a default accelerate configuration without answering questions about your environment
+
+```bash
+accelerate config default
+```
+
+Or if your environment doesn't support an interactive shell (e.g., a notebook)
+
+```python
+from accelerate.utils import write_basic_config
+write_basic_config()
+```
+
+When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
+Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.
+
+### Dog toy example
+
+Now let's get our dataset. For this example we will use some dog images: https://huggingface.co/datasets/diffusers/dog-example.
+
+Let's first download it locally:
+
+```python
+from huggingface_hub import snapshot_download
+
+local_dir = "./dog"
+snapshot_download(
+ "diffusers/dog-example",
+ local_dir=local_dir, repo_type="dataset",
+ ignore_patterns=".gitattributes",
+)
+```
+
+This will also allow us to push the trained LoRA parameters to the Hugging Face Hub platform.
+
+Now, we can launch training using:
+
+```bash
+export MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0"
+export INSTANCE_DIR="dog"
+export OUTPUT_DIR="lora-trained-xl"
+export VAE_PATH="madebyollin/sdxl-vae-fp16-fix"
+
+accelerate launch train_dreambooth_lora_sdxl.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --instance_data_dir=$INSTANCE_DIR \
+ --pretrained_vae_model_name_or_path=$VAE_PATH \
+ --output_dir=$OUTPUT_DIR \
+ --mixed_precision="fp16" \
+ --instance_prompt="a photo of sks dog" \
+ --resolution=1024 \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=4 \
+ --learning_rate=1e-4 \
+ --report_to="wandb" \
+ --lr_scheduler="constant" \
+ --lr_warmup_steps=0 \
+ --max_train_steps=500 \
+ --validation_prompt="A photo of sks dog in a bucket" \
+ --validation_epochs=25 \
+ --seed="0" \
+ --push_to_hub
+```
+
+To better track our training experiments, we're using the following flags in the command above:
+
+* `report_to="wandb` will ensure the training runs are tracked on Weights and Biases. To use it, be sure to install `wandb` with `pip install wandb`.
+* `validation_prompt` and `validation_epochs` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.
+
+Our experiments were conducted on a single 40GB A100 GPU.
+
+### Dog toy example with < 16GB VRAM
+
+By making use of [`gradient_checkpointing`](https://pytorch.org/docs/stable/checkpoint.html) (which is natively supported in Diffusers), [`xformers`](https://github.com/facebookresearch/xformers), and [`bitsandbytes`](https://github.com/TimDettmers/bitsandbytes) libraries, you can train SDXL LoRAs with less than 16GB of VRAM by adding the following flags to your accelerate launch command:
+
+```diff
++ --enable_xformers_memory_efficient_attention \
++ --gradient_checkpointing \
++ --use_8bit_adam \
++ --mixed_precision="fp16" \
+```
+
+and making sure that you have the following libraries installed:
+
+```
+bitsandbytes>=0.40.0
+xformers>=0.0.20
+```
+
+### Inference
+
+Once training is done, we can perform inference like so:
+
+```python
+from huggingface_hub.repocard import RepoCard
+from diffusers import DiffusionPipeline
+import torch
+
+lora_model_id = <"lora-sdxl-dreambooth-id">
+card = RepoCard.load(lora_model_id)
+base_model_id = card.data.to_dict()["base_model"]
+
+pipe = DiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16)
+pipe = pipe.to("cuda")
+pipe.load_lora_weights(lora_model_id)
+image = pipe("A picture of a sks dog in a bucket", num_inference_steps=25).images[0]
+image.save("sks_dog.png")
+```
+
+We can further refine the outputs with the [Refiner](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0):
+
+```python
+from huggingface_hub.repocard import RepoCard
+from diffusers import DiffusionPipeline, StableDiffusionXLImg2ImgPipeline
+import torch
+
+lora_model_id = <"lora-sdxl-dreambooth-id">
+card = RepoCard.load(lora_model_id)
+base_model_id = card.data.to_dict()["base_model"]
+
+# Load the base pipeline and load the LoRA parameters into it.
+pipe = DiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16)
+pipe = pipe.to("cuda")
+pipe.load_lora_weights(lora_model_id)
+
+# Load the refiner.
+refiner = StableDiffusionXLImg2ImgPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-refiner-1.0", torch_dtype=torch.float16, use_safetensors=True, variant="fp16"
+)
+refiner.to("cuda")
+
+prompt = "A picture of a sks dog in a bucket"
+generator = torch.Generator("cuda").manual_seed(0)
+
+# Run inference.
+image = pipe(prompt=prompt, output_type="latent", generator=generator).images[0]
+image = refiner(prompt=prompt, image=image[None, :], generator=generator).images[0]
+image.save("refined_sks_dog.png")
+```
+
+Here's a side-by-side comparison of the with and without Refiner pipeline outputs:
+
+| Without Refiner | With Refiner |
+|---|---|
+|  |  |
+
+### Training with text encoder(s)
+
+Alongside the UNet, LoRA fine-tuning of the text encoders is also supported. To do so, just specify `--train_text_encoder` while launching training. Please keep the following points in mind:
+
+* SDXL has two text encoders. So, we fine-tune both using LoRA.
+* When not fine-tuning the text encoders, we ALWAYS precompute the text embeddings to save memory.
+
+### Specifying a better VAE
+
+SDXL's VAE is known to suffer from numerical instability issues. This is why we also expose a CLI argument namely `--pretrained_vae_model_name_or_path` that lets you specify the location of a better VAE (such as [this one](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)).
+
+## Notes
+
+In our experiments, we found that SDXL yields good initial results without extensive hyperparameter tuning. For example, without fine-tuning the text encoders and without using prior-preservation, we observed decent results. We didn't explore further hyper-parameter tuning experiments, but we do encourage the community to explore this avenue further and share their results with us 🤗
+
+## Results
+
+You can explore the results from a couple of our internal experiments by checking out this link: [https://wandb.ai/sayakpaul/dreambooth-lora-sd-xl](https://wandb.ai/sayakpaul/dreambooth-lora-sd-xl). Specifically, we used the same script with the exact same hyperparameters on the following datasets:
+
+* [Dogs](https://huggingface.co/datasets/diffusers/dog-example)
+* [Starbucks logo](https://huggingface.co/datasets/diffusers/starbucks-example)
+* [Mr. Potato Head](https://huggingface.co/datasets/diffusers/potato-head-example)
+* [Keramer face](https://huggingface.co/datasets/diffusers/keramer-face-example)
+
+## Running on a free-tier Colab Notebook
+
+Check out [this notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/SDXL_DreamBooth_LoRA_.ipynb).
+
+## Conducting EDM-style training
+
+It's now possible to perform EDM-style training as proposed in [Elucidating the Design Space of Diffusion-Based Generative Models](https://arxiv.org/abs/2206.00364).
+
+For the SDXL model, simple set:
+
+```diff
++ --do_edm_style_training \
+```
+
+Other SDXL-like models that use the EDM formulation, such as [playgroundai/playground-v2.5-1024px-aesthetic](https://huggingface.co/playgroundai/playground-v2.5-1024px-aesthetic), can also be DreamBooth'd with the script. Below is an example command:
+
+```bash
+accelerate launch train_dreambooth_lora_sdxl.py \
+ --pretrained_model_name_or_path="playgroundai/playground-v2.5-1024px-aesthetic" \
+ --instance_data_dir="dog" \
+ --output_dir="dog-playground-lora" \
+ --mixed_precision="fp16" \
+ --instance_prompt="a photo of sks dog" \
+ --resolution=1024 \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=4 \
+ --learning_rate=1e-4 \
+ --use_8bit_adam \
+ --report_to="wandb" \
+ --lr_scheduler="constant" \
+ --lr_warmup_steps=0 \
+ --max_train_steps=500 \
+ --validation_prompt="A photo of sks dog in a bucket" \
+ --validation_epochs=25 \
+ --seed="0" \
+ --push_to_hub
+```
+
+> [!CAUTION]
+> Min-SNR gamma is not supported with the EDM-style training yet. When training with the PlaygroundAI model, it's recommended to not pass any "variant".
+
+### DoRA training
+The script now supports DoRA training too!
+> Proposed in [DoRA: Weight-Decomposed Low-Rank Adaptation](https://arxiv.org/abs/2402.09353),
+**DoRA** is very similar to LoRA, except it decomposes the pre-trained weight into two components, **magnitude** and **direction** and employs LoRA for _directional_ updates to efficiently minimize the number of trainable parameters.
+The authors found that by using DoRA, both the learning capacity and training stability of LoRA are enhanced without any additional overhead during inference.
+
+> [!NOTE]
+> 💡DoRA training is still _experimental_
+> and is likely to require different hyperparameter values to perform best compared to a LoRA.
+> Specifically, we've noticed 2 differences to take into account your training:
+> 1. **LoRA seem to converge faster than DoRA** (so a set of parameters that may lead to overfitting when training a LoRA may be working well for a DoRA)
+> 2. **DoRA quality superior to LoRA especially in lower ranks** the difference in quality of DoRA of rank 8 and LoRA of rank 8 appears to be more significant than when training ranks of 32 or 64 for example.
+> This is also aligned with some of the quantitative analysis shown in the paper.
+
+**Usage**
+1. To use DoRA you need to upgrade the installation of `peft`:
+```bash
+pip install -U peft
+```
+2. Enable DoRA training by adding this flag
+```bash
+--use_dora
+```
+**Inference**
+The inference is the same as if you train a regular LoRA 🤗
+
+## Format compatibility
+
+You can pass `--output_kohya_format` to additionally generate a state dictionary which should be compatible with other platforms and tools such as Automatic 1111, Comfy, Kohya, etc. The `output_dir` will contain a file named "pytorch_lora_weights_kohya.safetensors".
\ No newline at end of file
diff --git a/diffusers/examples/dreambooth/requirements.txt b/diffusers/examples/dreambooth/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..3f86855e1d1e03c20874e9de4a6037a119509444
--- /dev/null
+++ b/diffusers/examples/dreambooth/requirements.txt
@@ -0,0 +1,7 @@
+accelerate>=0.16.0
+torchvision
+transformers>=4.25.1
+ftfy
+tensorboard
+Jinja2
+peft==0.7.0
\ No newline at end of file
diff --git a/diffusers/examples/dreambooth/requirements_flax.txt b/diffusers/examples/dreambooth/requirements_flax.txt
new file mode 100644
index 0000000000000000000000000000000000000000..8f85ad523a3b46b65abf0138c05ecdd656e6845c
--- /dev/null
+++ b/diffusers/examples/dreambooth/requirements_flax.txt
@@ -0,0 +1,8 @@
+transformers>=4.25.1
+flax
+optax
+torch
+torchvision
+ftfy
+tensorboard
+Jinja2
diff --git a/diffusers/examples/dreambooth/requirements_flux.txt b/diffusers/examples/dreambooth/requirements_flux.txt
new file mode 100644
index 0000000000000000000000000000000000000000..dbc124ff6526c07b0357c42f3635e75cf484105f
--- /dev/null
+++ b/diffusers/examples/dreambooth/requirements_flux.txt
@@ -0,0 +1,8 @@
+accelerate>=0.31.0
+torchvision
+transformers>=4.41.2
+ftfy
+tensorboard
+Jinja2
+peft>=0.11.1
+sentencepiece
\ No newline at end of file
diff --git a/diffusers/examples/dreambooth/requirements_sd3.txt b/diffusers/examples/dreambooth/requirements_sd3.txt
new file mode 100644
index 0000000000000000000000000000000000000000..1c1ad1f3a7380ba5ed685d8ce0279939968dffdd
--- /dev/null
+++ b/diffusers/examples/dreambooth/requirements_sd3.txt
@@ -0,0 +1,8 @@
+accelerate>=0.31.0
+torchvision
+transformers>=4.41.2
+ftfy
+tensorboard
+Jinja2
+peft==0.11.1
+sentencepiece
\ No newline at end of file
diff --git a/diffusers/examples/dreambooth/requirements_sdxl.txt b/diffusers/examples/dreambooth/requirements_sdxl.txt
new file mode 100644
index 0000000000000000000000000000000000000000..3f86855e1d1e03c20874e9de4a6037a119509444
--- /dev/null
+++ b/diffusers/examples/dreambooth/requirements_sdxl.txt
@@ -0,0 +1,7 @@
+accelerate>=0.16.0
+torchvision
+transformers>=4.25.1
+ftfy
+tensorboard
+Jinja2
+peft==0.7.0
\ No newline at end of file
diff --git a/diffusers/examples/dreambooth/test_dreambooth.py b/diffusers/examples/dreambooth/test_dreambooth.py
new file mode 100644
index 0000000000000000000000000000000000000000..86d0438fc25f4f0d8da4a98d9c1a6a9632dccb6c
--- /dev/null
+++ b/diffusers/examples/dreambooth/test_dreambooth.py
@@ -0,0 +1,227 @@
+# coding=utf-8
+# Copyright 2024 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import os
+import shutil
+import sys
+import tempfile
+
+from diffusers import DiffusionPipeline, UNet2DConditionModel
+
+
+sys.path.append("..")
+from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
+
+
+logging.basicConfig(level=logging.DEBUG)
+
+logger = logging.getLogger()
+stream_handler = logging.StreamHandler(sys.stdout)
+logger.addHandler(stream_handler)
+
+
+class DreamBooth(ExamplesTestsAccelerate):
+ def test_dreambooth(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/dreambooth/train_dreambooth.py
+ --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe
+ --instance_data_dir docs/source/en/imgs
+ --instance_prompt photo
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ """.split()
+
+ run_command(self._launch_args + test_args)
+ # save_pretrained smoke test
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.safetensors")))
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json")))
+
+ def test_dreambooth_if(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/dreambooth/train_dreambooth.py
+ --pretrained_model_name_or_path hf-internal-testing/tiny-if-pipe
+ --instance_data_dir docs/source/en/imgs
+ --instance_prompt photo
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --pre_compute_text_embeddings
+ --tokenizer_max_length=77
+ --text_encoder_use_attention_mask
+ """.split()
+
+ run_command(self._launch_args + test_args)
+ # save_pretrained smoke test
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.safetensors")))
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json")))
+
+ def test_dreambooth_checkpointing(self):
+ instance_prompt = "photo"
+ pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe"
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ # Run training script with checkpointing
+ # max_train_steps == 4, checkpointing_steps == 2
+ # Should create checkpoints at steps 2, 4
+
+ initial_run_args = f"""
+ examples/dreambooth/train_dreambooth.py
+ --pretrained_model_name_or_path {pretrained_model_name_or_path}
+ --instance_data_dir docs/source/en/imgs
+ --instance_prompt {instance_prompt}
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 4
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --checkpointing_steps=2
+ --seed=0
+ """.split()
+
+ run_command(self._launch_args + initial_run_args)
+
+ # check can run the original fully trained output pipeline
+ pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)
+ pipe(instance_prompt, num_inference_steps=1)
+
+ # check checkpoint directories exist
+ self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-2")))
+ self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-4")))
+
+ # check can run an intermediate checkpoint
+ unet = UNet2DConditionModel.from_pretrained(tmpdir, subfolder="checkpoint-2/unet")
+ pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, unet=unet, safety_checker=None)
+ pipe(instance_prompt, num_inference_steps=1)
+
+ # Remove checkpoint 2 so that we can check only later checkpoints exist after resuming
+ shutil.rmtree(os.path.join(tmpdir, "checkpoint-2"))
+
+ # Run training script for 7 total steps resuming from checkpoint 4
+
+ resume_run_args = f"""
+ examples/dreambooth/train_dreambooth.py
+ --pretrained_model_name_or_path {pretrained_model_name_or_path}
+ --instance_data_dir docs/source/en/imgs
+ --instance_prompt {instance_prompt}
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 6
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --checkpointing_steps=2
+ --resume_from_checkpoint=checkpoint-4
+ --seed=0
+ """.split()
+
+ run_command(self._launch_args + resume_run_args)
+
+ # check can run new fully trained pipeline
+ pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)
+ pipe(instance_prompt, num_inference_steps=1)
+
+ # check old checkpoints do not exist
+ self.assertFalse(os.path.isdir(os.path.join(tmpdir, "checkpoint-2")))
+
+ # check new checkpoints exist
+ self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-4")))
+ self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-6")))
+
+ def test_dreambooth_checkpointing_checkpoints_total_limit(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/dreambooth/train_dreambooth.py
+ --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
+ --instance_data_dir=docs/source/en/imgs
+ --output_dir={tmpdir}
+ --instance_prompt=prompt
+ --resolution=64
+ --train_batch_size=1
+ --gradient_accumulation_steps=1
+ --max_train_steps=6
+ --checkpoints_total_limit=2
+ --checkpointing_steps=2
+ """.split()
+
+ run_command(self._launch_args + test_args)
+
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-4", "checkpoint-6"},
+ )
+
+ def test_dreambooth_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/dreambooth/train_dreambooth.py
+ --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
+ --instance_data_dir=docs/source/en/imgs
+ --output_dir={tmpdir}
+ --instance_prompt=prompt
+ --resolution=64
+ --train_batch_size=1
+ --gradient_accumulation_steps=1
+ --max_train_steps=4
+ --checkpointing_steps=2
+ """.split()
+
+ run_command(self._launch_args + test_args)
+
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-2", "checkpoint-4"},
+ )
+
+ resume_run_args = f"""
+ examples/dreambooth/train_dreambooth.py
+ --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
+ --instance_data_dir=docs/source/en/imgs
+ --output_dir={tmpdir}
+ --instance_prompt=prompt
+ --resolution=64
+ --train_batch_size=1
+ --gradient_accumulation_steps=1
+ --max_train_steps=8
+ --checkpointing_steps=2
+ --resume_from_checkpoint=checkpoint-4
+ --checkpoints_total_limit=2
+ """.split()
+
+ run_command(self._launch_args + resume_run_args)
+
+ self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})
diff --git a/diffusers/examples/dreambooth/test_dreambooth_flux.py b/diffusers/examples/dreambooth/test_dreambooth_flux.py
new file mode 100644
index 0000000000000000000000000000000000000000..2d5703d2a24aad78dc1fe7c3b1a91187ed3f8614
--- /dev/null
+++ b/diffusers/examples/dreambooth/test_dreambooth_flux.py
@@ -0,0 +1,203 @@
+# coding=utf-8
+# Copyright 2024 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import os
+import shutil
+import sys
+import tempfile
+
+from diffusers import DiffusionPipeline, FluxTransformer2DModel
+
+
+sys.path.append("..")
+from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
+
+
+logging.basicConfig(level=logging.DEBUG)
+
+logger = logging.getLogger()
+stream_handler = logging.StreamHandler(sys.stdout)
+logger.addHandler(stream_handler)
+
+
+class DreamBoothFlux(ExamplesTestsAccelerate):
+ instance_data_dir = "docs/source/en/imgs"
+ instance_prompt = "photo"
+ pretrained_model_name_or_path = "hf-internal-testing/tiny-flux-pipe"
+ script_path = "examples/dreambooth/train_dreambooth_flux.py"
+
+ def test_dreambooth(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ {self.script_path}
+ --pretrained_model_name_or_path {self.pretrained_model_name_or_path}
+ --instance_data_dir {self.instance_data_dir}
+ --instance_prompt {self.instance_prompt}
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ """.split()
+
+ run_command(self._launch_args + test_args)
+ # save_pretrained smoke test
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "transformer", "diffusion_pytorch_model.safetensors")))
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json")))
+
+ def test_dreambooth_checkpointing(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ # Run training script with checkpointing
+ # max_train_steps == 4, checkpointing_steps == 2
+ # Should create checkpoints at steps 2, 4
+
+ initial_run_args = f"""
+ {self.script_path}
+ --pretrained_model_name_or_path {self.pretrained_model_name_or_path}
+ --instance_data_dir {self.instance_data_dir}
+ --instance_prompt {self.instance_prompt}
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 4
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --checkpointing_steps=2
+ --seed=0
+ """.split()
+
+ run_command(self._launch_args + initial_run_args)
+
+ # check can run the original fully trained output pipeline
+ pipe = DiffusionPipeline.from_pretrained(tmpdir)
+ pipe(self.instance_prompt, num_inference_steps=1)
+
+ # check checkpoint directories exist
+ self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-2")))
+ self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-4")))
+
+ # check can run an intermediate checkpoint
+ transformer = FluxTransformer2DModel.from_pretrained(tmpdir, subfolder="checkpoint-2/transformer")
+ pipe = DiffusionPipeline.from_pretrained(self.pretrained_model_name_or_path, transformer=transformer)
+ pipe(self.instance_prompt, num_inference_steps=1)
+
+ # Remove checkpoint 2 so that we can check only later checkpoints exist after resuming
+ shutil.rmtree(os.path.join(tmpdir, "checkpoint-2"))
+
+ # Run training script for 7 total steps resuming from checkpoint 4
+
+ resume_run_args = f"""
+ {self.script_path}
+ --pretrained_model_name_or_path {self.pretrained_model_name_or_path}
+ --instance_data_dir {self.instance_data_dir}
+ --instance_prompt {self.instance_prompt}
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 6
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --checkpointing_steps=2
+ --resume_from_checkpoint=checkpoint-4
+ --seed=0
+ """.split()
+
+ run_command(self._launch_args + resume_run_args)
+
+ # check can run new fully trained pipeline
+ pipe = DiffusionPipeline.from_pretrained(tmpdir)
+ pipe(self.instance_prompt, num_inference_steps=1)
+
+ # check old checkpoints do not exist
+ self.assertFalse(os.path.isdir(os.path.join(tmpdir, "checkpoint-2")))
+
+ # check new checkpoints exist
+ self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-4")))
+ self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-6")))
+
+ def test_dreambooth_checkpointing_checkpoints_total_limit(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ {self.script_path}
+ --pretrained_model_name_or_path={self.pretrained_model_name_or_path}
+ --instance_data_dir={self.instance_data_dir}
+ --output_dir={tmpdir}
+ --instance_prompt={self.instance_prompt}
+ --resolution=64
+ --train_batch_size=1
+ --gradient_accumulation_steps=1
+ --max_train_steps=6
+ --checkpoints_total_limit=2
+ --checkpointing_steps=2
+ """.split()
+
+ run_command(self._launch_args + test_args)
+
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-4", "checkpoint-6"},
+ )
+
+ def test_dreambooth_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ {self.script_path}
+ --pretrained_model_name_or_path={self.pretrained_model_name_or_path}
+ --instance_data_dir={self.instance_data_dir}
+ --output_dir={tmpdir}
+ --instance_prompt={self.instance_prompt}
+ --resolution=64
+ --train_batch_size=1
+ --gradient_accumulation_steps=1
+ --max_train_steps=4
+ --checkpointing_steps=2
+ """.split()
+
+ run_command(self._launch_args + test_args)
+
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-2", "checkpoint-4"},
+ )
+
+ resume_run_args = f"""
+ {self.script_path}
+ --pretrained_model_name_or_path={self.pretrained_model_name_or_path}
+ --instance_data_dir={self.instance_data_dir}
+ --output_dir={tmpdir}
+ --instance_prompt={self.instance_prompt}
+ --resolution=64
+ --train_batch_size=1
+ --gradient_accumulation_steps=1
+ --max_train_steps=8
+ --checkpointing_steps=2
+ --resume_from_checkpoint=checkpoint-4
+ --checkpoints_total_limit=2
+ """.split()
+
+ run_command(self._launch_args + resume_run_args)
+
+ self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})
diff --git a/diffusers/examples/dreambooth/test_dreambooth_lora.py b/diffusers/examples/dreambooth/test_dreambooth_lora.py
new file mode 100644
index 0000000000000000000000000000000000000000..9df104e071af9fc02d573678f86a651e80aa16d6
--- /dev/null
+++ b/diffusers/examples/dreambooth/test_dreambooth_lora.py
@@ -0,0 +1,379 @@
+# coding=utf-8
+# Copyright 2024 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import os
+import sys
+import tempfile
+
+import safetensors
+
+
+sys.path.append("..")
+from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
+
+from diffusers import DiffusionPipeline # noqa: E402
+
+
+logging.basicConfig(level=logging.DEBUG)
+
+logger = logging.getLogger()
+stream_handler = logging.StreamHandler(sys.stdout)
+logger.addHandler(stream_handler)
+
+
+class DreamBoothLoRA(ExamplesTestsAccelerate):
+ def test_dreambooth_lora(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/dreambooth/train_dreambooth_lora.py
+ --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe
+ --instance_data_dir docs/source/en/imgs
+ --instance_prompt photo
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ """.split()
+
+ run_command(self._launch_args + test_args)
+ # save_pretrained smoke test
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
+
+ # make sure the state_dict has the correct naming in the parameters.
+ lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
+ is_lora = all("lora" in k for k in lora_state_dict.keys())
+ self.assertTrue(is_lora)
+
+ # when not training the text encoder, all the parameters in the state dict should start
+ # with `"unet"` in their names.
+ starts_with_unet = all(key.startswith("unet") for key in lora_state_dict.keys())
+ self.assertTrue(starts_with_unet)
+
+ def test_dreambooth_lora_with_text_encoder(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/dreambooth/train_dreambooth_lora.py
+ --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe
+ --instance_data_dir docs/source/en/imgs
+ --instance_prompt photo
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --train_text_encoder
+ --output_dir {tmpdir}
+ """.split()
+
+ run_command(self._launch_args + test_args)
+ # save_pretrained smoke test
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
+
+ # check `text_encoder` is present at all.
+ lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
+ keys = lora_state_dict.keys()
+ is_text_encoder_present = any(k.startswith("text_encoder") for k in keys)
+ self.assertTrue(is_text_encoder_present)
+
+ # the names of the keys of the state dict should either start with `unet`
+ # or `text_encoder`.
+ is_correct_naming = all(k.startswith("unet") or k.startswith("text_encoder") for k in keys)
+ self.assertTrue(is_correct_naming)
+
+ def test_dreambooth_lora_checkpointing_checkpoints_total_limit(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/dreambooth/train_dreambooth_lora.py
+ --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
+ --instance_data_dir=docs/source/en/imgs
+ --output_dir={tmpdir}
+ --instance_prompt=prompt
+ --resolution=64
+ --train_batch_size=1
+ --gradient_accumulation_steps=1
+ --max_train_steps=6
+ --checkpoints_total_limit=2
+ --checkpointing_steps=2
+ """.split()
+
+ run_command(self._launch_args + test_args)
+
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-4", "checkpoint-6"},
+ )
+
+ def test_dreambooth_lora_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/dreambooth/train_dreambooth_lora.py
+ --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
+ --instance_data_dir=docs/source/en/imgs
+ --output_dir={tmpdir}
+ --instance_prompt=prompt
+ --resolution=64
+ --train_batch_size=1
+ --gradient_accumulation_steps=1
+ --max_train_steps=4
+ --checkpointing_steps=2
+ """.split()
+
+ run_command(self._launch_args + test_args)
+
+ self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-2", "checkpoint-4"})
+
+ resume_run_args = f"""
+ examples/dreambooth/train_dreambooth_lora.py
+ --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
+ --instance_data_dir=docs/source/en/imgs
+ --output_dir={tmpdir}
+ --instance_prompt=prompt
+ --resolution=64
+ --train_batch_size=1
+ --gradient_accumulation_steps=1
+ --max_train_steps=8
+ --checkpointing_steps=2
+ --resume_from_checkpoint=checkpoint-4
+ --checkpoints_total_limit=2
+ """.split()
+
+ run_command(self._launch_args + resume_run_args)
+
+ self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})
+
+ def test_dreambooth_lora_if_model(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/dreambooth/train_dreambooth_lora.py
+ --pretrained_model_name_or_path hf-internal-testing/tiny-if-pipe
+ --instance_data_dir docs/source/en/imgs
+ --instance_prompt photo
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --pre_compute_text_embeddings
+ --tokenizer_max_length=77
+ --text_encoder_use_attention_mask
+ """.split()
+
+ run_command(self._launch_args + test_args)
+ # save_pretrained smoke test
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
+
+ # make sure the state_dict has the correct naming in the parameters.
+ lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
+ is_lora = all("lora" in k for k in lora_state_dict.keys())
+ self.assertTrue(is_lora)
+
+ # when not training the text encoder, all the parameters in the state dict should start
+ # with `"unet"` in their names.
+ starts_with_unet = all(key.startswith("unet") for key in lora_state_dict.keys())
+ self.assertTrue(starts_with_unet)
+
+
+class DreamBoothLoRASDXL(ExamplesTestsAccelerate):
+ def test_dreambooth_lora_sdxl(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/dreambooth/train_dreambooth_lora_sdxl.py
+ --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-xl-pipe
+ --instance_data_dir docs/source/en/imgs
+ --instance_prompt photo
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ """.split()
+
+ run_command(self._launch_args + test_args)
+ # save_pretrained smoke test
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
+
+ # make sure the state_dict has the correct naming in the parameters.
+ lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
+ is_lora = all("lora" in k for k in lora_state_dict.keys())
+ self.assertTrue(is_lora)
+
+ # when not training the text encoder, all the parameters in the state dict should start
+ # with `"unet"` in their names.
+ starts_with_unet = all(key.startswith("unet") for key in lora_state_dict.keys())
+ self.assertTrue(starts_with_unet)
+
+ def test_dreambooth_lora_sdxl_with_text_encoder(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/dreambooth/train_dreambooth_lora_sdxl.py
+ --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-xl-pipe
+ --instance_data_dir docs/source/en/imgs
+ --instance_prompt photo
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --train_text_encoder
+ """.split()
+
+ run_command(self._launch_args + test_args)
+ # save_pretrained smoke test
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
+
+ # make sure the state_dict has the correct naming in the parameters.
+ lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
+ is_lora = all("lora" in k for k in lora_state_dict.keys())
+ self.assertTrue(is_lora)
+
+ # when not training the text encoder, all the parameters in the state dict should start
+ # with `"unet"` or `"text_encoder"` or `"text_encoder_2"` in their names.
+ keys = lora_state_dict.keys()
+ starts_with_unet = all(
+ k.startswith("unet") or k.startswith("text_encoder") or k.startswith("text_encoder_2") for k in keys
+ )
+ self.assertTrue(starts_with_unet)
+
+ def test_dreambooth_lora_sdxl_custom_captions(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/dreambooth/train_dreambooth_lora_sdxl.py
+ --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-xl-pipe
+ --dataset_name hf-internal-testing/dummy_image_text_data
+ --caption_column text
+ --instance_prompt photo
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ """.split()
+
+ run_command(self._launch_args + test_args)
+
+ def test_dreambooth_lora_sdxl_text_encoder_custom_captions(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/dreambooth/train_dreambooth_lora_sdxl.py
+ --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-xl-pipe
+ --dataset_name hf-internal-testing/dummy_image_text_data
+ --caption_column text
+ --instance_prompt photo
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --train_text_encoder
+ """.split()
+
+ run_command(self._launch_args + test_args)
+
+ def test_dreambooth_lora_sdxl_checkpointing_checkpoints_total_limit(self):
+ pipeline_path = "hf-internal-testing/tiny-stable-diffusion-xl-pipe"
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/dreambooth/train_dreambooth_lora_sdxl.py
+ --pretrained_model_name_or_path {pipeline_path}
+ --instance_data_dir docs/source/en/imgs
+ --instance_prompt photo
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 6
+ --checkpointing_steps=2
+ --checkpoints_total_limit=2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ """.split()
+
+ run_command(self._launch_args + test_args)
+
+ pipe = DiffusionPipeline.from_pretrained(pipeline_path)
+ pipe.load_lora_weights(tmpdir)
+ pipe("a prompt", num_inference_steps=1)
+
+ # check checkpoint directories exist
+ # checkpoint-2 should have been deleted
+ self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-4", "checkpoint-6"})
+
+ def test_dreambooth_lora_sdxl_text_encoder_checkpointing_checkpoints_total_limit(self):
+ pipeline_path = "hf-internal-testing/tiny-stable-diffusion-xl-pipe"
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/dreambooth/train_dreambooth_lora_sdxl.py
+ --pretrained_model_name_or_path {pipeline_path}
+ --instance_data_dir docs/source/en/imgs
+ --instance_prompt photo
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 7
+ --checkpointing_steps=2
+ --checkpoints_total_limit=2
+ --train_text_encoder
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ """.split()
+
+ run_command(self._launch_args + test_args)
+
+ pipe = DiffusionPipeline.from_pretrained(pipeline_path)
+ pipe.load_lora_weights(tmpdir)
+ pipe("a prompt", num_inference_steps=2)
+
+ # check checkpoint directories exist
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ # checkpoint-2 should have been deleted
+ {"checkpoint-4", "checkpoint-6"},
+ )
diff --git a/diffusers/examples/dreambooth/test_dreambooth_lora_edm.py b/diffusers/examples/dreambooth/test_dreambooth_lora_edm.py
new file mode 100644
index 0000000000000000000000000000000000000000..0f6b3674b8b35a53ed1145c5415886f89f134726
--- /dev/null
+++ b/diffusers/examples/dreambooth/test_dreambooth_lora_edm.py
@@ -0,0 +1,99 @@
+# coding=utf-8
+# Copyright 2024 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import os
+import sys
+import tempfile
+
+import safetensors
+
+
+sys.path.append("..")
+from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
+
+
+logging.basicConfig(level=logging.DEBUG)
+
+logger = logging.getLogger()
+stream_handler = logging.StreamHandler(sys.stdout)
+logger.addHandler(stream_handler)
+
+
+class DreamBoothLoRASDXLWithEDM(ExamplesTestsAccelerate):
+ def test_dreambooth_lora_sdxl_with_edm(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/dreambooth/train_dreambooth_lora_sdxl.py
+ --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-xl-pipe
+ --do_edm_style_training
+ --instance_data_dir docs/source/en/imgs
+ --instance_prompt photo
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ """.split()
+
+ run_command(self._launch_args + test_args)
+ # save_pretrained smoke test
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
+
+ # make sure the state_dict has the correct naming in the parameters.
+ lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
+ is_lora = all("lora" in k for k in lora_state_dict.keys())
+ self.assertTrue(is_lora)
+
+ # when not training the text encoder, all the parameters in the state dict should start
+ # with `"unet"` in their names.
+ starts_with_unet = all(key.startswith("unet") for key in lora_state_dict.keys())
+ self.assertTrue(starts_with_unet)
+
+ def test_dreambooth_lora_playground(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/dreambooth/train_dreambooth_lora_sdxl.py
+ --pretrained_model_name_or_path hf-internal-testing/tiny-playground-v2-5-pipe
+ --instance_data_dir docs/source/en/imgs
+ --instance_prompt photo
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ """.split()
+
+ run_command(self._launch_args + test_args)
+ # save_pretrained smoke test
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
+
+ # make sure the state_dict has the correct naming in the parameters.
+ lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
+ is_lora = all("lora" in k for k in lora_state_dict.keys())
+ self.assertTrue(is_lora)
+
+ # when not training the text encoder, all the parameters in the state dict should start
+ # with `"unet"` in their names.
+ starts_with_unet = all(key.startswith("unet") for key in lora_state_dict.keys())
+ self.assertTrue(starts_with_unet)
diff --git a/diffusers/examples/dreambooth/test_dreambooth_lora_flux.py b/diffusers/examples/dreambooth/test_dreambooth_lora_flux.py
new file mode 100644
index 0000000000000000000000000000000000000000..d197c8187b8745fc44a412474ae338fb16ab6414
--- /dev/null
+++ b/diffusers/examples/dreambooth/test_dreambooth_lora_flux.py
@@ -0,0 +1,198 @@
+# coding=utf-8
+# Copyright 2024 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import os
+import sys
+import tempfile
+
+import safetensors
+
+
+sys.path.append("..")
+from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
+
+
+logging.basicConfig(level=logging.DEBUG)
+
+logger = logging.getLogger()
+stream_handler = logging.StreamHandler(sys.stdout)
+logger.addHandler(stream_handler)
+
+
+class DreamBoothLoRAFlux(ExamplesTestsAccelerate):
+ instance_data_dir = "docs/source/en/imgs"
+ instance_prompt = "photo"
+ pretrained_model_name_or_path = "hf-internal-testing/tiny-flux-pipe"
+ script_path = "examples/dreambooth/train_dreambooth_lora_flux.py"
+
+ def test_dreambooth_lora_flux(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ {self.script_path}
+ --pretrained_model_name_or_path {self.pretrained_model_name_or_path}
+ --instance_data_dir {self.instance_data_dir}
+ --instance_prompt {self.instance_prompt}
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ """.split()
+
+ run_command(self._launch_args + test_args)
+ # save_pretrained smoke test
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
+
+ # make sure the state_dict has the correct naming in the parameters.
+ lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
+ is_lora = all("lora" in k for k in lora_state_dict.keys())
+ self.assertTrue(is_lora)
+
+ # when not training the text encoder, all the parameters in the state dict should start
+ # with `"transformer"` in their names.
+ starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
+ self.assertTrue(starts_with_transformer)
+
+ def test_dreambooth_lora_text_encoder_flux(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ {self.script_path}
+ --pretrained_model_name_or_path {self.pretrained_model_name_or_path}
+ --instance_data_dir {self.instance_data_dir}
+ --instance_prompt {self.instance_prompt}
+ --resolution 64
+ --train_batch_size 1
+ --train_text_encoder
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ """.split()
+
+ run_command(self._launch_args + test_args)
+ # save_pretrained smoke test
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
+
+ # make sure the state_dict has the correct naming in the parameters.
+ lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
+ is_lora = all("lora" in k for k in lora_state_dict.keys())
+ self.assertTrue(is_lora)
+
+ starts_with_expected_prefix = all(
+ (key.startswith("transformer") or key.startswith("text_encoder")) for key in lora_state_dict.keys()
+ )
+ self.assertTrue(starts_with_expected_prefix)
+
+ def test_dreambooth_lora_latent_caching(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ {self.script_path}
+ --pretrained_model_name_or_path {self.pretrained_model_name_or_path}
+ --instance_data_dir {self.instance_data_dir}
+ --instance_prompt {self.instance_prompt}
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --cache_latents
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ """.split()
+
+ run_command(self._launch_args + test_args)
+ # save_pretrained smoke test
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
+
+ # make sure the state_dict has the correct naming in the parameters.
+ lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
+ is_lora = all("lora" in k for k in lora_state_dict.keys())
+ self.assertTrue(is_lora)
+
+ # when not training the text encoder, all the parameters in the state dict should start
+ # with `"transformer"` in their names.
+ starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
+ self.assertTrue(starts_with_transformer)
+
+ def test_dreambooth_lora_flux_checkpointing_checkpoints_total_limit(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ {self.script_path}
+ --pretrained_model_name_or_path={self.pretrained_model_name_or_path}
+ --instance_data_dir={self.instance_data_dir}
+ --output_dir={tmpdir}
+ --instance_prompt={self.instance_prompt}
+ --resolution=64
+ --train_batch_size=1
+ --gradient_accumulation_steps=1
+ --max_train_steps=6
+ --checkpoints_total_limit=2
+ --checkpointing_steps=2
+ """.split()
+
+ run_command(self._launch_args + test_args)
+
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-4", "checkpoint-6"},
+ )
+
+ def test_dreambooth_lora_flux_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ {self.script_path}
+ --pretrained_model_name_or_path={self.pretrained_model_name_or_path}
+ --instance_data_dir={self.instance_data_dir}
+ --output_dir={tmpdir}
+ --instance_prompt={self.instance_prompt}
+ --resolution=64
+ --train_batch_size=1
+ --gradient_accumulation_steps=1
+ --max_train_steps=4
+ --checkpointing_steps=2
+ """.split()
+
+ run_command(self._launch_args + test_args)
+
+ self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-2", "checkpoint-4"})
+
+ resume_run_args = f"""
+ {self.script_path}
+ --pretrained_model_name_or_path={self.pretrained_model_name_or_path}
+ --instance_data_dir={self.instance_data_dir}
+ --output_dir={tmpdir}
+ --instance_prompt={self.instance_prompt}
+ --resolution=64
+ --train_batch_size=1
+ --gradient_accumulation_steps=1
+ --max_train_steps=8
+ --checkpointing_steps=2
+ --resume_from_checkpoint=checkpoint-4
+ --checkpoints_total_limit=2
+ """.split()
+
+ run_command(self._launch_args + resume_run_args)
+
+ self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})
diff --git a/diffusers/examples/dreambooth/test_dreambooth_lora_sd3.py b/diffusers/examples/dreambooth/test_dreambooth_lora_sd3.py
new file mode 100644
index 0000000000000000000000000000000000000000..518738b78246b3f1d97c5a79d6f378a36d3e96d3
--- /dev/null
+++ b/diffusers/examples/dreambooth/test_dreambooth_lora_sd3.py
@@ -0,0 +1,165 @@
+# coding=utf-8
+# Copyright 2024 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import os
+import sys
+import tempfile
+
+import safetensors
+
+
+sys.path.append("..")
+from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
+
+
+logging.basicConfig(level=logging.DEBUG)
+
+logger = logging.getLogger()
+stream_handler = logging.StreamHandler(sys.stdout)
+logger.addHandler(stream_handler)
+
+
+class DreamBoothLoRASD3(ExamplesTestsAccelerate):
+ instance_data_dir = "docs/source/en/imgs"
+ instance_prompt = "photo"
+ pretrained_model_name_or_path = "hf-internal-testing/tiny-sd3-pipe"
+ script_path = "examples/dreambooth/train_dreambooth_lora_sd3.py"
+
+ def test_dreambooth_lora_sd3(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ {self.script_path}
+ --pretrained_model_name_or_path {self.pretrained_model_name_or_path}
+ --instance_data_dir {self.instance_data_dir}
+ --instance_prompt {self.instance_prompt}
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ """.split()
+
+ run_command(self._launch_args + test_args)
+ # save_pretrained smoke test
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
+
+ # make sure the state_dict has the correct naming in the parameters.
+ lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
+ is_lora = all("lora" in k for k in lora_state_dict.keys())
+ self.assertTrue(is_lora)
+
+ # when not training the text encoder, all the parameters in the state dict should start
+ # with `"transformer"` in their names.
+ starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
+ self.assertTrue(starts_with_transformer)
+
+ def test_dreambooth_lora_text_encoder_sd3(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ {self.script_path}
+ --pretrained_model_name_or_path {self.pretrained_model_name_or_path}
+ --instance_data_dir {self.instance_data_dir}
+ --instance_prompt {self.instance_prompt}
+ --resolution 64
+ --train_batch_size 1
+ --train_text_encoder
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ """.split()
+
+ run_command(self._launch_args + test_args)
+ # save_pretrained smoke test
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
+
+ # make sure the state_dict has the correct naming in the parameters.
+ lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
+ is_lora = all("lora" in k for k in lora_state_dict.keys())
+ self.assertTrue(is_lora)
+
+ starts_with_expected_prefix = all(
+ (key.startswith("transformer") or key.startswith("text_encoder")) for key in lora_state_dict.keys()
+ )
+ self.assertTrue(starts_with_expected_prefix)
+
+ def test_dreambooth_lora_sd3_checkpointing_checkpoints_total_limit(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ {self.script_path}
+ --pretrained_model_name_or_path={self.pretrained_model_name_or_path}
+ --instance_data_dir={self.instance_data_dir}
+ --output_dir={tmpdir}
+ --instance_prompt={self.instance_prompt}
+ --resolution=64
+ --train_batch_size=1
+ --gradient_accumulation_steps=1
+ --max_train_steps=6
+ --checkpoints_total_limit=2
+ --checkpointing_steps=2
+ """.split()
+
+ run_command(self._launch_args + test_args)
+
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-4", "checkpoint-6"},
+ )
+
+ def test_dreambooth_lora_sd3_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ {self.script_path}
+ --pretrained_model_name_or_path={self.pretrained_model_name_or_path}
+ --instance_data_dir={self.instance_data_dir}
+ --output_dir={tmpdir}
+ --instance_prompt={self.instance_prompt}
+ --resolution=64
+ --train_batch_size=1
+ --gradient_accumulation_steps=1
+ --max_train_steps=4
+ --checkpointing_steps=2
+ """.split()
+
+ run_command(self._launch_args + test_args)
+
+ self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-2", "checkpoint-4"})
+
+ resume_run_args = f"""
+ {self.script_path}
+ --pretrained_model_name_or_path={self.pretrained_model_name_or_path}
+ --instance_data_dir={self.instance_data_dir}
+ --output_dir={tmpdir}
+ --instance_prompt={self.instance_prompt}
+ --resolution=64
+ --train_batch_size=1
+ --gradient_accumulation_steps=1
+ --max_train_steps=8
+ --checkpointing_steps=2
+ --resume_from_checkpoint=checkpoint-4
+ --checkpoints_total_limit=2
+ """.split()
+
+ run_command(self._launch_args + resume_run_args)
+
+ self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})
diff --git a/diffusers/examples/dreambooth/test_dreambooth_sd3.py b/diffusers/examples/dreambooth/test_dreambooth_sd3.py
new file mode 100644
index 0000000000000000000000000000000000000000..19fb7243cefdb8fefd2b6fc0b3d36e0c1522c1bf
--- /dev/null
+++ b/diffusers/examples/dreambooth/test_dreambooth_sd3.py
@@ -0,0 +1,203 @@
+# coding=utf-8
+# Copyright 2024 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import os
+import shutil
+import sys
+import tempfile
+
+from diffusers import DiffusionPipeline, SD3Transformer2DModel
+
+
+sys.path.append("..")
+from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
+
+
+logging.basicConfig(level=logging.DEBUG)
+
+logger = logging.getLogger()
+stream_handler = logging.StreamHandler(sys.stdout)
+logger.addHandler(stream_handler)
+
+
+class DreamBoothSD3(ExamplesTestsAccelerate):
+ instance_data_dir = "docs/source/en/imgs"
+ instance_prompt = "photo"
+ pretrained_model_name_or_path = "hf-internal-testing/tiny-sd3-pipe"
+ script_path = "examples/dreambooth/train_dreambooth_sd3.py"
+
+ def test_dreambooth(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ {self.script_path}
+ --pretrained_model_name_or_path {self.pretrained_model_name_or_path}
+ --instance_data_dir {self.instance_data_dir}
+ --instance_prompt {self.instance_prompt}
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ """.split()
+
+ run_command(self._launch_args + test_args)
+ # save_pretrained smoke test
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "transformer", "diffusion_pytorch_model.safetensors")))
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json")))
+
+ def test_dreambooth_checkpointing(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ # Run training script with checkpointing
+ # max_train_steps == 4, checkpointing_steps == 2
+ # Should create checkpoints at steps 2, 4
+
+ initial_run_args = f"""
+ {self.script_path}
+ --pretrained_model_name_or_path {self.pretrained_model_name_or_path}
+ --instance_data_dir {self.instance_data_dir}
+ --instance_prompt {self.instance_prompt}
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 4
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --checkpointing_steps=2
+ --seed=0
+ """.split()
+
+ run_command(self._launch_args + initial_run_args)
+
+ # check can run the original fully trained output pipeline
+ pipe = DiffusionPipeline.from_pretrained(tmpdir)
+ pipe(self.instance_prompt, num_inference_steps=1)
+
+ # check checkpoint directories exist
+ self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-2")))
+ self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-4")))
+
+ # check can run an intermediate checkpoint
+ transformer = SD3Transformer2DModel.from_pretrained(tmpdir, subfolder="checkpoint-2/transformer")
+ pipe = DiffusionPipeline.from_pretrained(self.pretrained_model_name_or_path, transformer=transformer)
+ pipe(self.instance_prompt, num_inference_steps=1)
+
+ # Remove checkpoint 2 so that we can check only later checkpoints exist after resuming
+ shutil.rmtree(os.path.join(tmpdir, "checkpoint-2"))
+
+ # Run training script for 7 total steps resuming from checkpoint 4
+
+ resume_run_args = f"""
+ {self.script_path}
+ --pretrained_model_name_or_path {self.pretrained_model_name_or_path}
+ --instance_data_dir {self.instance_data_dir}
+ --instance_prompt {self.instance_prompt}
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 6
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --checkpointing_steps=2
+ --resume_from_checkpoint=checkpoint-4
+ --seed=0
+ """.split()
+
+ run_command(self._launch_args + resume_run_args)
+
+ # check can run new fully trained pipeline
+ pipe = DiffusionPipeline.from_pretrained(tmpdir)
+ pipe(self.instance_prompt, num_inference_steps=1)
+
+ # check old checkpoints do not exist
+ self.assertFalse(os.path.isdir(os.path.join(tmpdir, "checkpoint-2")))
+
+ # check new checkpoints exist
+ self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-4")))
+ self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-6")))
+
+ def test_dreambooth_checkpointing_checkpoints_total_limit(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ {self.script_path}
+ --pretrained_model_name_or_path={self.pretrained_model_name_or_path}
+ --instance_data_dir={self.instance_data_dir}
+ --output_dir={tmpdir}
+ --instance_prompt={self.instance_prompt}
+ --resolution=64
+ --train_batch_size=1
+ --gradient_accumulation_steps=1
+ --max_train_steps=6
+ --checkpoints_total_limit=2
+ --checkpointing_steps=2
+ """.split()
+
+ run_command(self._launch_args + test_args)
+
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-4", "checkpoint-6"},
+ )
+
+ def test_dreambooth_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ {self.script_path}
+ --pretrained_model_name_or_path={self.pretrained_model_name_or_path}
+ --instance_data_dir={self.instance_data_dir}
+ --output_dir={tmpdir}
+ --instance_prompt={self.instance_prompt}
+ --resolution=64
+ --train_batch_size=1
+ --gradient_accumulation_steps=1
+ --max_train_steps=4
+ --checkpointing_steps=2
+ """.split()
+
+ run_command(self._launch_args + test_args)
+
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-2", "checkpoint-4"},
+ )
+
+ resume_run_args = f"""
+ {self.script_path}
+ --pretrained_model_name_or_path={self.pretrained_model_name_or_path}
+ --instance_data_dir={self.instance_data_dir}
+ --output_dir={tmpdir}
+ --instance_prompt={self.instance_prompt}
+ --resolution=64
+ --train_batch_size=1
+ --gradient_accumulation_steps=1
+ --max_train_steps=8
+ --checkpointing_steps=2
+ --resume_from_checkpoint=checkpoint-4
+ --checkpoints_total_limit=2
+ """.split()
+
+ run_command(self._launch_args + resume_run_args)
+
+ self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})
diff --git a/diffusers/examples/dreambooth/train_dreambooth.py b/diffusers/examples/dreambooth/train_dreambooth.py
new file mode 100644
index 0000000000000000000000000000000000000000..5099107118e4e10767d5e9ecb0babfac7a437e5f
--- /dev/null
+++ b/diffusers/examples/dreambooth/train_dreambooth.py
@@ -0,0 +1,1443 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+import argparse
+import copy
+import gc
+import importlib
+import itertools
+import logging
+import math
+import os
+import shutil
+import warnings
+from pathlib import Path
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import ProjectConfiguration, set_seed
+from huggingface_hub import create_repo, model_info, upload_folder
+from huggingface_hub.utils import insecure_hashlib
+from packaging import version
+from PIL import Image
+from PIL.ImageOps import exif_transpose
+from torch.utils.data import Dataset
+from torchvision import transforms
+from tqdm.auto import tqdm
+from transformers import AutoTokenizer, PretrainedConfig
+
+import diffusers
+from diffusers import (
+ AutoencoderKL,
+ DDPMScheduler,
+ DiffusionPipeline,
+ StableDiffusionPipeline,
+ UNet2DConditionModel,
+)
+from diffusers.optimization import get_scheduler
+from diffusers.training_utils import compute_snr
+from diffusers.utils import check_min_version, is_wandb_available
+from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
+from diffusers.utils.import_utils import is_xformers_available
+from diffusers.utils.torch_utils import is_compiled_module
+
+
+if is_wandb_available():
+ import wandb
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.31.0.dev0")
+
+logger = get_logger(__name__)
+
+
+def save_model_card(
+ repo_id: str,
+ images: list = None,
+ base_model: str = None,
+ train_text_encoder=False,
+ prompt: str = None,
+ repo_folder: str = None,
+ pipeline: DiffusionPipeline = None,
+):
+ img_str = ""
+ if images is not None:
+ for i, image in enumerate(images):
+ image.save(os.path.join(repo_folder, f"image_{i}.png"))
+ img_str += f"\n"
+
+ model_description = f"""
+# DreamBooth - {repo_id}
+
+This is a dreambooth model derived from {base_model}. The weights were trained on {prompt} using [DreamBooth](https://dreambooth.github.io/).
+You can find some example images in the following. \n
+{img_str}
+
+DreamBooth for the text encoder was enabled: {train_text_encoder}.
+"""
+ model_card = load_or_create_model_card(
+ repo_id_or_path=repo_id,
+ from_training=True,
+ license="creativeml-openrail-m",
+ base_model=base_model,
+ prompt=prompt,
+ model_description=model_description,
+ inference=True,
+ )
+
+ tags = ["text-to-image", "dreambooth", "diffusers-training"]
+ if isinstance(pipeline, StableDiffusionPipeline):
+ tags.extend(["stable-diffusion", "stable-diffusion-diffusers"])
+ else:
+ tags.extend(["if", "if-diffusers"])
+ model_card = populate_model_card(model_card, tags=tags)
+
+ model_card.save(os.path.join(repo_folder, "README.md"))
+
+
+def log_validation(
+ text_encoder,
+ tokenizer,
+ unet,
+ vae,
+ args,
+ accelerator,
+ weight_dtype,
+ global_step,
+ prompt_embeds,
+ negative_prompt_embeds,
+):
+ logger.info(
+ f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
+ f" {args.validation_prompt}."
+ )
+
+ pipeline_args = {}
+
+ if vae is not None:
+ pipeline_args["vae"] = vae
+
+ # create pipeline (note: unet and vae are loaded again in float32)
+ pipeline = DiffusionPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ tokenizer=tokenizer,
+ text_encoder=text_encoder,
+ unet=unet,
+ revision=args.revision,
+ variant=args.variant,
+ torch_dtype=weight_dtype,
+ **pipeline_args,
+ )
+
+ # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
+ scheduler_args = {}
+
+ if "variance_type" in pipeline.scheduler.config:
+ variance_type = pipeline.scheduler.config.variance_type
+
+ if variance_type in ["learned", "learned_range"]:
+ variance_type = "fixed_small"
+
+ scheduler_args["variance_type"] = variance_type
+
+ module = importlib.import_module("diffusers")
+ scheduler_class = getattr(module, args.validation_scheduler)
+ pipeline.scheduler = scheduler_class.from_config(pipeline.scheduler.config, **scheduler_args)
+ pipeline = pipeline.to(accelerator.device)
+ pipeline.set_progress_bar_config(disable=True)
+
+ if args.pre_compute_text_embeddings:
+ pipeline_args = {
+ "prompt_embeds": prompt_embeds,
+ "negative_prompt_embeds": negative_prompt_embeds,
+ }
+ else:
+ pipeline_args = {"prompt": args.validation_prompt}
+
+ # run inference
+ generator = None if args.seed is None else torch.Generator(device=accelerator.device).manual_seed(args.seed)
+ images = []
+ if args.validation_images is None:
+ for _ in range(args.num_validation_images):
+ with torch.autocast("cuda"):
+ image = pipeline(**pipeline_args, num_inference_steps=25, generator=generator).images[0]
+ images.append(image)
+ else:
+ for image in args.validation_images:
+ image = Image.open(image)
+ image = pipeline(**pipeline_args, image=image, generator=generator).images[0]
+ images.append(image)
+
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images("validation", np_images, global_step, dataformats="NHWC")
+ if tracker.name == "wandb":
+ tracker.log(
+ {
+ "validation": [
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images)
+ ]
+ }
+ )
+
+ del pipeline
+ torch.cuda.empty_cache()
+
+ return images
+
+
+def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
+ text_encoder_config = PretrainedConfig.from_pretrained(
+ pretrained_model_name_or_path,
+ subfolder="text_encoder",
+ revision=revision,
+ )
+ model_class = text_encoder_config.architectures[0]
+
+ if model_class == "CLIPTextModel":
+ from transformers import CLIPTextModel
+
+ return CLIPTextModel
+ elif model_class == "RobertaSeriesModelWithTransformation":
+ from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation
+
+ return RobertaSeriesModelWithTransformation
+ elif model_class == "T5EncoderModel":
+ from transformers import T5EncoderModel
+
+ return T5EncoderModel
+ else:
+ raise ValueError(f"{model_class} is not supported.")
+
+
+def parse_args(input_args=None):
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--variant",
+ type=str,
+ default=None,
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
+ )
+ parser.add_argument(
+ "--tokenizer_name",
+ type=str,
+ default=None,
+ help="Pretrained tokenizer name or path if not the same as model_name",
+ )
+ parser.add_argument(
+ "--instance_data_dir",
+ type=str,
+ default=None,
+ required=True,
+ help="A folder containing the training data of instance images.",
+ )
+ parser.add_argument(
+ "--class_data_dir",
+ type=str,
+ default=None,
+ required=False,
+ help="A folder containing the training data of class images.",
+ )
+ parser.add_argument(
+ "--instance_prompt",
+ type=str,
+ default=None,
+ required=True,
+ help="The prompt with identifier specifying the instance",
+ )
+ parser.add_argument(
+ "--class_prompt",
+ type=str,
+ default=None,
+ help="The prompt to specify images in the same class as provided instance images.",
+ )
+ parser.add_argument(
+ "--with_prior_preservation",
+ default=False,
+ action="store_true",
+ help="Flag to add prior preservation loss.",
+ )
+ parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.")
+ parser.add_argument(
+ "--num_class_images",
+ type=int,
+ default=100,
+ help=(
+ "Minimal class images for prior preservation loss. If there are not enough images already present in"
+ " class_data_dir, additional images will be sampled with class_prompt."
+ ),
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="dreambooth-model",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--center_crop",
+ default=False,
+ action="store_true",
+ help=(
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
+ " cropped. The images will be resized to the resolution first before cropping."
+ ),
+ )
+ parser.add_argument(
+ "--train_text_encoder",
+ action="store_true",
+ help="Whether to train the text encoder. If set, the text encoder should be float32 precision.",
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument(
+ "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=1)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. "
+ "In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference."
+ "Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components."
+ "See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step"
+ "instructions."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=(
+ "Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
+ " See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
+ " for more details"
+ ),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=5e-6,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--lr_num_cycles",
+ type=int,
+ default=1,
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
+ )
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
+ parser.add_argument(
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
+ )
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument(
+ "--validation_prompt",
+ type=str,
+ default=None,
+ help="A prompt that is used during validation to verify that the model is learning.",
+ )
+ parser.add_argument(
+ "--num_validation_images",
+ type=int,
+ default=4,
+ help="Number of images that should be generated during validation with `validation_prompt`.",
+ )
+ parser.add_argument(
+ "--validation_steps",
+ type=int,
+ default=100,
+ help=(
+ "Run validation every X steps. Validation consists of running the prompt"
+ " `args.validation_prompt` multiple times: `args.num_validation_images`"
+ " and logging the images."
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument(
+ "--prior_generation_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp32", "fp16", "bf16"],
+ help=(
+ "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32."
+ ),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+ parser.add_argument(
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
+ )
+ parser.add_argument(
+ "--set_grads_to_none",
+ action="store_true",
+ help=(
+ "Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain"
+ " behaviors, so disable this argument if it causes any problems. More info:"
+ " https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html"
+ ),
+ )
+
+ parser.add_argument(
+ "--offset_noise",
+ action="store_true",
+ default=False,
+ help=(
+ "Fine-tuning against a modified noise"
+ " See: https://www.crosslabs.org//blog/diffusion-with-offset-noise for more information."
+ ),
+ )
+ parser.add_argument(
+ "--snr_gamma",
+ type=float,
+ default=None,
+ help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
+ "More details here: https://arxiv.org/abs/2303.09556.",
+ )
+ parser.add_argument(
+ "--pre_compute_text_embeddings",
+ action="store_true",
+ help="Whether or not to pre-compute text embeddings. If text embeddings are pre-computed, the text encoder will not be kept in memory during training and will leave more GPU memory available for training the rest of the model. This is not compatible with `--train_text_encoder`.",
+ )
+ parser.add_argument(
+ "--tokenizer_max_length",
+ type=int,
+ default=None,
+ required=False,
+ help="The maximum length of the tokenizer. If not set, will default to the tokenizer's max length.",
+ )
+ parser.add_argument(
+ "--text_encoder_use_attention_mask",
+ action="store_true",
+ required=False,
+ help="Whether to use attention mask for the text encoder",
+ )
+ parser.add_argument(
+ "--skip_save_text_encoder", action="store_true", required=False, help="Set to not save text encoder"
+ )
+ parser.add_argument(
+ "--validation_images",
+ required=False,
+ default=None,
+ nargs="+",
+ help="Optional set of images to use for validation. Used when the target pipeline takes an initial image as input such as when training image variation or superresolution.",
+ )
+ parser.add_argument(
+ "--class_labels_conditioning",
+ required=False,
+ default=None,
+ help="The optional `class_label` conditioning to pass to the unet, available values are `timesteps`.",
+ )
+ parser.add_argument(
+ "--validation_scheduler",
+ type=str,
+ default="DPMSolverMultistepScheduler",
+ choices=["DPMSolverMultistepScheduler", "DDPMScheduler"],
+ help="Select which scheduler to use for validation. DDPMScheduler is recommended for DeepFloyd IF.",
+ )
+
+ if input_args is not None:
+ args = parser.parse_args(input_args)
+ else:
+ args = parser.parse_args()
+
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ if args.with_prior_preservation:
+ if args.class_data_dir is None:
+ raise ValueError("You must specify a data directory for class images.")
+ if args.class_prompt is None:
+ raise ValueError("You must specify prompt for class images.")
+ else:
+ # logger is not available yet
+ if args.class_data_dir is not None:
+ warnings.warn("You need not use --class_data_dir without --with_prior_preservation.")
+ if args.class_prompt is not None:
+ warnings.warn("You need not use --class_prompt without --with_prior_preservation.")
+
+ if args.train_text_encoder and args.pre_compute_text_embeddings:
+ raise ValueError("`--train_text_encoder` cannot be used with `--pre_compute_text_embeddings`")
+
+ return args
+
+
+class DreamBoothDataset(Dataset):
+ """
+ A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
+ It pre-processes the images and the tokenizes prompts.
+ """
+
+ def __init__(
+ self,
+ instance_data_root,
+ instance_prompt,
+ tokenizer,
+ class_data_root=None,
+ class_prompt=None,
+ class_num=None,
+ size=512,
+ center_crop=False,
+ encoder_hidden_states=None,
+ class_prompt_encoder_hidden_states=None,
+ tokenizer_max_length=None,
+ ):
+ self.size = size
+ self.center_crop = center_crop
+ self.tokenizer = tokenizer
+ self.encoder_hidden_states = encoder_hidden_states
+ self.class_prompt_encoder_hidden_states = class_prompt_encoder_hidden_states
+ self.tokenizer_max_length = tokenizer_max_length
+
+ self.instance_data_root = Path(instance_data_root)
+ if not self.instance_data_root.exists():
+ raise ValueError(f"Instance {self.instance_data_root} images root doesn't exists.")
+
+ self.instance_images_path = list(Path(instance_data_root).iterdir())
+ self.num_instance_images = len(self.instance_images_path)
+ self.instance_prompt = instance_prompt
+ self._length = self.num_instance_images
+
+ if class_data_root is not None:
+ self.class_data_root = Path(class_data_root)
+ self.class_data_root.mkdir(parents=True, exist_ok=True)
+ self.class_images_path = list(self.class_data_root.iterdir())
+ if class_num is not None:
+ self.num_class_images = min(len(self.class_images_path), class_num)
+ else:
+ self.num_class_images = len(self.class_images_path)
+ self._length = max(self.num_class_images, self.num_instance_images)
+ self.class_prompt = class_prompt
+ else:
+ self.class_data_root = None
+
+ self.image_transforms = transforms.Compose(
+ [
+ transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
+ transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+ )
+
+ def __len__(self):
+ return self._length
+
+ def __getitem__(self, index):
+ example = {}
+ instance_image = Image.open(self.instance_images_path[index % self.num_instance_images])
+ instance_image = exif_transpose(instance_image)
+
+ if not instance_image.mode == "RGB":
+ instance_image = instance_image.convert("RGB")
+ example["instance_images"] = self.image_transforms(instance_image)
+
+ if self.encoder_hidden_states is not None:
+ example["instance_prompt_ids"] = self.encoder_hidden_states
+ else:
+ text_inputs = tokenize_prompt(
+ self.tokenizer, self.instance_prompt, tokenizer_max_length=self.tokenizer_max_length
+ )
+ example["instance_prompt_ids"] = text_inputs.input_ids
+ example["instance_attention_mask"] = text_inputs.attention_mask
+
+ if self.class_data_root:
+ class_image = Image.open(self.class_images_path[index % self.num_class_images])
+ class_image = exif_transpose(class_image)
+
+ if not class_image.mode == "RGB":
+ class_image = class_image.convert("RGB")
+ example["class_images"] = self.image_transforms(class_image)
+
+ if self.class_prompt_encoder_hidden_states is not None:
+ example["class_prompt_ids"] = self.class_prompt_encoder_hidden_states
+ else:
+ class_text_inputs = tokenize_prompt(
+ self.tokenizer, self.class_prompt, tokenizer_max_length=self.tokenizer_max_length
+ )
+ example["class_prompt_ids"] = class_text_inputs.input_ids
+ example["class_attention_mask"] = class_text_inputs.attention_mask
+
+ return example
+
+
+def collate_fn(examples, with_prior_preservation=False):
+ has_attention_mask = "instance_attention_mask" in examples[0]
+
+ input_ids = [example["instance_prompt_ids"] for example in examples]
+ pixel_values = [example["instance_images"] for example in examples]
+
+ if has_attention_mask:
+ attention_mask = [example["instance_attention_mask"] for example in examples]
+
+ # Concat class and instance examples for prior preservation.
+ # We do this to avoid doing two forward passes.
+ if with_prior_preservation:
+ input_ids += [example["class_prompt_ids"] for example in examples]
+ pixel_values += [example["class_images"] for example in examples]
+
+ if has_attention_mask:
+ attention_mask += [example["class_attention_mask"] for example in examples]
+
+ pixel_values = torch.stack(pixel_values)
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
+
+ input_ids = torch.cat(input_ids, dim=0)
+
+ batch = {
+ "input_ids": input_ids,
+ "pixel_values": pixel_values,
+ }
+
+ if has_attention_mask:
+ attention_mask = torch.cat(attention_mask, dim=0)
+ batch["attention_mask"] = attention_mask
+
+ return batch
+
+
+class PromptDataset(Dataset):
+ """A simple dataset to prepare the prompts to generate class images on multiple GPUs."""
+
+ def __init__(self, prompt, num_samples):
+ self.prompt = prompt
+ self.num_samples = num_samples
+
+ def __len__(self):
+ return self.num_samples
+
+ def __getitem__(self, index):
+ example = {}
+ example["prompt"] = self.prompt
+ example["index"] = index
+ return example
+
+
+def model_has_vae(args):
+ config_file_name = Path("vae", AutoencoderKL.config_name).as_posix()
+ if os.path.isdir(args.pretrained_model_name_or_path):
+ config_file_name = os.path.join(args.pretrained_model_name_or_path, config_file_name)
+ return os.path.isfile(config_file_name)
+ else:
+ files_in_repo = model_info(args.pretrained_model_name_or_path, revision=args.revision).siblings
+ return any(file.rfilename == config_file_name for file in files_in_repo)
+
+
+def tokenize_prompt(tokenizer, prompt, tokenizer_max_length=None):
+ if tokenizer_max_length is not None:
+ max_length = tokenizer_max_length
+ else:
+ max_length = tokenizer.model_max_length
+
+ text_inputs = tokenizer(
+ prompt,
+ truncation=True,
+ padding="max_length",
+ max_length=max_length,
+ return_tensors="pt",
+ )
+
+ return text_inputs
+
+
+def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_attention_mask=None):
+ text_input_ids = input_ids.to(text_encoder.device)
+
+ if text_encoder_use_attention_mask:
+ attention_mask = attention_mask.to(text_encoder.device)
+ else:
+ attention_mask = None
+
+ prompt_embeds = text_encoder(
+ text_input_ids,
+ attention_mask=attention_mask,
+ return_dict=False,
+ )
+ prompt_embeds = prompt_embeds[0]
+
+ return prompt_embeds
+
+
+def main(args):
+ if args.report_to == "wandb" and args.hub_token is not None:
+ raise ValueError(
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
+ " Please use `huggingface-cli login` to authenticate with the Hub."
+ )
+
+ logging_dir = Path(args.output_dir, args.logging_dir)
+
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ )
+
+ # Disable AMP for MPS.
+ if torch.backends.mps.is_available():
+ accelerator.native_amp = False
+
+ if args.report_to == "wandb":
+ if not is_wandb_available():
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
+
+ # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
+ # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
+ # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.
+ if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1:
+ raise ValueError(
+ "Gradient accumulation is not supported when training the text encoder in distributed training. "
+ "Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
+ )
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Generate class images if prior preservation is enabled.
+ if args.with_prior_preservation:
+ class_images_dir = Path(args.class_data_dir)
+ if not class_images_dir.exists():
+ class_images_dir.mkdir(parents=True)
+ cur_class_images = len(list(class_images_dir.iterdir()))
+
+ if cur_class_images < args.num_class_images:
+ torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
+ if args.prior_generation_precision == "fp32":
+ torch_dtype = torch.float32
+ elif args.prior_generation_precision == "fp16":
+ torch_dtype = torch.float16
+ elif args.prior_generation_precision == "bf16":
+ torch_dtype = torch.bfloat16
+ pipeline = DiffusionPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ torch_dtype=torch_dtype,
+ safety_checker=None,
+ revision=args.revision,
+ variant=args.variant,
+ )
+ pipeline.set_progress_bar_config(disable=True)
+
+ num_new_images = args.num_class_images - cur_class_images
+ logger.info(f"Number of class images to sample: {num_new_images}.")
+
+ sample_dataset = PromptDataset(args.class_prompt, num_new_images)
+ sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
+
+ sample_dataloader = accelerator.prepare(sample_dataloader)
+ pipeline.to(accelerator.device)
+
+ for example in tqdm(
+ sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
+ ):
+ images = pipeline(example["prompt"]).images
+
+ for i, image in enumerate(images):
+ hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest()
+ image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
+ image.save(image_filename)
+
+ del pipeline
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
+ ).repo_id
+
+ # Load the tokenizer
+ if args.tokenizer_name:
+ tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False)
+ elif args.pretrained_model_name_or_path:
+ tokenizer = AutoTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="tokenizer",
+ revision=args.revision,
+ use_fast=False,
+ )
+
+ # import correct text encoder class
+ text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)
+
+ # Load scheduler and models
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
+ text_encoder = text_encoder_cls.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
+ )
+
+ if model_has_vae(args):
+ vae = AutoencoderKL.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant
+ )
+ else:
+ vae = None
+
+ unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
+ )
+
+ def unwrap_model(model):
+ model = accelerator.unwrap_model(model)
+ model = model._orig_mod if is_compiled_module(model) else model
+ return model
+
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
+ def save_model_hook(models, weights, output_dir):
+ if accelerator.is_main_process:
+ for model in models:
+ sub_dir = "unet" if isinstance(model, type(unwrap_model(unet))) else "text_encoder"
+ model.save_pretrained(os.path.join(output_dir, sub_dir))
+
+ # make sure to pop weight so that corresponding model is not saved again
+ weights.pop()
+
+ def load_model_hook(models, input_dir):
+ while len(models) > 0:
+ # pop models so that they are not loaded again
+ model = models.pop()
+
+ if isinstance(model, type(unwrap_model(text_encoder))):
+ # load transformers style into model
+ load_model = text_encoder_cls.from_pretrained(input_dir, subfolder="text_encoder")
+ model.config = load_model.config
+ else:
+ # load diffusers style into model
+ load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet")
+ model.register_to_config(**load_model.config)
+
+ model.load_state_dict(load_model.state_dict())
+ del load_model
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
+
+ if vae is not None:
+ vae.requires_grad_(False)
+
+ if not args.train_text_encoder:
+ text_encoder.requires_grad_(False)
+
+ if args.enable_xformers_memory_efficient_attention:
+ if is_xformers_available():
+ import xformers
+
+ xformers_version = version.parse(xformers.__version__)
+ if xformers_version == version.parse("0.0.16"):
+ logger.warning(
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
+ )
+ unet.enable_xformers_memory_efficient_attention()
+ else:
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
+
+ if args.gradient_checkpointing:
+ unet.enable_gradient_checkpointing()
+ if args.train_text_encoder:
+ text_encoder.gradient_checkpointing_enable()
+
+ # Check that all trainable models are in full precision
+ low_precision_error_string = (
+ "Please make sure to always have all model weights in full float32 precision when starting training - even if"
+ " doing mixed precision training. copy of the weights should still be float32."
+ )
+
+ if unwrap_model(unet).dtype != torch.float32:
+ raise ValueError(f"Unet loaded as datatype {unwrap_model(unet).dtype}. {low_precision_error_string}")
+
+ if args.train_text_encoder and unwrap_model(text_encoder).dtype != torch.float32:
+ raise ValueError(
+ f"Text encoder loaded as datatype {unwrap_model(text_encoder).dtype}." f" {low_precision_error_string}"
+ )
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
+ )
+
+ optimizer_class = bnb.optim.AdamW8bit
+ else:
+ optimizer_class = torch.optim.AdamW
+
+ # Optimizer creation
+ params_to_optimize = (
+ itertools.chain(unet.parameters(), text_encoder.parameters()) if args.train_text_encoder else unet.parameters()
+ )
+ optimizer = optimizer_class(
+ params_to_optimize,
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ if args.pre_compute_text_embeddings:
+
+ def compute_text_embeddings(prompt):
+ with torch.no_grad():
+ text_inputs = tokenize_prompt(tokenizer, prompt, tokenizer_max_length=args.tokenizer_max_length)
+ prompt_embeds = encode_prompt(
+ text_encoder,
+ text_inputs.input_ids,
+ text_inputs.attention_mask,
+ text_encoder_use_attention_mask=args.text_encoder_use_attention_mask,
+ )
+
+ return prompt_embeds
+
+ pre_computed_encoder_hidden_states = compute_text_embeddings(args.instance_prompt)
+ validation_prompt_negative_prompt_embeds = compute_text_embeddings("")
+
+ if args.validation_prompt is not None:
+ validation_prompt_encoder_hidden_states = compute_text_embeddings(args.validation_prompt)
+ else:
+ validation_prompt_encoder_hidden_states = None
+
+ if args.class_prompt is not None:
+ pre_computed_class_prompt_encoder_hidden_states = compute_text_embeddings(args.class_prompt)
+ else:
+ pre_computed_class_prompt_encoder_hidden_states = None
+
+ text_encoder = None
+ tokenizer = None
+
+ gc.collect()
+ torch.cuda.empty_cache()
+ else:
+ pre_computed_encoder_hidden_states = None
+ validation_prompt_encoder_hidden_states = None
+ validation_prompt_negative_prompt_embeds = None
+ pre_computed_class_prompt_encoder_hidden_states = None
+
+ # Dataset and DataLoaders creation:
+ train_dataset = DreamBoothDataset(
+ instance_data_root=args.instance_data_dir,
+ instance_prompt=args.instance_prompt,
+ class_data_root=args.class_data_dir if args.with_prior_preservation else None,
+ class_prompt=args.class_prompt,
+ class_num=args.num_class_images,
+ tokenizer=tokenizer,
+ size=args.resolution,
+ center_crop=args.center_crop,
+ encoder_hidden_states=pre_computed_encoder_hidden_states,
+ class_prompt_encoder_hidden_states=pre_computed_class_prompt_encoder_hidden_states,
+ tokenizer_max_length=args.tokenizer_max_length,
+ )
+
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset,
+ batch_size=args.train_batch_size,
+ shuffle=True,
+ collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),
+ num_workers=args.dataloader_num_workers,
+ )
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
+ num_cycles=args.lr_num_cycles,
+ power=args.lr_power,
+ )
+
+ # Prepare everything with our `accelerator`.
+ if args.train_text_encoder:
+ unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ unet, text_encoder, optimizer, train_dataloader, lr_scheduler
+ )
+ else:
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ unet, optimizer, train_dataloader, lr_scheduler
+ )
+
+ # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
+ # as these weights are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ # Move vae and text_encoder to device and cast to weight_dtype
+ if vae is not None:
+ vae.to(accelerator.device, dtype=weight_dtype)
+
+ if not args.train_text_encoder and text_encoder is not None:
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ tracker_config = vars(copy.deepcopy(args))
+ tracker_config.pop("validation_images")
+ accelerator.init_trackers("dreambooth", config=tracker_config)
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the most recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ initial_global_step = 0
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ initial_global_step = global_step
+ first_epoch = global_step // num_update_steps_per_epoch
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
+
+ for epoch in range(first_epoch, args.num_train_epochs):
+ unet.train()
+ if args.train_text_encoder:
+ text_encoder.train()
+ for step, batch in enumerate(train_dataloader):
+ with accelerator.accumulate(unet):
+ pixel_values = batch["pixel_values"].to(dtype=weight_dtype)
+
+ if vae is not None:
+ # Convert images to latent space
+ model_input = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
+ model_input = model_input * vae.config.scaling_factor
+ else:
+ model_input = pixel_values
+
+ # Sample noise that we'll add to the model input
+ if args.offset_noise:
+ noise = torch.randn_like(model_input) + 0.1 * torch.randn(
+ model_input.shape[0], model_input.shape[1], 1, 1, device=model_input.device
+ )
+ else:
+ noise = torch.randn_like(model_input)
+ bsz, channels, height, width = model_input.shape
+ # Sample a random timestep for each image
+ timesteps = torch.randint(
+ 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device
+ )
+ timesteps = timesteps.long()
+
+ # Add noise to the model input according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)
+
+ # Get the text embedding for conditioning
+ if args.pre_compute_text_embeddings:
+ encoder_hidden_states = batch["input_ids"]
+ else:
+ encoder_hidden_states = encode_prompt(
+ text_encoder,
+ batch["input_ids"],
+ batch["attention_mask"],
+ text_encoder_use_attention_mask=args.text_encoder_use_attention_mask,
+ )
+
+ if unwrap_model(unet).config.in_channels == channels * 2:
+ noisy_model_input = torch.cat([noisy_model_input, noisy_model_input], dim=1)
+
+ if args.class_labels_conditioning == "timesteps":
+ class_labels = timesteps
+ else:
+ class_labels = None
+
+ # Predict the noise residual
+ model_pred = unet(
+ noisy_model_input, timesteps, encoder_hidden_states, class_labels=class_labels, return_dict=False
+ )[0]
+
+ if model_pred.shape[1] == 6:
+ model_pred, _ = torch.chunk(model_pred, 2, dim=1)
+
+ # Get the target for loss depending on the prediction type
+ if noise_scheduler.config.prediction_type == "epsilon":
+ target = noise
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ target = noise_scheduler.get_velocity(model_input, noise, timesteps)
+ else:
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
+
+ if args.with_prior_preservation:
+ # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
+ model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
+ target, target_prior = torch.chunk(target, 2, dim=0)
+ # Compute prior loss
+ prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
+
+ # Compute instance loss
+ if args.snr_gamma is None:
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
+ else:
+ # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
+ # Since we predict the noise instead of x_0, the original formulation is slightly changed.
+ # This is discussed in Section 4.2 of the same paper.
+ snr = compute_snr(noise_scheduler, timesteps)
+ base_weight = (
+ torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
+ )
+
+ if noise_scheduler.config.prediction_type == "v_prediction":
+ # Velocity objective needs to be floored to an SNR weight of one.
+ mse_loss_weights = base_weight + 1
+ else:
+ # Epsilon and sample both use the same loss weights.
+ mse_loss_weights = base_weight
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
+ loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
+ loss = loss.mean()
+
+ if args.with_prior_preservation:
+ # Add the prior loss to the instance loss.
+ loss = loss + args.prior_loss_weight * prior_loss
+
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ params_to_clip = (
+ itertools.chain(unet.parameters(), text_encoder.parameters())
+ if args.train_text_encoder
+ else unet.parameters()
+ )
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad(set_to_none=args.set_grads_to_none)
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ global_step += 1
+
+ if accelerator.is_main_process:
+ if global_step % args.checkpointing_steps == 0:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ images = []
+
+ if args.validation_prompt is not None and global_step % args.validation_steps == 0:
+ images = log_validation(
+ unwrap_model(text_encoder) if text_encoder is not None else text_encoder,
+ tokenizer,
+ unwrap_model(unet),
+ vae,
+ args,
+ accelerator,
+ weight_dtype,
+ global_step,
+ validation_prompt_encoder_hidden_states,
+ validation_prompt_negative_prompt_embeds,
+ )
+
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+ accelerator.log(logs, step=global_step)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ # Create the pipeline using the trained modules and save it.
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ pipeline_args = {}
+
+ if text_encoder is not None:
+ pipeline_args["text_encoder"] = unwrap_model(text_encoder)
+
+ if args.skip_save_text_encoder:
+ pipeline_args["text_encoder"] = None
+
+ pipeline = DiffusionPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ unet=unwrap_model(unet),
+ revision=args.revision,
+ variant=args.variant,
+ **pipeline_args,
+ )
+
+ # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
+ scheduler_args = {}
+
+ if "variance_type" in pipeline.scheduler.config:
+ variance_type = pipeline.scheduler.config.variance_type
+
+ if variance_type in ["learned", "learned_range"]:
+ variance_type = "fixed_small"
+
+ scheduler_args["variance_type"] = variance_type
+
+ pipeline.scheduler = pipeline.scheduler.from_config(pipeline.scheduler.config, **scheduler_args)
+
+ pipeline.save_pretrained(args.output_dir)
+
+ if args.push_to_hub:
+ save_model_card(
+ repo_id,
+ images=images,
+ base_model=args.pretrained_model_name_or_path,
+ train_text_encoder=args.train_text_encoder,
+ prompt=args.instance_prompt,
+ repo_folder=args.output_dir,
+ pipeline=pipeline,
+ )
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ main(args)
diff --git a/diffusers/examples/dreambooth/train_dreambooth_flax.py b/diffusers/examples/dreambooth/train_dreambooth_flax.py
new file mode 100644
index 0000000000000000000000000000000000000000..29fd5e78535de36097f6adc52dec00d475b78aba
--- /dev/null
+++ b/diffusers/examples/dreambooth/train_dreambooth_flax.py
@@ -0,0 +1,701 @@
+import argparse
+import logging
+import math
+import os
+from pathlib import Path
+
+import jax
+import jax.numpy as jnp
+import numpy as np
+import optax
+import torch
+import torch.utils.checkpoint
+import transformers
+from flax import jax_utils
+from flax.training import train_state
+from flax.training.common_utils import shard
+from huggingface_hub import create_repo, upload_folder
+from huggingface_hub.utils import insecure_hashlib
+from jax.experimental.compilation_cache import compilation_cache as cc
+from PIL import Image
+from torch.utils.data import Dataset
+from torchvision import transforms
+from tqdm.auto import tqdm
+from transformers import CLIPImageProcessor, CLIPTokenizer, FlaxCLIPTextModel, set_seed
+
+from diffusers import (
+ FlaxAutoencoderKL,
+ FlaxDDPMScheduler,
+ FlaxPNDMScheduler,
+ FlaxStableDiffusionPipeline,
+ FlaxUNet2DConditionModel,
+)
+from diffusers.pipelines.stable_diffusion import FlaxStableDiffusionSafetyChecker
+from diffusers.utils import check_min_version
+
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.31.0.dev0")
+
+# Cache compiled models across invocations of this script.
+cc.initialize_cache(os.path.expanduser("~/.cache/jax/compilation_cache"))
+
+logger = logging.getLogger(__name__)
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--pretrained_vae_name_or_path",
+ type=str,
+ default=None,
+ help="Path to pretrained vae or vae identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--tokenizer_name",
+ type=str,
+ default=None,
+ help="Pretrained tokenizer name or path if not the same as model_name",
+ )
+ parser.add_argument(
+ "--instance_data_dir",
+ type=str,
+ default=None,
+ required=True,
+ help="A folder containing the training data of instance images.",
+ )
+ parser.add_argument(
+ "--class_data_dir",
+ type=str,
+ default=None,
+ required=False,
+ help="A folder containing the training data of class images.",
+ )
+ parser.add_argument(
+ "--instance_prompt",
+ type=str,
+ default=None,
+ help="The prompt with identifier specifying the instance",
+ )
+ parser.add_argument(
+ "--class_prompt",
+ type=str,
+ default=None,
+ help="The prompt to specify images in the same class as provided instance images.",
+ )
+ parser.add_argument(
+ "--with_prior_preservation",
+ default=False,
+ action="store_true",
+ help="Flag to add prior preservation loss.",
+ )
+ parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.")
+ parser.add_argument(
+ "--num_class_images",
+ type=int,
+ default=100,
+ help=(
+ "Minimal class images for prior preservation loss. If there are not enough images already present in"
+ " class_data_dir, additional images will be sampled with class_prompt."
+ ),
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="text-inversion-model",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument("--save_steps", type=int, default=None, help="Save a checkpoint every X steps.")
+ parser.add_argument("--seed", type=int, default=0, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--center_crop",
+ default=False,
+ action="store_true",
+ help=(
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
+ " cropped. The images will be resized to the resolution first before cropping."
+ ),
+ )
+ parser.add_argument("--train_text_encoder", action="store_true", help="Whether to train the text encoder")
+ parser.add_argument(
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument(
+ "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=1)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=5e-6,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default="no",
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose"
+ "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
+ "and an Nvidia Ampere GPU."
+ ),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+
+ args = parser.parse_args()
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ if args.instance_data_dir is None:
+ raise ValueError("You must specify a train data directory.")
+
+ if args.with_prior_preservation:
+ if args.class_data_dir is None:
+ raise ValueError("You must specify a data directory for class images.")
+ if args.class_prompt is None:
+ raise ValueError("You must specify prompt for class images.")
+
+ return args
+
+
+class DreamBoothDataset(Dataset):
+ """
+ A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
+ It pre-processes the images and the tokenizes prompts.
+ """
+
+ def __init__(
+ self,
+ instance_data_root,
+ instance_prompt,
+ tokenizer,
+ class_data_root=None,
+ class_prompt=None,
+ class_num=None,
+ size=512,
+ center_crop=False,
+ ):
+ self.size = size
+ self.center_crop = center_crop
+ self.tokenizer = tokenizer
+
+ self.instance_data_root = Path(instance_data_root)
+ if not self.instance_data_root.exists():
+ raise ValueError("Instance images root doesn't exists.")
+
+ self.instance_images_path = list(Path(instance_data_root).iterdir())
+ self.num_instance_images = len(self.instance_images_path)
+ self.instance_prompt = instance_prompt
+ self._length = self.num_instance_images
+
+ if class_data_root is not None:
+ self.class_data_root = Path(class_data_root)
+ self.class_data_root.mkdir(parents=True, exist_ok=True)
+ self.class_images_path = list(self.class_data_root.iterdir())
+ if class_num is not None:
+ self.num_class_images = min(len(self.class_images_path), class_num)
+ else:
+ self.num_class_images = len(self.class_images_path)
+ self._length = max(self.num_class_images, self.num_instance_images)
+ self.class_prompt = class_prompt
+ else:
+ self.class_data_root = None
+
+ self.image_transforms = transforms.Compose(
+ [
+ transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
+ transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+ )
+
+ def __len__(self):
+ return self._length
+
+ def __getitem__(self, index):
+ example = {}
+ instance_image = Image.open(self.instance_images_path[index % self.num_instance_images])
+ if not instance_image.mode == "RGB":
+ instance_image = instance_image.convert("RGB")
+ example["instance_images"] = self.image_transforms(instance_image)
+ example["instance_prompt_ids"] = self.tokenizer(
+ self.instance_prompt,
+ padding="do_not_pad",
+ truncation=True,
+ max_length=self.tokenizer.model_max_length,
+ ).input_ids
+
+ if self.class_data_root:
+ class_image = Image.open(self.class_images_path[index % self.num_class_images])
+ if not class_image.mode == "RGB":
+ class_image = class_image.convert("RGB")
+ example["class_images"] = self.image_transforms(class_image)
+ example["class_prompt_ids"] = self.tokenizer(
+ self.class_prompt,
+ padding="do_not_pad",
+ truncation=True,
+ max_length=self.tokenizer.model_max_length,
+ ).input_ids
+
+ return example
+
+
+class PromptDataset(Dataset):
+ """A simple dataset to prepare the prompts to generate class images on multiple GPUs."""
+
+ def __init__(self, prompt, num_samples):
+ self.prompt = prompt
+ self.num_samples = num_samples
+
+ def __len__(self):
+ return self.num_samples
+
+ def __getitem__(self, index):
+ example = {}
+ example["prompt"] = self.prompt
+ example["index"] = index
+ return example
+
+
+def get_params_to_save(params):
+ return jax.device_get(jax.tree_util.tree_map(lambda x: x[0], params))
+
+
+def main():
+ args = parse_args()
+
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ # Setup logging, we only want one process per machine to log things on the screen.
+ logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
+ if jax.process_index() == 0:
+ transformers.utils.logging.set_verbosity_info()
+ else:
+ transformers.utils.logging.set_verbosity_error()
+
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ rng = jax.random.PRNGKey(args.seed)
+
+ if args.with_prior_preservation:
+ class_images_dir = Path(args.class_data_dir)
+ if not class_images_dir.exists():
+ class_images_dir.mkdir(parents=True)
+ cur_class_images = len(list(class_images_dir.iterdir()))
+
+ if cur_class_images < args.num_class_images:
+ pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
+ args.pretrained_model_name_or_path, safety_checker=None, revision=args.revision
+ )
+ pipeline.set_progress_bar_config(disable=True)
+
+ num_new_images = args.num_class_images - cur_class_images
+ logger.info(f"Number of class images to sample: {num_new_images}.")
+
+ sample_dataset = PromptDataset(args.class_prompt, num_new_images)
+ total_sample_batch_size = args.sample_batch_size * jax.local_device_count()
+ sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=total_sample_batch_size)
+
+ for example in tqdm(
+ sample_dataloader, desc="Generating class images", disable=not jax.process_index() == 0
+ ):
+ prompt_ids = pipeline.prepare_inputs(example["prompt"])
+ prompt_ids = shard(prompt_ids)
+ p_params = jax_utils.replicate(params)
+ rng = jax.random.split(rng)[0]
+ sample_rng = jax.random.split(rng, jax.device_count())
+ images = pipeline(prompt_ids, p_params, sample_rng, jit=True).images
+ images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])
+ images = pipeline.numpy_to_pil(np.array(images))
+
+ for i, image in enumerate(images):
+ hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest()
+ image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
+ image.save(image_filename)
+
+ del pipeline
+
+ # Handle the repository creation
+ if jax.process_index() == 0:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
+ ).repo_id
+
+ # Load the tokenizer and add the placeholder token as a additional special token
+ if args.tokenizer_name:
+ tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
+ elif args.pretrained_model_name_or_path:
+ tokenizer = CLIPTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
+ )
+ else:
+ raise NotImplementedError("No tokenizer specified!")
+
+ train_dataset = DreamBoothDataset(
+ instance_data_root=args.instance_data_dir,
+ instance_prompt=args.instance_prompt,
+ class_data_root=args.class_data_dir if args.with_prior_preservation else None,
+ class_prompt=args.class_prompt,
+ class_num=args.num_class_images,
+ tokenizer=tokenizer,
+ size=args.resolution,
+ center_crop=args.center_crop,
+ )
+
+ def collate_fn(examples):
+ input_ids = [example["instance_prompt_ids"] for example in examples]
+ pixel_values = [example["instance_images"] for example in examples]
+
+ # Concat class and instance examples for prior preservation.
+ # We do this to avoid doing two forward passes.
+ if args.with_prior_preservation:
+ input_ids += [example["class_prompt_ids"] for example in examples]
+ pixel_values += [example["class_images"] for example in examples]
+
+ pixel_values = torch.stack(pixel_values)
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
+
+ input_ids = tokenizer.pad(
+ {"input_ids": input_ids}, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt"
+ ).input_ids
+
+ batch = {
+ "input_ids": input_ids,
+ "pixel_values": pixel_values,
+ }
+ batch = {k: v.numpy() for k, v in batch.items()}
+ return batch
+
+ total_train_batch_size = args.train_batch_size * jax.local_device_count()
+ if len(train_dataset) < total_train_batch_size:
+ raise ValueError(
+ f"Training batch size is {total_train_batch_size}, but your dataset only contains"
+ f" {len(train_dataset)} images. Please, use a larger dataset or reduce the effective batch size. Note that"
+ f" there are {jax.local_device_count()} parallel devices, so your batch size can't be smaller than that."
+ )
+
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset, batch_size=total_train_batch_size, shuffle=True, collate_fn=collate_fn, drop_last=True
+ )
+
+ weight_dtype = jnp.float32
+ if args.mixed_precision == "fp16":
+ weight_dtype = jnp.float16
+ elif args.mixed_precision == "bf16":
+ weight_dtype = jnp.bfloat16
+
+ if args.pretrained_vae_name_or_path:
+ # TODO(patil-suraj): Upload flax weights for the VAE
+ vae_arg, vae_kwargs = (args.pretrained_vae_name_or_path, {"from_pt": True})
+ else:
+ vae_arg, vae_kwargs = (args.pretrained_model_name_or_path, {"subfolder": "vae", "revision": args.revision})
+
+ # Load models and create wrapper for stable diffusion
+ text_encoder = FlaxCLIPTextModel.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="text_encoder",
+ dtype=weight_dtype,
+ revision=args.revision,
+ )
+ vae, vae_params = FlaxAutoencoderKL.from_pretrained(
+ vae_arg,
+ dtype=weight_dtype,
+ **vae_kwargs,
+ )
+ unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="unet",
+ dtype=weight_dtype,
+ revision=args.revision,
+ )
+
+ # Optimization
+ if args.scale_lr:
+ args.learning_rate = args.learning_rate * total_train_batch_size
+
+ constant_scheduler = optax.constant_schedule(args.learning_rate)
+
+ adamw = optax.adamw(
+ learning_rate=constant_scheduler,
+ b1=args.adam_beta1,
+ b2=args.adam_beta2,
+ eps=args.adam_epsilon,
+ weight_decay=args.adam_weight_decay,
+ )
+
+ optimizer = optax.chain(
+ optax.clip_by_global_norm(args.max_grad_norm),
+ adamw,
+ )
+
+ unet_state = train_state.TrainState.create(apply_fn=unet.__call__, params=unet_params, tx=optimizer)
+ text_encoder_state = train_state.TrainState.create(
+ apply_fn=text_encoder.__call__, params=text_encoder.params, tx=optimizer
+ )
+
+ noise_scheduler = FlaxDDPMScheduler(
+ beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000
+ )
+ noise_scheduler_state = noise_scheduler.create_state()
+
+ # Initialize our training
+ train_rngs = jax.random.split(rng, jax.local_device_count())
+
+ def train_step(unet_state, text_encoder_state, vae_params, batch, train_rng):
+ dropout_rng, sample_rng, new_train_rng = jax.random.split(train_rng, 3)
+
+ if args.train_text_encoder:
+ params = {"text_encoder": text_encoder_state.params, "unet": unet_state.params}
+ else:
+ params = {"unet": unet_state.params}
+
+ def compute_loss(params):
+ # Convert images to latent space
+ vae_outputs = vae.apply(
+ {"params": vae_params}, batch["pixel_values"], deterministic=True, method=vae.encode
+ )
+ latents = vae_outputs.latent_dist.sample(sample_rng)
+ # (NHWC) -> (NCHW)
+ latents = jnp.transpose(latents, (0, 3, 1, 2))
+ latents = latents * vae.config.scaling_factor
+
+ # Sample noise that we'll add to the latents
+ noise_rng, timestep_rng = jax.random.split(sample_rng)
+ noise = jax.random.normal(noise_rng, latents.shape)
+ # Sample a random timestep for each image
+ bsz = latents.shape[0]
+ timesteps = jax.random.randint(
+ timestep_rng,
+ (bsz,),
+ 0,
+ noise_scheduler.config.num_train_timesteps,
+ )
+
+ # Add noise to the latents according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ noisy_latents = noise_scheduler.add_noise(noise_scheduler_state, latents, noise, timesteps)
+
+ # Get the text embedding for conditioning
+ if args.train_text_encoder:
+ encoder_hidden_states = text_encoder_state.apply_fn(
+ batch["input_ids"], params=params["text_encoder"], dropout_rng=dropout_rng, train=True
+ )[0]
+ else:
+ encoder_hidden_states = text_encoder(
+ batch["input_ids"], params=text_encoder_state.params, train=False
+ )[0]
+
+ # Predict the noise residual
+ model_pred = unet.apply(
+ {"params": params["unet"]}, noisy_latents, timesteps, encoder_hidden_states, train=True
+ ).sample
+
+ # Get the target for loss depending on the prediction type
+ if noise_scheduler.config.prediction_type == "epsilon":
+ target = noise
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ target = noise_scheduler.get_velocity(noise_scheduler_state, latents, noise, timesteps)
+ else:
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
+
+ if args.with_prior_preservation:
+ # Chunk the noise and noise_pred into two parts and compute the loss on each part separately.
+ model_pred, model_pred_prior = jnp.split(model_pred, 2, axis=0)
+ target, target_prior = jnp.split(target, 2, axis=0)
+
+ # Compute instance loss
+ loss = (target - model_pred) ** 2
+ loss = loss.mean()
+
+ # Compute prior loss
+ prior_loss = (target_prior - model_pred_prior) ** 2
+ prior_loss = prior_loss.mean()
+
+ # Add the prior loss to the instance loss.
+ loss = loss + args.prior_loss_weight * prior_loss
+ else:
+ loss = (target - model_pred) ** 2
+ loss = loss.mean()
+
+ return loss
+
+ grad_fn = jax.value_and_grad(compute_loss)
+ loss, grad = grad_fn(params)
+ grad = jax.lax.pmean(grad, "batch")
+
+ new_unet_state = unet_state.apply_gradients(grads=grad["unet"])
+ if args.train_text_encoder:
+ new_text_encoder_state = text_encoder_state.apply_gradients(grads=grad["text_encoder"])
+ else:
+ new_text_encoder_state = text_encoder_state
+
+ metrics = {"loss": loss}
+ metrics = jax.lax.pmean(metrics, axis_name="batch")
+
+ return new_unet_state, new_text_encoder_state, metrics, new_train_rng
+
+ # Create parallel version of the train step
+ p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0, 1))
+
+ # Replicate the train state on each device
+ unet_state = jax_utils.replicate(unet_state)
+ text_encoder_state = jax_utils.replicate(text_encoder_state)
+ vae_params = jax_utils.replicate(vae_params)
+
+ # Train!
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader))
+
+ # Scheduler and math around the number of training steps.
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel & distributed) = {total_train_batch_size}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+
+ def checkpoint(step=None):
+ # Create the pipeline using the trained modules and save it.
+ scheduler, _ = FlaxPNDMScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
+ safety_checker = FlaxStableDiffusionSafetyChecker.from_pretrained(
+ "CompVis/stable-diffusion-safety-checker", from_pt=True
+ )
+ pipeline = FlaxStableDiffusionPipeline(
+ text_encoder=text_encoder,
+ vae=vae,
+ unet=unet,
+ tokenizer=tokenizer,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32"),
+ )
+
+ outdir = os.path.join(args.output_dir, str(step)) if step else args.output_dir
+ pipeline.save_pretrained(
+ outdir,
+ params={
+ "text_encoder": get_params_to_save(text_encoder_state.params),
+ "vae": get_params_to_save(vae_params),
+ "unet": get_params_to_save(unet_state.params),
+ "safety_checker": safety_checker.params,
+ },
+ )
+
+ if args.push_to_hub:
+ message = f"checkpoint-{step}" if step is not None else "End of training"
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message=message,
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ global_step = 0
+
+ epochs = tqdm(range(args.num_train_epochs), desc="Epoch ... ", position=0)
+ for epoch in epochs:
+ # ======================== Training ================================
+
+ train_metrics = []
+
+ steps_per_epoch = len(train_dataset) // total_train_batch_size
+ train_step_progress_bar = tqdm(total=steps_per_epoch, desc="Training...", position=1, leave=False)
+ # train
+ for batch in train_dataloader:
+ batch = shard(batch)
+ unet_state, text_encoder_state, train_metric, train_rngs = p_train_step(
+ unet_state, text_encoder_state, vae_params, batch, train_rngs
+ )
+ train_metrics.append(train_metric)
+
+ train_step_progress_bar.update(jax.local_device_count())
+
+ global_step += 1
+ if jax.process_index() == 0 and args.save_steps and global_step % args.save_steps == 0:
+ checkpoint(global_step)
+ if global_step >= args.max_train_steps:
+ break
+
+ train_metric = jax_utils.unreplicate(train_metric)
+
+ train_step_progress_bar.close()
+ epochs.write(f"Epoch... ({epoch + 1}/{args.num_train_epochs} | Loss: {train_metric['loss']})")
+
+ if jax.process_index() == 0:
+ checkpoint()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/diffusers/examples/dreambooth/train_dreambooth_flux.py b/diffusers/examples/dreambooth/train_dreambooth_flux.py
new file mode 100644
index 0000000000000000000000000000000000000000..8e0f4e09a461b219af169b5b77b84121692b369a
--- /dev/null
+++ b/diffusers/examples/dreambooth/train_dreambooth_flux.py
@@ -0,0 +1,1791 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+import argparse
+import copy
+import gc
+import itertools
+import logging
+import math
+import os
+import random
+import shutil
+import warnings
+from contextlib import nullcontext
+from pathlib import Path
+
+import numpy as np
+import torch
+import torch.utils.checkpoint
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
+from huggingface_hub import create_repo, upload_folder
+from huggingface_hub.utils import insecure_hashlib
+from PIL import Image
+from PIL.ImageOps import exif_transpose
+from torch.utils.data import Dataset
+from torchvision import transforms
+from torchvision.transforms.functional import crop
+from tqdm.auto import tqdm
+from transformers import CLIPTextModelWithProjection, CLIPTokenizer, PretrainedConfig, T5EncoderModel, T5TokenizerFast
+
+import diffusers
+from diffusers import (
+ AutoencoderKL,
+ FlowMatchEulerDiscreteScheduler,
+ FluxPipeline,
+ FluxTransformer2DModel,
+)
+from diffusers.optimization import get_scheduler
+from diffusers.training_utils import compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3
+from diffusers.utils import (
+ check_min_version,
+ is_wandb_available,
+)
+from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
+from diffusers.utils.torch_utils import is_compiled_module
+
+
+if is_wandb_available():
+ import wandb
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.31.0.dev0")
+
+logger = get_logger(__name__)
+
+
+def save_model_card(
+ repo_id: str,
+ images=None,
+ base_model: str = None,
+ train_text_encoder=False,
+ instance_prompt=None,
+ validation_prompt=None,
+ repo_folder=None,
+):
+ widget_dict = []
+ if images is not None:
+ for i, image in enumerate(images):
+ image.save(os.path.join(repo_folder, f"image_{i}.png"))
+ widget_dict.append(
+ {"text": validation_prompt if validation_prompt else " ", "output": {"url": f"image_{i}.png"}}
+ )
+
+ model_description = f"""
+# Flux [dev] DreamBooth - {repo_id}
+
+
+
+## Model description
+
+These are {repo_id} DreamBooth weights for {base_model}.
+
+The weights were trained using [DreamBooth](https://dreambooth.github.io/) with the [Flux diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/README_flux.md).
+
+Was the text encoder fine-tuned? {train_text_encoder}.
+
+## Trigger words
+
+You should use `{instance_prompt}` to trigger the image generation.
+
+## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers)
+
+```py
+from diffusers import AutoPipelineForText2Image
+import torch
+pipeline = AutoPipelineForText2Image.from_pretrained('{repo_id}', torch_dtype=torch.bfloat16).to('cuda')
+image = pipeline('{validation_prompt if validation_prompt else instance_prompt}').images[0]
+```
+
+## License
+
+Please adhere to the licensing terms as described [here](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md).
+"""
+ model_card = load_or_create_model_card(
+ repo_id_or_path=repo_id,
+ from_training=True,
+ license="other",
+ base_model=base_model,
+ prompt=instance_prompt,
+ model_description=model_description,
+ widget=widget_dict,
+ )
+ tags = [
+ "text-to-image",
+ "diffusers-training",
+ "diffusers",
+ "flux",
+ "flux-diffusers",
+ "template:sd-lora",
+ ]
+
+ model_card = populate_model_card(model_card, tags=tags)
+ model_card.save(os.path.join(repo_folder, "README.md"))
+
+
+def load_text_encoders(class_one, class_two):
+ text_encoder_one = class_one.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
+ )
+ text_encoder_two = class_two.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant
+ )
+ return text_encoder_one, text_encoder_two
+
+
+def log_validation(
+ pipeline,
+ args,
+ accelerator,
+ pipeline_args,
+ epoch,
+ torch_dtype,
+ is_final_validation=False,
+):
+ logger.info(
+ f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
+ f" {args.validation_prompt}."
+ )
+ pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)
+ pipeline.set_progress_bar_config(disable=True)
+
+ # run inference
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
+ # autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()
+ autocast_ctx = nullcontext()
+
+ with autocast_ctx:
+ images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)]
+
+ for tracker in accelerator.trackers:
+ phase_name = "test" if is_final_validation else "validation"
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images(phase_name, np_images, epoch, dataformats="NHWC")
+ if tracker.name == "wandb":
+ tracker.log(
+ {
+ phase_name: [
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images)
+ ]
+ }
+ )
+
+ del pipeline
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+
+ return images
+
+
+def import_model_class_from_model_name_or_path(
+ pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
+):
+ text_encoder_config = PretrainedConfig.from_pretrained(
+ pretrained_model_name_or_path, subfolder=subfolder, revision=revision
+ )
+ model_class = text_encoder_config.architectures[0]
+ if model_class == "CLIPTextModel":
+ from transformers import CLIPTextModel
+
+ return CLIPTextModel
+ elif model_class == "T5EncoderModel":
+ from transformers import T5EncoderModel
+
+ return T5EncoderModel
+ else:
+ raise ValueError(f"{model_class} is not supported.")
+
+
+def parse_args(input_args=None):
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--variant",
+ type=str,
+ default=None,
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
+ )
+ parser.add_argument(
+ "--dataset_name",
+ type=str,
+ default=None,
+ help=(
+ "The name of the Dataset (from the HuggingFace hub) containing the training data of instance images (could be your own, possibly private,"
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
+ " or to a folder containing files that 🤗 Datasets can understand."
+ ),
+ )
+ parser.add_argument(
+ "--dataset_config_name",
+ type=str,
+ default=None,
+ help="The config of the Dataset, leave as None if there's only one config.",
+ )
+ parser.add_argument(
+ "--instance_data_dir",
+ type=str,
+ default=None,
+ help=("A folder containing the training data. "),
+ )
+
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+
+ parser.add_argument(
+ "--image_column",
+ type=str,
+ default="image",
+ help="The column of the dataset containing the target image. By "
+ "default, the standard Image Dataset maps out 'file_name' "
+ "to 'image'.",
+ )
+ parser.add_argument(
+ "--caption_column",
+ type=str,
+ default=None,
+ help="The column of the dataset containing the instance prompt for each image",
+ )
+
+ parser.add_argument("--repeats", type=int, default=1, help="How many times to repeat the training data.")
+
+ parser.add_argument(
+ "--class_data_dir",
+ type=str,
+ default=None,
+ required=False,
+ help="A folder containing the training data of class images.",
+ )
+ parser.add_argument(
+ "--instance_prompt",
+ type=str,
+ default=None,
+ required=True,
+ help="The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'",
+ )
+ parser.add_argument(
+ "--class_prompt",
+ type=str,
+ default=None,
+ help="The prompt to specify images in the same class as provided instance images.",
+ )
+ parser.add_argument(
+ "--max_sequence_length",
+ type=int,
+ default=77,
+ help="Maximum sequence length to use with with the T5 text encoder",
+ )
+ parser.add_argument(
+ "--validation_prompt",
+ type=str,
+ default=None,
+ help="A prompt that is used during validation to verify that the model is learning.",
+ )
+ parser.add_argument(
+ "--num_validation_images",
+ type=int,
+ default=4,
+ help="Number of images that should be generated during validation with `validation_prompt`.",
+ )
+ parser.add_argument(
+ "--validation_epochs",
+ type=int,
+ default=50,
+ help=(
+ "Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt"
+ " `args.validation_prompt` multiple times: `args.num_validation_images`."
+ ),
+ )
+ parser.add_argument(
+ "--with_prior_preservation",
+ default=False,
+ action="store_true",
+ help="Flag to add prior preservation loss.",
+ )
+ parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.")
+ parser.add_argument(
+ "--num_class_images",
+ type=int,
+ default=100,
+ help=(
+ "Minimal class images for prior preservation loss. If there are not enough images already present in"
+ " class_data_dir, additional images will be sampled with class_prompt."
+ ),
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="flux-dreambooth",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--center_crop",
+ default=False,
+ action="store_true",
+ help=(
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
+ " cropped. The images will be resized to the resolution first before cropping."
+ ),
+ )
+ parser.add_argument(
+ "--random_flip",
+ action="store_true",
+ help="whether to randomly flip images horizontally",
+ )
+ parser.add_argument(
+ "--train_text_encoder",
+ action="store_true",
+ help="Whether to train the text encoder. If set, the text encoder should be float32 precision.",
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument(
+ "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=1)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
+ " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+
+ parser.add_argument(
+ "--guidance_scale",
+ type=float,
+ default=3.5,
+ help="the FLUX.1 dev variant is a guidance distilled model",
+ )
+
+ parser.add_argument(
+ "--text_encoder_lr",
+ type=float,
+ default=5e-6,
+ help="Text encoder learning rate to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--lr_num_cycles",
+ type=int,
+ default=1,
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
+ )
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ parser.add_argument(
+ "--weighting_scheme",
+ type=str,
+ default="none",
+ choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"],
+ help=('We default to the "none" weighting scheme for uniform sampling and uniform loss'),
+ )
+ parser.add_argument(
+ "--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme."
+ )
+ parser.add_argument(
+ "--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme."
+ )
+ parser.add_argument(
+ "--mode_scale",
+ type=float,
+ default=1.29,
+ help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.",
+ )
+ parser.add_argument(
+ "--optimizer",
+ type=str,
+ default="AdamW",
+ help=('The optimizer type to use. Choose between ["AdamW", "prodigy"]'),
+ )
+
+ parser.add_argument(
+ "--use_8bit_adam",
+ action="store_true",
+ help="Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW",
+ )
+
+ parser.add_argument(
+ "--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam and Prodigy optimizers."
+ )
+ parser.add_argument(
+ "--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam and Prodigy optimizers."
+ )
+ parser.add_argument(
+ "--prodigy_beta3",
+ type=float,
+ default=None,
+ help="coefficients for computing the Prodigy stepsize using running averages. If set to None, "
+ "uses the value of square root of beta2. Ignored if optimizer is adamW",
+ )
+ parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params")
+ parser.add_argument(
+ "--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder"
+ )
+
+ parser.add_argument(
+ "--adam_epsilon",
+ type=float,
+ default=1e-08,
+ help="Epsilon value for the Adam optimizer and Prodigy optimizers.",
+ )
+
+ parser.add_argument(
+ "--prodigy_use_bias_correction",
+ type=bool,
+ default=True,
+ help="Turn on Adam's bias correction. True by default. Ignored if optimizer is adamW",
+ )
+ parser.add_argument(
+ "--prodigy_safeguard_warmup",
+ type=bool,
+ default=True,
+ help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. "
+ "Ignored if optimizer is adamW",
+ )
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument(
+ "--prior_generation_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp32", "fp16", "bf16"],
+ help=(
+ "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32."
+ ),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+
+ if input_args is not None:
+ args = parser.parse_args(input_args)
+ else:
+ args = parser.parse_args()
+
+ if args.dataset_name is None and args.instance_data_dir is None:
+ raise ValueError("Specify either `--dataset_name` or `--instance_data_dir`")
+
+ if args.dataset_name is not None and args.instance_data_dir is not None:
+ raise ValueError("Specify only one of `--dataset_name` or `--instance_data_dir`")
+
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ if args.with_prior_preservation:
+ if args.class_data_dir is None:
+ raise ValueError("You must specify a data directory for class images.")
+ if args.class_prompt is None:
+ raise ValueError("You must specify prompt for class images.")
+ else:
+ # logger is not available yet
+ if args.class_data_dir is not None:
+ warnings.warn("You need not use --class_data_dir without --with_prior_preservation.")
+ if args.class_prompt is not None:
+ warnings.warn("You need not use --class_prompt without --with_prior_preservation.")
+
+ return args
+
+
+class DreamBoothDataset(Dataset):
+ """
+ A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
+ It pre-processes the images.
+ """
+
+ def __init__(
+ self,
+ instance_data_root,
+ instance_prompt,
+ class_prompt,
+ class_data_root=None,
+ class_num=None,
+ size=1024,
+ repeats=1,
+ center_crop=False,
+ ):
+ self.size = size
+ self.center_crop = center_crop
+
+ self.instance_prompt = instance_prompt
+ self.custom_instance_prompts = None
+ self.class_prompt = class_prompt
+
+ # if --dataset_name is provided or a metadata jsonl file is provided in the local --instance_data directory,
+ # we load the training data using load_dataset
+ if args.dataset_name is not None:
+ try:
+ from datasets import load_dataset
+ except ImportError:
+ raise ImportError(
+ "You are trying to load your data using the datasets library. If you wish to train using custom "
+ "captions please install the datasets library: `pip install datasets`. If you wish to load a "
+ "local folder containing images only, specify --instance_data_dir instead."
+ )
+ # Downloading and loading a dataset from the hub.
+ # See more about loading custom images at
+ # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script
+ dataset = load_dataset(
+ args.dataset_name,
+ args.dataset_config_name,
+ cache_dir=args.cache_dir,
+ )
+ # Preprocessing the datasets.
+ column_names = dataset["train"].column_names
+
+ # 6. Get the column names for input/target.
+ if args.image_column is None:
+ image_column = column_names[0]
+ logger.info(f"image column defaulting to {image_column}")
+ else:
+ image_column = args.image_column
+ if image_column not in column_names:
+ raise ValueError(
+ f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
+ )
+ instance_images = dataset["train"][image_column]
+
+ if args.caption_column is None:
+ logger.info(
+ "No caption column provided, defaulting to instance_prompt for all images. If your dataset "
+ "contains captions/prompts for the images, make sure to specify the "
+ "column as --caption_column"
+ )
+ self.custom_instance_prompts = None
+ else:
+ if args.caption_column not in column_names:
+ raise ValueError(
+ f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
+ )
+ custom_instance_prompts = dataset["train"][args.caption_column]
+ # create final list of captions according to --repeats
+ self.custom_instance_prompts = []
+ for caption in custom_instance_prompts:
+ self.custom_instance_prompts.extend(itertools.repeat(caption, repeats))
+ else:
+ self.instance_data_root = Path(instance_data_root)
+ if not self.instance_data_root.exists():
+ raise ValueError("Instance images root doesn't exists.")
+
+ instance_images = [Image.open(path) for path in list(Path(instance_data_root).iterdir())]
+ self.custom_instance_prompts = None
+
+ self.instance_images = []
+ for img in instance_images:
+ self.instance_images.extend(itertools.repeat(img, repeats))
+
+ self.pixel_values = []
+ train_resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR)
+ train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size)
+ train_flip = transforms.RandomHorizontalFlip(p=1.0)
+ train_transforms = transforms.Compose(
+ [
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+ )
+ for image in self.instance_images:
+ image = exif_transpose(image)
+ if not image.mode == "RGB":
+ image = image.convert("RGB")
+ image = train_resize(image)
+ if args.random_flip and random.random() < 0.5:
+ # flip
+ image = train_flip(image)
+ if args.center_crop:
+ y1 = max(0, int(round((image.height - args.resolution) / 2.0)))
+ x1 = max(0, int(round((image.width - args.resolution) / 2.0)))
+ image = train_crop(image)
+ else:
+ y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution))
+ image = crop(image, y1, x1, h, w)
+ image = train_transforms(image)
+ self.pixel_values.append(image)
+
+ self.num_instance_images = len(self.instance_images)
+ self._length = self.num_instance_images
+
+ if class_data_root is not None:
+ self.class_data_root = Path(class_data_root)
+ self.class_data_root.mkdir(parents=True, exist_ok=True)
+ self.class_images_path = list(self.class_data_root.iterdir())
+ if class_num is not None:
+ self.num_class_images = min(len(self.class_images_path), class_num)
+ else:
+ self.num_class_images = len(self.class_images_path)
+ self._length = max(self.num_class_images, self.num_instance_images)
+ else:
+ self.class_data_root = None
+
+ self.image_transforms = transforms.Compose(
+ [
+ transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
+ transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+ )
+
+ def __len__(self):
+ return self._length
+
+ def __getitem__(self, index):
+ example = {}
+ instance_image = self.pixel_values[index % self.num_instance_images]
+ example["instance_images"] = instance_image
+
+ if self.custom_instance_prompts:
+ caption = self.custom_instance_prompts[index % self.num_instance_images]
+ if caption:
+ example["instance_prompt"] = caption
+ else:
+ example["instance_prompt"] = self.instance_prompt
+
+ else: # custom prompts were provided, but length does not match size of image dataset
+ example["instance_prompt"] = self.instance_prompt
+
+ if self.class_data_root:
+ class_image = Image.open(self.class_images_path[index % self.num_class_images])
+ class_image = exif_transpose(class_image)
+
+ if not class_image.mode == "RGB":
+ class_image = class_image.convert("RGB")
+ example["class_images"] = self.image_transforms(class_image)
+ example["class_prompt"] = self.class_prompt
+
+ return example
+
+
+def collate_fn(examples, with_prior_preservation=False):
+ pixel_values = [example["instance_images"] for example in examples]
+ prompts = [example["instance_prompt"] for example in examples]
+
+ # Concat class and instance examples for prior preservation.
+ # We do this to avoid doing two forward passes.
+ if with_prior_preservation:
+ pixel_values += [example["class_images"] for example in examples]
+ prompts += [example["class_prompt"] for example in examples]
+
+ pixel_values = torch.stack(pixel_values)
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
+
+ batch = {"pixel_values": pixel_values, "prompts": prompts}
+ return batch
+
+
+class PromptDataset(Dataset):
+ "A simple dataset to prepare the prompts to generate class images on multiple GPUs."
+
+ def __init__(self, prompt, num_samples):
+ self.prompt = prompt
+ self.num_samples = num_samples
+
+ def __len__(self):
+ return self.num_samples
+
+ def __getitem__(self, index):
+ example = {}
+ example["prompt"] = self.prompt
+ example["index"] = index
+ return example
+
+
+def tokenize_prompt(tokenizer, prompt, max_sequence_length):
+ text_inputs = tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ return_length=False,
+ return_overflowing_tokens=False,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ return text_input_ids
+
+
+def _encode_prompt_with_t5(
+ text_encoder,
+ tokenizer,
+ max_sequence_length=512,
+ prompt=None,
+ num_images_per_prompt=1,
+ device=None,
+ text_input_ids=None,
+):
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ if tokenizer is not None:
+ text_inputs = tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ return_length=False,
+ return_overflowing_tokens=False,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ else:
+ if text_input_ids is None:
+ raise ValueError("text_input_ids must be provided when the tokenizer is not specified")
+
+ prompt_embeds = text_encoder(text_input_ids.to(device))[0]
+
+ dtype = text_encoder.dtype
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ _, seq_len, _ = prompt_embeds.shape
+
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ return prompt_embeds
+
+
+def _encode_prompt_with_clip(
+ text_encoder,
+ tokenizer,
+ prompt: str,
+ device=None,
+ text_input_ids=None,
+ num_images_per_prompt: int = 1,
+):
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ if tokenizer is not None:
+ text_inputs = tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=77,
+ truncation=True,
+ return_overflowing_tokens=False,
+ return_length=False,
+ return_tensors="pt",
+ )
+
+ text_input_ids = text_inputs.input_ids
+ else:
+ if text_input_ids is None:
+ raise ValueError("text_input_ids must be provided when the tokenizer is not specified")
+
+ prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False)
+
+ # Use pooled output of CLIPTextModel
+ prompt_embeds = prompt_embeds.pooler_output
+ prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device)
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
+
+ return prompt_embeds
+
+
+def encode_prompt(
+ text_encoders,
+ tokenizers,
+ prompt: str,
+ max_sequence_length,
+ device=None,
+ num_images_per_prompt: int = 1,
+ text_input_ids_list=None,
+):
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+ dtype = text_encoders[0].dtype
+ device = device if device is not None else text_encoders[1].device
+ pooled_prompt_embeds = _encode_prompt_with_clip(
+ text_encoder=text_encoders[0],
+ tokenizer=tokenizers[0],
+ prompt=prompt,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ text_input_ids=text_input_ids_list[0] if text_input_ids_list else None,
+ )
+
+ prompt_embeds = _encode_prompt_with_t5(
+ text_encoder=text_encoders[1],
+ tokenizer=tokenizers[1],
+ max_sequence_length=max_sequence_length,
+ prompt=prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ text_input_ids=text_input_ids_list[1] if text_input_ids_list else None,
+ )
+
+ text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
+ text_ids = text_ids.repeat(num_images_per_prompt, 1, 1)
+
+ return prompt_embeds, pooled_prompt_embeds, text_ids
+
+
+def main(args):
+ if args.report_to == "wandb" and args.hub_token is not None:
+ raise ValueError(
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
+ " Please use `huggingface-cli login` to authenticate with the Hub."
+ )
+
+ if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
+ # due to pytorch#99272, MPS does not yet support bfloat16.
+ raise ValueError(
+ "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
+ )
+
+ logging_dir = Path(args.output_dir, args.logging_dir)
+
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+ kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ kwargs_handlers=[kwargs],
+ )
+
+ # Disable AMP for MPS.
+ if torch.backends.mps.is_available():
+ accelerator.native_amp = False
+
+ if args.report_to == "wandb":
+ if not is_wandb_available():
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Generate class images if prior preservation is enabled.
+ if args.with_prior_preservation:
+ class_images_dir = Path(args.class_data_dir)
+ if not class_images_dir.exists():
+ class_images_dir.mkdir(parents=True)
+ cur_class_images = len(list(class_images_dir.iterdir()))
+
+ if cur_class_images < args.num_class_images:
+ has_supported_fp16_accelerator = torch.cuda.is_available() or torch.backends.mps.is_available()
+ torch_dtype = torch.float16 if has_supported_fp16_accelerator else torch.float32
+ if args.prior_generation_precision == "fp32":
+ torch_dtype = torch.float32
+ elif args.prior_generation_precision == "fp16":
+ torch_dtype = torch.float16
+ elif args.prior_generation_precision == "bf16":
+ torch_dtype = torch.bfloat16
+ pipeline = FluxPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ torch_dtype=torch_dtype,
+ revision=args.revision,
+ variant=args.variant,
+ )
+ pipeline.set_progress_bar_config(disable=True)
+
+ num_new_images = args.num_class_images - cur_class_images
+ logger.info(f"Number of class images to sample: {num_new_images}.")
+
+ sample_dataset = PromptDataset(args.class_prompt, num_new_images)
+ sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
+
+ sample_dataloader = accelerator.prepare(sample_dataloader)
+ pipeline.to(accelerator.device)
+
+ for example in tqdm(
+ sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
+ ):
+ images = pipeline(example["prompt"]).images
+
+ for i, image in enumerate(images):
+ hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest()
+ image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
+ image.save(image_filename)
+
+ del pipeline
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name,
+ exist_ok=True,
+ ).repo_id
+
+ # Load the tokenizers
+ tokenizer_one = CLIPTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="tokenizer",
+ revision=args.revision,
+ )
+ tokenizer_two = T5TokenizerFast.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="tokenizer_2",
+ revision=args.revision,
+ )
+
+ # import correct text encoder classes
+ text_encoder_cls_one = import_model_class_from_model_name_or_path(
+ args.pretrained_model_name_or_path, args.revision
+ )
+ text_encoder_cls_two = import_model_class_from_model_name_or_path(
+ args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2"
+ )
+
+ # Load scheduler and models
+ noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="scheduler"
+ )
+ noise_scheduler_copy = copy.deepcopy(noise_scheduler)
+ text_encoder_one, text_encoder_two = load_text_encoders(text_encoder_cls_one, text_encoder_cls_two)
+ vae = AutoencoderKL.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="vae",
+ revision=args.revision,
+ variant=args.variant,
+ )
+ transformer = FluxTransformer2DModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="transformer", revision=args.revision, variant=args.variant
+ )
+
+ transformer.requires_grad_(True)
+ vae.requires_grad_(False)
+ if args.train_text_encoder:
+ text_encoder_one.requires_grad_(True)
+ text_encoder_two.requires_grad_(False)
+ else:
+ text_encoder_one.requires_grad_(False)
+ text_encoder_two.requires_grad_(False)
+
+ # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
+ # as these weights are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:
+ # due to pytorch#99272, MPS does not yet support bfloat16.
+ raise ValueError(
+ "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
+ )
+
+ vae.to(accelerator.device, dtype=weight_dtype)
+ if not args.train_text_encoder:
+ text_encoder_one.to(accelerator.device, dtype=weight_dtype)
+ text_encoder_two.to(accelerator.device, dtype=weight_dtype)
+
+ if args.gradient_checkpointing:
+ transformer.enable_gradient_checkpointing()
+ if args.train_text_encoder:
+ text_encoder_one.gradient_checkpointing_enable()
+
+ def unwrap_model(model):
+ model = accelerator.unwrap_model(model)
+ model = model._orig_mod if is_compiled_module(model) else model
+ return model
+
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
+ def save_model_hook(models, weights, output_dir):
+ if accelerator.is_main_process:
+ for i, model in enumerate(models):
+ if isinstance(unwrap_model(model), FluxTransformer2DModel):
+ unwrap_model(model).save_pretrained(os.path.join(output_dir, "transformer"))
+ elif isinstance(unwrap_model(model), (CLIPTextModelWithProjection, T5EncoderModel)):
+ if isinstance(unwrap_model(model), CLIPTextModelWithProjection):
+ unwrap_model(model).save_pretrained(os.path.join(output_dir, "text_encoder"))
+ else:
+ unwrap_model(model).save_pretrained(os.path.join(output_dir, "text_encoder_2"))
+ else:
+ raise ValueError(f"Wrong model supplied: {type(model)=}.")
+
+ # make sure to pop weight so that corresponding model is not saved again
+ weights.pop()
+
+ def load_model_hook(models, input_dir):
+ for _ in range(len(models)):
+ # pop models so that they are not loaded again
+ model = models.pop()
+
+ # load diffusers style into model
+ if isinstance(unwrap_model(model), FluxTransformer2DModel):
+ load_model = FluxTransformer2DModel.from_pretrained(input_dir, subfolder="transformer")
+ model.register_to_config(**load_model.config)
+
+ model.load_state_dict(load_model.state_dict())
+ elif isinstance(unwrap_model(model), (CLIPTextModelWithProjection, T5EncoderModel)):
+ try:
+ load_model = CLIPTextModelWithProjection.from_pretrained(input_dir, subfolder="text_encoder")
+ model(**load_model.config)
+ model.load_state_dict(load_model.state_dict())
+ except Exception:
+ try:
+ load_model = T5EncoderModel.from_pretrained(input_dir, subfolder="text_encoder_2")
+ model(**load_model.config)
+ model.load_state_dict(load_model.state_dict())
+ except Exception:
+ raise ValueError(f"Couldn't load the model of type: ({type(model)}).")
+ else:
+ raise ValueError(f"Unsupported model found: {type(model)=}")
+
+ del load_model
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32 and torch.cuda.is_available():
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ # Optimization parameters
+ transformer_parameters_with_lr = {"params": transformer.parameters(), "lr": args.learning_rate}
+ if args.train_text_encoder:
+ # different learning rate for text encoder and unet
+ text_parameters_one_with_lr = {
+ "params": text_encoder_one.parameters(),
+ "weight_decay": args.adam_weight_decay_text_encoder,
+ "lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate,
+ }
+ params_to_optimize = [
+ transformer_parameters_with_lr,
+ text_parameters_one_with_lr,
+ ]
+ else:
+ params_to_optimize = [transformer_parameters_with_lr]
+
+ # Optimizer creation
+ if not (args.optimizer.lower() == "prodigy" or args.optimizer.lower() == "adamw"):
+ logger.warning(
+ f"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy]."
+ "Defaulting to adamW"
+ )
+ args.optimizer = "adamw"
+
+ if args.use_8bit_adam and not args.optimizer.lower() == "adamw":
+ logger.warning(
+ f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was "
+ f"set to {args.optimizer.lower()}"
+ )
+
+ if args.optimizer.lower() == "adamw":
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
+ )
+
+ optimizer_class = bnb.optim.AdamW8bit
+ else:
+ optimizer_class = torch.optim.AdamW
+
+ optimizer = optimizer_class(
+ params_to_optimize,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ if args.optimizer.lower() == "prodigy":
+ try:
+ import prodigyopt
+ except ImportError:
+ raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`")
+
+ optimizer_class = prodigyopt.Prodigy
+
+ if args.learning_rate <= 0.1:
+ logger.warning(
+ "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0"
+ )
+ if args.train_text_encoder and args.text_encoder_lr:
+ logger.warning(
+ f"Learning rates were provided both for the transformer and the text encoder- e.g. text_encoder_lr:"
+ f" {args.text_encoder_lr} and learning_rate: {args.learning_rate}. "
+ f"When using prodigy only learning_rate is used as the initial learning rate."
+ )
+ # changes the learning rate of text_encoder_parameters_one and text_encoder_parameters_two to be
+ # --learning_rate
+ params_to_optimize[1]["lr"] = args.learning_rate
+ params_to_optimize[2]["lr"] = args.learning_rate
+
+ optimizer = optimizer_class(
+ params_to_optimize,
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ beta3=args.prodigy_beta3,
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ decouple=args.prodigy_decouple,
+ use_bias_correction=args.prodigy_use_bias_correction,
+ safeguard_warmup=args.prodigy_safeguard_warmup,
+ )
+
+ # Dataset and DataLoaders creation:
+ train_dataset = DreamBoothDataset(
+ instance_data_root=args.instance_data_dir,
+ instance_prompt=args.instance_prompt,
+ class_prompt=args.class_prompt,
+ class_data_root=args.class_data_dir if args.with_prior_preservation else None,
+ class_num=args.num_class_images,
+ size=args.resolution,
+ repeats=args.repeats,
+ center_crop=args.center_crop,
+ )
+
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset,
+ batch_size=args.train_batch_size,
+ shuffle=True,
+ collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),
+ num_workers=args.dataloader_num_workers,
+ )
+
+ if not args.train_text_encoder:
+ tokenizers = [tokenizer_one, tokenizer_two]
+ text_encoders = [text_encoder_one, text_encoder_two]
+
+ def compute_text_embeddings(prompt, text_encoders, tokenizers):
+ with torch.no_grad():
+ prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt(
+ text_encoders, tokenizers, prompt, args.max_sequence_length
+ )
+ prompt_embeds = prompt_embeds.to(accelerator.device)
+ pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device)
+ text_ids = text_ids.to(accelerator.device)
+ return prompt_embeds, pooled_prompt_embeds, text_ids
+
+ # If no type of tuning is done on the text_encoder and custom instance prompts are NOT
+ # provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid
+ # the redundant encoding.
+ if not args.train_text_encoder and not train_dataset.custom_instance_prompts:
+ instance_prompt_hidden_states, instance_pooled_prompt_embeds, instance_text_ids = compute_text_embeddings(
+ args.instance_prompt, text_encoders, tokenizers
+ )
+
+ # Handle class prompt for prior-preservation.
+ if args.with_prior_preservation:
+ if not args.train_text_encoder:
+ class_prompt_hidden_states, class_pooled_prompt_embeds, class_text_ids = compute_text_embeddings(
+ args.class_prompt, text_encoders, tokenizers
+ )
+
+ # Clear the memory here
+ if not args.train_text_encoder and not train_dataset.custom_instance_prompts:
+ del tokenizers, text_encoders
+ # Explicitly delete the objects as well, otherwise only the lists are deleted and the original references remain, preventing garbage collection
+ del text_encoder_one, text_encoder_two
+ gc.collect()
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+
+ # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
+ # pack the statically computed variables appropriately here. This is so that we don't
+ # have to pass them to the dataloader.
+
+ if not train_dataset.custom_instance_prompts:
+ if not args.train_text_encoder:
+ prompt_embeds = instance_prompt_hidden_states
+ pooled_prompt_embeds = instance_pooled_prompt_embeds
+ text_ids = instance_text_ids
+ if args.with_prior_preservation:
+ prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0)
+ pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, class_pooled_prompt_embeds], dim=0)
+ text_ids = torch.cat([text_ids, class_text_ids], dim=0)
+ # if we're optimizing the text encoder (both if instance prompt is used for all images or custom prompts) we need to tokenize and encode the
+ # batch prompts on all training steps
+ else:
+ tokens_one = tokenize_prompt(tokenizer_one, args.instance_prompt, max_sequence_length=77)
+ tokens_two = tokenize_prompt(tokenizer_two, args.instance_prompt, max_sequence_length=512)
+ if args.with_prior_preservation:
+ class_tokens_one = tokenize_prompt(tokenizer_one, args.class_prompt, max_sequence_length=77)
+ class_tokens_two = tokenize_prompt(tokenizer_two, args.class_prompt, max_sequence_length=512)
+ tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0)
+ tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0)
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
+ num_cycles=args.lr_num_cycles,
+ power=args.lr_power,
+ )
+
+ # Prepare everything with our `accelerator`.
+ if args.train_text_encoder:
+ (
+ transformer,
+ text_encoder_one,
+ optimizer,
+ train_dataloader,
+ lr_scheduler,
+ ) = accelerator.prepare(
+ transformer,
+ text_encoder_one,
+ optimizer,
+ train_dataloader,
+ lr_scheduler,
+ )
+ else:
+ transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ transformer, optimizer, train_dataloader, lr_scheduler
+ )
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ tracker_name = "dreambooth-flux-dev-lora"
+ accelerator.init_trackers(tracker_name, config=vars(args))
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the mos recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ initial_global_step = 0
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ initial_global_step = global_step
+ first_epoch = global_step // num_update_steps_per_epoch
+
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
+
+ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
+ sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype)
+ schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device)
+ timesteps = timesteps.to(accelerator.device)
+ step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
+
+ sigma = sigmas[step_indices].flatten()
+ while len(sigma.shape) < n_dim:
+ sigma = sigma.unsqueeze(-1)
+ return sigma
+
+ for epoch in range(first_epoch, args.num_train_epochs):
+ transformer.train()
+ if args.train_text_encoder:
+ text_encoder_one.train()
+
+ for step, batch in enumerate(train_dataloader):
+ models_to_accumulate = [transformer]
+ if args.train_text_encoder:
+ models_to_accumulate.extend([text_encoder_one])
+ with accelerator.accumulate(models_to_accumulate):
+ pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
+ prompts = batch["prompts"]
+
+ # encode batch prompts when custom prompts are provided for each image -
+ if train_dataset.custom_instance_prompts:
+ if not args.train_text_encoder:
+ prompt_embeds, pooled_prompt_embeds, text_ids = compute_text_embeddings(
+ prompts, text_encoders, tokenizers
+ )
+ else:
+ tokens_one = tokenize_prompt(tokenizer_one, prompts, max_sequence_length=77)
+ tokens_two = tokenize_prompt(
+ tokenizer_two, prompts, max_sequence_length=args.max_sequence_length
+ )
+ prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt(
+ text_encoders=[text_encoder_one, text_encoder_two],
+ tokenizers=[None, None],
+ text_input_ids_list=[tokens_one, tokens_two],
+ max_sequence_length=args.max_sequence_length,
+ prompt=prompts,
+ )
+ else:
+ if args.train_text_encoder:
+ prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt(
+ text_encoders=[text_encoder_one, text_encoder_two],
+ tokenizers=[None, None],
+ text_input_ids_list=[tokens_one, tokens_two],
+ max_sequence_length=args.max_sequence_length,
+ prompt=args.instance_prompt,
+ )
+
+ # Convert images to latent space
+ model_input = vae.encode(pixel_values).latent_dist.sample()
+ model_input = (model_input - vae.config.shift_factor) * vae.config.scaling_factor
+ model_input = model_input.to(dtype=weight_dtype)
+
+ vae_scale_factor = 2 ** (len(vae.config.block_out_channels))
+
+ latent_image_ids = FluxPipeline._prepare_latent_image_ids(
+ model_input.shape[0],
+ model_input.shape[2],
+ model_input.shape[3],
+ accelerator.device,
+ weight_dtype,
+ )
+
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(model_input)
+ bsz = model_input.shape[0]
+
+ # Sample a random timestep for each image
+ # for weighting schemes where we sample timesteps non-uniformly
+ u = compute_density_for_timestep_sampling(
+ weighting_scheme=args.weighting_scheme,
+ batch_size=bsz,
+ logit_mean=args.logit_mean,
+ logit_std=args.logit_std,
+ mode_scale=args.mode_scale,
+ )
+ indices = (u * noise_scheduler_copy.config.num_train_timesteps).long()
+ timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device)
+
+ # Add noise according to flow matching.
+ # zt = (1 - texp) * x + texp * z1
+ sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)
+ noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise
+
+ packed_noisy_model_input = FluxPipeline._pack_latents(
+ noisy_model_input,
+ batch_size=model_input.shape[0],
+ num_channels_latents=model_input.shape[1],
+ height=model_input.shape[2],
+ width=model_input.shape[3],
+ )
+
+ # handle guidance
+ if transformer.config.guidance_embeds:
+ guidance = torch.tensor([args.guidance_scale], device=accelerator.device)
+ guidance = guidance.expand(model_input.shape[0])
+ else:
+ guidance = None
+
+ # Predict the noise residual
+ model_pred = transformer(
+ hidden_states=packed_noisy_model_input,
+ # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
+ timestep=timesteps / 1000,
+ guidance=guidance,
+ pooled_projections=pooled_prompt_embeds,
+ encoder_hidden_states=prompt_embeds,
+ txt_ids=text_ids,
+ img_ids=latent_image_ids,
+ return_dict=False,
+ )[0]
+ # upscaling height & width as discussed in https://github.com/huggingface/diffusers/pull/9257#discussion_r1731108042
+ model_pred = FluxPipeline._unpack_latents(
+ model_pred,
+ height=int(model_input.shape[2] * vae_scale_factor / 2),
+ width=int(model_input.shape[3] * vae_scale_factor / 2),
+ vae_scale_factor=vae_scale_factor,
+ )
+
+ # these weighting schemes use a uniform timestep sampling
+ # and instead post-weight the loss
+ weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
+
+ # flow matching loss
+ target = noise - model_input
+
+ if args.with_prior_preservation:
+ # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
+ model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
+ target, target_prior = torch.chunk(target, 2, dim=0)
+
+ # Compute prior loss
+ prior_loss = torch.mean(
+ (weighting.float() * (model_pred_prior.float() - target_prior.float()) ** 2).reshape(
+ target_prior.shape[0], -1
+ ),
+ 1,
+ )
+ prior_loss = prior_loss.mean()
+
+ # Compute regular loss.
+ loss = torch.mean(
+ (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1),
+ 1,
+ )
+ loss = loss.mean()
+
+ if args.with_prior_preservation:
+ # Add the prior loss to the instance loss.
+ loss = loss + args.prior_loss_weight * prior_loss
+
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ params_to_clip = (
+ itertools.chain(transformer.parameters(), text_encoder_one.parameters())
+ if args.train_text_encoder
+ else transformer.parameters()
+ )
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
+
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ global_step += 1
+
+ if accelerator.is_main_process:
+ if global_step % args.checkpointing_steps == 0:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+ accelerator.log(logs, step=global_step)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ if accelerator.is_main_process:
+ if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
+ # create pipeline
+ if not args.train_text_encoder:
+ text_encoder_one, text_encoder_two = load_text_encoders(text_encoder_cls_one, text_encoder_cls_two)
+ else: # even when training the text encoder we're only training text encoder one
+ text_encoder_two = text_encoder_cls_two.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="text_encoder_2",
+ revision=args.revision,
+ variant=args.variant,
+ )
+ pipeline = FluxPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ vae=vae,
+ text_encoder=accelerator.unwrap_model(text_encoder_one),
+ text_encoder_2=accelerator.unwrap_model(text_encoder_two),
+ transformer=accelerator.unwrap_model(transformer),
+ revision=args.revision,
+ variant=args.variant,
+ torch_dtype=weight_dtype,
+ )
+ pipeline_args = {"prompt": args.validation_prompt}
+ images = log_validation(
+ pipeline=pipeline,
+ args=args,
+ accelerator=accelerator,
+ pipeline_args=pipeline_args,
+ epoch=epoch,
+ torch_dtype=weight_dtype,
+ )
+ if not args.train_text_encoder:
+ del text_encoder_one, text_encoder_two
+ torch.cuda.empty_cache()
+ gc.collect()
+
+ # Save the lora layers
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ transformer = unwrap_model(transformer)
+
+ if args.train_text_encoder:
+ text_encoder_one = unwrap_model(text_encoder_one)
+ pipeline = FluxPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ transformer=transformer,
+ text_encoder=text_encoder_one,
+ )
+ else:
+ pipeline = FluxPipeline.from_pretrained(args.pretrained_model_name_or_path, transformer=transformer)
+
+ # save the pipeline
+ pipeline.save_pretrained(args.output_dir)
+
+ # Final inference
+ # Load previous pipeline
+ pipeline = FluxPipeline.from_pretrained(
+ args.output_dir,
+ revision=args.revision,
+ variant=args.variant,
+ torch_dtype=weight_dtype,
+ )
+
+ # run inference
+ images = []
+ if args.validation_prompt and args.num_validation_images > 0:
+ pipeline_args = {"prompt": args.validation_prompt}
+ images = log_validation(
+ pipeline=pipeline,
+ args=args,
+ accelerator=accelerator,
+ pipeline_args=pipeline_args,
+ epoch=epoch,
+ is_final_validation=True,
+ torch_dtype=weight_dtype,
+ )
+
+ if args.push_to_hub:
+ save_model_card(
+ repo_id,
+ images=images,
+ base_model=args.pretrained_model_name_or_path,
+ train_text_encoder=args.train_text_encoder,
+ instance_prompt=args.instance_prompt,
+ validation_prompt=args.validation_prompt,
+ repo_folder=args.output_dir,
+ )
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ main(args)
diff --git a/diffusers/examples/dreambooth/train_dreambooth_lora.py b/diffusers/examples/dreambooth/train_dreambooth_lora.py
new file mode 100644
index 0000000000000000000000000000000000000000..5d7d697bb21de4629b5376d600a2be0c9d7429a1
--- /dev/null
+++ b/diffusers/examples/dreambooth/train_dreambooth_lora.py
@@ -0,0 +1,1432 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+import argparse
+import copy
+import gc
+import logging
+import math
+import os
+import shutil
+import warnings
+from pathlib import Path
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import ProjectConfiguration, set_seed
+from huggingface_hub import create_repo, upload_folder
+from huggingface_hub.utils import insecure_hashlib
+from packaging import version
+from peft import LoraConfig
+from peft.utils import get_peft_model_state_dict, set_peft_model_state_dict
+from PIL import Image
+from PIL.ImageOps import exif_transpose
+from torch.utils.data import Dataset
+from torchvision import transforms
+from tqdm.auto import tqdm
+from transformers import AutoTokenizer, PretrainedConfig
+
+import diffusers
+from diffusers import (
+ AutoencoderKL,
+ DDPMScheduler,
+ DiffusionPipeline,
+ DPMSolverMultistepScheduler,
+ StableDiffusionPipeline,
+ UNet2DConditionModel,
+)
+from diffusers.loaders import StableDiffusionLoraLoaderMixin
+from diffusers.optimization import get_scheduler
+from diffusers.training_utils import _set_state_dict_into_text_encoder, cast_training_params
+from diffusers.utils import (
+ check_min_version,
+ convert_state_dict_to_diffusers,
+ convert_unet_state_dict_to_peft,
+ is_wandb_available,
+)
+from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
+from diffusers.utils.import_utils import is_xformers_available
+from diffusers.utils.torch_utils import is_compiled_module
+
+
+if is_wandb_available():
+ import wandb
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.31.0.dev0")
+
+logger = get_logger(__name__)
+
+
+def save_model_card(
+ repo_id: str,
+ images=None,
+ base_model=str,
+ train_text_encoder=False,
+ prompt=str,
+ repo_folder=None,
+ pipeline: DiffusionPipeline = None,
+):
+ img_str = ""
+ for i, image in enumerate(images):
+ image.save(os.path.join(repo_folder, f"image_{i}.png"))
+ img_str += f"\n"
+
+ model_description = f"""
+# LoRA DreamBooth - {repo_id}
+
+These are LoRA adaption weights for {base_model}. The weights were trained on {prompt} using [DreamBooth](https://dreambooth.github.io/). You can find some example images in the following. \n
+{img_str}
+
+LoRA for the text encoder was enabled: {train_text_encoder}.
+"""
+ model_card = load_or_create_model_card(
+ repo_id_or_path=repo_id,
+ from_training=True,
+ license="creativeml-openrail-m",
+ base_model=base_model,
+ prompt=prompt,
+ model_description=model_description,
+ inference=True,
+ )
+ tags = ["text-to-image", "diffusers", "lora", "diffusers-training"]
+ if isinstance(pipeline, StableDiffusionPipeline):
+ tags.extend(["stable-diffusion", "stable-diffusion-diffusers"])
+ else:
+ tags.extend(["if", "if-diffusers"])
+ model_card = populate_model_card(model_card, tags=tags)
+
+ model_card.save(os.path.join(repo_folder, "README.md"))
+
+
+def log_validation(
+ pipeline,
+ args,
+ accelerator,
+ pipeline_args,
+ epoch,
+ torch_dtype,
+ is_final_validation=False,
+):
+ logger.info(
+ f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
+ f" {args.validation_prompt}."
+ )
+ # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
+ scheduler_args = {}
+
+ if "variance_type" in pipeline.scheduler.config:
+ variance_type = pipeline.scheduler.config.variance_type
+
+ if variance_type in ["learned", "learned_range"]:
+ variance_type = "fixed_small"
+
+ scheduler_args["variance_type"] = variance_type
+
+ pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args)
+
+ pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)
+ pipeline.set_progress_bar_config(disable=True)
+
+ # run inference
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
+
+ if args.validation_images is None:
+ images = []
+ for _ in range(args.num_validation_images):
+ with torch.cuda.amp.autocast():
+ image = pipeline(**pipeline_args, generator=generator).images[0]
+ images.append(image)
+ else:
+ images = []
+ for image in args.validation_images:
+ image = Image.open(image)
+ with torch.cuda.amp.autocast():
+ image = pipeline(**pipeline_args, image=image, generator=generator).images[0]
+ images.append(image)
+
+ for tracker in accelerator.trackers:
+ phase_name = "test" if is_final_validation else "validation"
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images(phase_name, np_images, epoch, dataformats="NHWC")
+ if tracker.name == "wandb":
+ tracker.log(
+ {
+ phase_name: [
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images)
+ ]
+ }
+ )
+
+ del pipeline
+ torch.cuda.empty_cache()
+
+ return images
+
+
+def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
+ text_encoder_config = PretrainedConfig.from_pretrained(
+ pretrained_model_name_or_path,
+ subfolder="text_encoder",
+ revision=revision,
+ )
+ model_class = text_encoder_config.architectures[0]
+
+ if model_class == "CLIPTextModel":
+ from transformers import CLIPTextModel
+
+ return CLIPTextModel
+ elif model_class == "RobertaSeriesModelWithTransformation":
+ from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation
+
+ return RobertaSeriesModelWithTransformation
+ elif model_class == "T5EncoderModel":
+ from transformers import T5EncoderModel
+
+ return T5EncoderModel
+ else:
+ raise ValueError(f"{model_class} is not supported.")
+
+
+def parse_args(input_args=None):
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--variant",
+ type=str,
+ default=None,
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
+ )
+ parser.add_argument(
+ "--tokenizer_name",
+ type=str,
+ default=None,
+ help="Pretrained tokenizer name or path if not the same as model_name",
+ )
+ parser.add_argument(
+ "--instance_data_dir",
+ type=str,
+ default=None,
+ required=True,
+ help="A folder containing the training data of instance images.",
+ )
+ parser.add_argument(
+ "--class_data_dir",
+ type=str,
+ default=None,
+ required=False,
+ help="A folder containing the training data of class images.",
+ )
+ parser.add_argument(
+ "--instance_prompt",
+ type=str,
+ default=None,
+ required=True,
+ help="The prompt with identifier specifying the instance",
+ )
+ parser.add_argument(
+ "--class_prompt",
+ type=str,
+ default=None,
+ help="The prompt to specify images in the same class as provided instance images.",
+ )
+ parser.add_argument(
+ "--validation_prompt",
+ type=str,
+ default=None,
+ help="A prompt that is used during validation to verify that the model is learning.",
+ )
+ parser.add_argument(
+ "--num_validation_images",
+ type=int,
+ default=4,
+ help="Number of images that should be generated during validation with `validation_prompt`.",
+ )
+ parser.add_argument(
+ "--validation_epochs",
+ type=int,
+ default=50,
+ help=(
+ "Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt"
+ " `args.validation_prompt` multiple times: `args.num_validation_images`."
+ ),
+ )
+ parser.add_argument(
+ "--with_prior_preservation",
+ default=False,
+ action="store_true",
+ help="Flag to add prior preservation loss.",
+ )
+ parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.")
+ parser.add_argument(
+ "--num_class_images",
+ type=int,
+ default=100,
+ help=(
+ "Minimal class images for prior preservation loss. If there are not enough images already present in"
+ " class_data_dir, additional images will be sampled with class_prompt."
+ ),
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="lora-dreambooth-model",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--center_crop",
+ default=False,
+ action="store_true",
+ help=(
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
+ " cropped. The images will be resized to the resolution first before cropping."
+ ),
+ )
+ parser.add_argument(
+ "--train_text_encoder",
+ action="store_true",
+ help="Whether to train the text encoder. If set, the text encoder should be float32 precision.",
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument(
+ "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=1)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
+ " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=5e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--lr_num_cycles",
+ type=int,
+ default=1,
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
+ )
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ parser.add_argument(
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument(
+ "--prior_generation_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp32", "fp16", "bf16"],
+ help=(
+ "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32."
+ ),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+ parser.add_argument(
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
+ )
+ parser.add_argument(
+ "--pre_compute_text_embeddings",
+ action="store_true",
+ help="Whether or not to pre-compute text embeddings. If text embeddings are pre-computed, the text encoder will not be kept in memory during training and will leave more GPU memory available for training the rest of the model. This is not compatible with `--train_text_encoder`.",
+ )
+ parser.add_argument(
+ "--tokenizer_max_length",
+ type=int,
+ default=None,
+ required=False,
+ help="The maximum length of the tokenizer. If not set, will default to the tokenizer's max length.",
+ )
+ parser.add_argument(
+ "--text_encoder_use_attention_mask",
+ action="store_true",
+ required=False,
+ help="Whether to use attention mask for the text encoder",
+ )
+ parser.add_argument(
+ "--validation_images",
+ required=False,
+ default=None,
+ nargs="+",
+ help="Optional set of images to use for validation. Used when the target pipeline takes an initial image as input such as when training image variation or superresolution.",
+ )
+ parser.add_argument(
+ "--class_labels_conditioning",
+ required=False,
+ default=None,
+ help="The optional `class_label` conditioning to pass to the unet, available values are `timesteps`.",
+ )
+ parser.add_argument(
+ "--rank",
+ type=int,
+ default=4,
+ help=("The dimension of the LoRA update matrices."),
+ )
+
+ if input_args is not None:
+ args = parser.parse_args(input_args)
+ else:
+ args = parser.parse_args()
+
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ if args.with_prior_preservation:
+ if args.class_data_dir is None:
+ raise ValueError("You must specify a data directory for class images.")
+ if args.class_prompt is None:
+ raise ValueError("You must specify prompt for class images.")
+ else:
+ # logger is not available yet
+ if args.class_data_dir is not None:
+ warnings.warn("You need not use --class_data_dir without --with_prior_preservation.")
+ if args.class_prompt is not None:
+ warnings.warn("You need not use --class_prompt without --with_prior_preservation.")
+
+ if args.train_text_encoder and args.pre_compute_text_embeddings:
+ raise ValueError("`--train_text_encoder` cannot be used with `--pre_compute_text_embeddings`")
+
+ return args
+
+
+class DreamBoothDataset(Dataset):
+ """
+ A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
+ It pre-processes the images and the tokenizes prompts.
+ """
+
+ def __init__(
+ self,
+ instance_data_root,
+ instance_prompt,
+ tokenizer,
+ class_data_root=None,
+ class_prompt=None,
+ class_num=None,
+ size=512,
+ center_crop=False,
+ encoder_hidden_states=None,
+ class_prompt_encoder_hidden_states=None,
+ tokenizer_max_length=None,
+ ):
+ self.size = size
+ self.center_crop = center_crop
+ self.tokenizer = tokenizer
+ self.encoder_hidden_states = encoder_hidden_states
+ self.class_prompt_encoder_hidden_states = class_prompt_encoder_hidden_states
+ self.tokenizer_max_length = tokenizer_max_length
+
+ self.instance_data_root = Path(instance_data_root)
+ if not self.instance_data_root.exists():
+ raise ValueError("Instance images root doesn't exists.")
+
+ self.instance_images_path = list(Path(instance_data_root).iterdir())
+ self.num_instance_images = len(self.instance_images_path)
+ self.instance_prompt = instance_prompt
+ self._length = self.num_instance_images
+
+ if class_data_root is not None:
+ self.class_data_root = Path(class_data_root)
+ self.class_data_root.mkdir(parents=True, exist_ok=True)
+ self.class_images_path = list(self.class_data_root.iterdir())
+ if class_num is not None:
+ self.num_class_images = min(len(self.class_images_path), class_num)
+ else:
+ self.num_class_images = len(self.class_images_path)
+ self._length = max(self.num_class_images, self.num_instance_images)
+ self.class_prompt = class_prompt
+ else:
+ self.class_data_root = None
+
+ self.image_transforms = transforms.Compose(
+ [
+ transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
+ transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+ )
+
+ def __len__(self):
+ return self._length
+
+ def __getitem__(self, index):
+ example = {}
+ instance_image = Image.open(self.instance_images_path[index % self.num_instance_images])
+ instance_image = exif_transpose(instance_image)
+
+ if not instance_image.mode == "RGB":
+ instance_image = instance_image.convert("RGB")
+ example["instance_images"] = self.image_transforms(instance_image)
+
+ if self.encoder_hidden_states is not None:
+ example["instance_prompt_ids"] = self.encoder_hidden_states
+ else:
+ text_inputs = tokenize_prompt(
+ self.tokenizer, self.instance_prompt, tokenizer_max_length=self.tokenizer_max_length
+ )
+ example["instance_prompt_ids"] = text_inputs.input_ids
+ example["instance_attention_mask"] = text_inputs.attention_mask
+
+ if self.class_data_root:
+ class_image = Image.open(self.class_images_path[index % self.num_class_images])
+ class_image = exif_transpose(class_image)
+
+ if not class_image.mode == "RGB":
+ class_image = class_image.convert("RGB")
+ example["class_images"] = self.image_transforms(class_image)
+
+ if self.class_prompt_encoder_hidden_states is not None:
+ example["class_prompt_ids"] = self.class_prompt_encoder_hidden_states
+ else:
+ class_text_inputs = tokenize_prompt(
+ self.tokenizer, self.class_prompt, tokenizer_max_length=self.tokenizer_max_length
+ )
+ example["class_prompt_ids"] = class_text_inputs.input_ids
+ example["class_attention_mask"] = class_text_inputs.attention_mask
+
+ return example
+
+
+def collate_fn(examples, with_prior_preservation=False):
+ has_attention_mask = "instance_attention_mask" in examples[0]
+
+ input_ids = [example["instance_prompt_ids"] for example in examples]
+ pixel_values = [example["instance_images"] for example in examples]
+
+ if has_attention_mask:
+ attention_mask = [example["instance_attention_mask"] for example in examples]
+
+ # Concat class and instance examples for prior preservation.
+ # We do this to avoid doing two forward passes.
+ if with_prior_preservation:
+ input_ids += [example["class_prompt_ids"] for example in examples]
+ pixel_values += [example["class_images"] for example in examples]
+ if has_attention_mask:
+ attention_mask += [example["class_attention_mask"] for example in examples]
+
+ pixel_values = torch.stack(pixel_values)
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
+
+ input_ids = torch.cat(input_ids, dim=0)
+
+ batch = {
+ "input_ids": input_ids,
+ "pixel_values": pixel_values,
+ }
+
+ if has_attention_mask:
+ batch["attention_mask"] = attention_mask
+
+ return batch
+
+
+class PromptDataset(Dataset):
+ """A simple dataset to prepare the prompts to generate class images on multiple GPUs."""
+
+ def __init__(self, prompt, num_samples):
+ self.prompt = prompt
+ self.num_samples = num_samples
+
+ def __len__(self):
+ return self.num_samples
+
+ def __getitem__(self, index):
+ example = {}
+ example["prompt"] = self.prompt
+ example["index"] = index
+ return example
+
+
+def tokenize_prompt(tokenizer, prompt, tokenizer_max_length=None):
+ if tokenizer_max_length is not None:
+ max_length = tokenizer_max_length
+ else:
+ max_length = tokenizer.model_max_length
+
+ text_inputs = tokenizer(
+ prompt,
+ truncation=True,
+ padding="max_length",
+ max_length=max_length,
+ return_tensors="pt",
+ )
+
+ return text_inputs
+
+
+def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_attention_mask=None):
+ text_input_ids = input_ids.to(text_encoder.device)
+
+ if text_encoder_use_attention_mask:
+ attention_mask = attention_mask.to(text_encoder.device)
+ else:
+ attention_mask = None
+
+ prompt_embeds = text_encoder(
+ text_input_ids,
+ attention_mask=attention_mask,
+ return_dict=False,
+ )
+ prompt_embeds = prompt_embeds[0]
+
+ return prompt_embeds
+
+
+def main(args):
+ if args.report_to == "wandb" and args.hub_token is not None:
+ raise ValueError(
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
+ " Please use `huggingface-cli login` to authenticate with the Hub."
+ )
+
+ logging_dir = Path(args.output_dir, args.logging_dir)
+
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ )
+
+ # Disable AMP for MPS.
+ if torch.backends.mps.is_available():
+ accelerator.native_amp = False
+
+ if args.report_to == "wandb":
+ if not is_wandb_available():
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
+
+ # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
+ # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
+ # TODO (sayakpaul): Remove this check when gradient accumulation with two models is enabled in accelerate.
+ if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1:
+ raise ValueError(
+ "Gradient accumulation is not supported when training the text encoder in distributed training. "
+ "Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
+ )
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Generate class images if prior preservation is enabled.
+ if args.with_prior_preservation:
+ class_images_dir = Path(args.class_data_dir)
+ if not class_images_dir.exists():
+ class_images_dir.mkdir(parents=True)
+ cur_class_images = len(list(class_images_dir.iterdir()))
+
+ if cur_class_images < args.num_class_images:
+ torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
+ if args.prior_generation_precision == "fp32":
+ torch_dtype = torch.float32
+ elif args.prior_generation_precision == "fp16":
+ torch_dtype = torch.float16
+ elif args.prior_generation_precision == "bf16":
+ torch_dtype = torch.bfloat16
+ pipeline = DiffusionPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ torch_dtype=torch_dtype,
+ safety_checker=None,
+ revision=args.revision,
+ variant=args.variant,
+ )
+ pipeline.set_progress_bar_config(disable=True)
+
+ num_new_images = args.num_class_images - cur_class_images
+ logger.info(f"Number of class images to sample: {num_new_images}.")
+
+ sample_dataset = PromptDataset(args.class_prompt, num_new_images)
+ sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
+
+ sample_dataloader = accelerator.prepare(sample_dataloader)
+ pipeline.to(accelerator.device)
+
+ for example in tqdm(
+ sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
+ ):
+ images = pipeline(example["prompt"]).images
+
+ for i, image in enumerate(images):
+ hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest()
+ image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
+ image.save(image_filename)
+
+ del pipeline
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
+ ).repo_id
+
+ # Load the tokenizer
+ if args.tokenizer_name:
+ tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False)
+ elif args.pretrained_model_name_or_path:
+ tokenizer = AutoTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="tokenizer",
+ revision=args.revision,
+ use_fast=False,
+ )
+
+ # import correct text encoder class
+ text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)
+
+ # Load scheduler and models
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
+ text_encoder = text_encoder_cls.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
+ )
+ try:
+ vae = AutoencoderKL.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant
+ )
+ except OSError:
+ # IF does not have a VAE so let's just set it to None
+ # We don't have to error out here
+ vae = None
+
+ unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
+ )
+
+ # We only train the additional adapter LoRA layers
+ if vae is not None:
+ vae.requires_grad_(False)
+ text_encoder.requires_grad_(False)
+ unet.requires_grad_(False)
+
+ # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
+ # as these weights are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ # Move unet, vae and text_encoder to device and cast to weight_dtype
+ unet.to(accelerator.device, dtype=weight_dtype)
+ if vae is not None:
+ vae.to(accelerator.device, dtype=weight_dtype)
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
+
+ if args.enable_xformers_memory_efficient_attention:
+ if is_xformers_available():
+ import xformers
+
+ xformers_version = version.parse(xformers.__version__)
+ if xformers_version == version.parse("0.0.16"):
+ logger.warning(
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
+ )
+ unet.enable_xformers_memory_efficient_attention()
+ else:
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
+
+ if args.gradient_checkpointing:
+ unet.enable_gradient_checkpointing()
+ if args.train_text_encoder:
+ text_encoder.gradient_checkpointing_enable()
+
+ # now we will add new LoRA weights to the attention layers
+ unet_lora_config = LoraConfig(
+ r=args.rank,
+ lora_alpha=args.rank,
+ init_lora_weights="gaussian",
+ target_modules=["to_k", "to_q", "to_v", "to_out.0", "add_k_proj", "add_v_proj"],
+ )
+ unet.add_adapter(unet_lora_config)
+
+ # The text encoder comes from 🤗 transformers, we will also attach adapters to it.
+ if args.train_text_encoder:
+ text_lora_config = LoraConfig(
+ r=args.rank,
+ lora_alpha=args.rank,
+ init_lora_weights="gaussian",
+ target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
+ )
+ text_encoder.add_adapter(text_lora_config)
+
+ def unwrap_model(model):
+ model = accelerator.unwrap_model(model)
+ model = model._orig_mod if is_compiled_module(model) else model
+ return model
+
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
+ def save_model_hook(models, weights, output_dir):
+ if accelerator.is_main_process:
+ # there are only two options here. Either are just the unet attn processor layers
+ # or there are the unet and text encoder atten layers
+ unet_lora_layers_to_save = None
+ text_encoder_lora_layers_to_save = None
+
+ for model in models:
+ if isinstance(model, type(unwrap_model(unet))):
+ unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model))
+ elif isinstance(model, type(unwrap_model(text_encoder))):
+ text_encoder_lora_layers_to_save = convert_state_dict_to_diffusers(
+ get_peft_model_state_dict(model)
+ )
+ else:
+ raise ValueError(f"unexpected save model: {model.__class__}")
+
+ # make sure to pop weight so that corresponding model is not saved again
+ weights.pop()
+
+ StableDiffusionLoraLoaderMixin.save_lora_weights(
+ output_dir,
+ unet_lora_layers=unet_lora_layers_to_save,
+ text_encoder_lora_layers=text_encoder_lora_layers_to_save,
+ )
+
+ def load_model_hook(models, input_dir):
+ unet_ = None
+ text_encoder_ = None
+
+ while len(models) > 0:
+ model = models.pop()
+
+ if isinstance(model, type(unwrap_model(unet))):
+ unet_ = model
+ elif isinstance(model, type(unwrap_model(text_encoder))):
+ text_encoder_ = model
+ else:
+ raise ValueError(f"unexpected save model: {model.__class__}")
+
+ lora_state_dict, network_alphas = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir)
+
+ unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")}
+ unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
+ incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
+
+ if incompatible_keys is not None:
+ # check only for unexpected keys
+ unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
+ if unexpected_keys:
+ logger.warning(
+ f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
+ f" {unexpected_keys}. "
+ )
+
+ if args.train_text_encoder:
+ _set_state_dict_into_text_encoder(lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_)
+
+ # Make sure the trainable params are in float32. This is again needed since the base models
+ # are in `weight_dtype`. More details:
+ # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
+ if args.mixed_precision == "fp16":
+ models = [unet_]
+ if args.train_text_encoder:
+ models.append(text_encoder_)
+
+ # only upcast trainable parameters (LoRA) into fp32
+ cast_training_params(models, dtype=torch.float32)
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ # Make sure the trainable params are in float32.
+ if args.mixed_precision == "fp16":
+ models = [unet]
+ if args.train_text_encoder:
+ models.append(text_encoder)
+
+ # only upcast trainable parameters (LoRA) into fp32
+ cast_training_params(models, dtype=torch.float32)
+
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
+ )
+
+ optimizer_class = bnb.optim.AdamW8bit
+ else:
+ optimizer_class = torch.optim.AdamW
+
+ # Optimizer creation
+ params_to_optimize = list(filter(lambda p: p.requires_grad, unet.parameters()))
+ if args.train_text_encoder:
+ params_to_optimize = params_to_optimize + list(filter(lambda p: p.requires_grad, text_encoder.parameters()))
+
+ optimizer = optimizer_class(
+ params_to_optimize,
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ if args.pre_compute_text_embeddings:
+
+ def compute_text_embeddings(prompt):
+ with torch.no_grad():
+ text_inputs = tokenize_prompt(tokenizer, prompt, tokenizer_max_length=args.tokenizer_max_length)
+ prompt_embeds = encode_prompt(
+ text_encoder,
+ text_inputs.input_ids,
+ text_inputs.attention_mask,
+ text_encoder_use_attention_mask=args.text_encoder_use_attention_mask,
+ )
+
+ return prompt_embeds
+
+ pre_computed_encoder_hidden_states = compute_text_embeddings(args.instance_prompt)
+ validation_prompt_negative_prompt_embeds = compute_text_embeddings("")
+
+ if args.validation_prompt is not None:
+ validation_prompt_encoder_hidden_states = compute_text_embeddings(args.validation_prompt)
+ else:
+ validation_prompt_encoder_hidden_states = None
+
+ if args.class_prompt is not None:
+ pre_computed_class_prompt_encoder_hidden_states = compute_text_embeddings(args.class_prompt)
+ else:
+ pre_computed_class_prompt_encoder_hidden_states = None
+
+ text_encoder = None
+ tokenizer = None
+
+ gc.collect()
+ torch.cuda.empty_cache()
+ else:
+ pre_computed_encoder_hidden_states = None
+ validation_prompt_encoder_hidden_states = None
+ validation_prompt_negative_prompt_embeds = None
+ pre_computed_class_prompt_encoder_hidden_states = None
+
+ # Dataset and DataLoaders creation:
+ train_dataset = DreamBoothDataset(
+ instance_data_root=args.instance_data_dir,
+ instance_prompt=args.instance_prompt,
+ class_data_root=args.class_data_dir if args.with_prior_preservation else None,
+ class_prompt=args.class_prompt,
+ class_num=args.num_class_images,
+ tokenizer=tokenizer,
+ size=args.resolution,
+ center_crop=args.center_crop,
+ encoder_hidden_states=pre_computed_encoder_hidden_states,
+ class_prompt_encoder_hidden_states=pre_computed_class_prompt_encoder_hidden_states,
+ tokenizer_max_length=args.tokenizer_max_length,
+ )
+
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset,
+ batch_size=args.train_batch_size,
+ shuffle=True,
+ collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),
+ num_workers=args.dataloader_num_workers,
+ )
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
+ num_cycles=args.lr_num_cycles,
+ power=args.lr_power,
+ )
+
+ # Prepare everything with our `accelerator`.
+ if args.train_text_encoder:
+ unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ unet, text_encoder, optimizer, train_dataloader, lr_scheduler
+ )
+ else:
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ unet, optimizer, train_dataloader, lr_scheduler
+ )
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ tracker_config = vars(copy.deepcopy(args))
+ tracker_config.pop("validation_images")
+ accelerator.init_trackers("dreambooth-lora", config=tracker_config)
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the mos recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ initial_global_step = 0
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ initial_global_step = global_step
+ first_epoch = global_step // num_update_steps_per_epoch
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
+
+ for epoch in range(first_epoch, args.num_train_epochs):
+ unet.train()
+ if args.train_text_encoder:
+ text_encoder.train()
+ for step, batch in enumerate(train_dataloader):
+ with accelerator.accumulate(unet):
+ pixel_values = batch["pixel_values"].to(dtype=weight_dtype)
+
+ if vae is not None:
+ # Convert images to latent space
+ model_input = vae.encode(pixel_values).latent_dist.sample()
+ model_input = model_input * vae.config.scaling_factor
+ else:
+ model_input = pixel_values
+
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(model_input)
+ bsz, channels, height, width = model_input.shape
+ # Sample a random timestep for each image
+ timesteps = torch.randint(
+ 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device
+ )
+ timesteps = timesteps.long()
+
+ # Add noise to the model input according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)
+
+ # Get the text embedding for conditioning
+ if args.pre_compute_text_embeddings:
+ encoder_hidden_states = batch["input_ids"]
+ else:
+ encoder_hidden_states = encode_prompt(
+ text_encoder,
+ batch["input_ids"],
+ batch["attention_mask"],
+ text_encoder_use_attention_mask=args.text_encoder_use_attention_mask,
+ )
+
+ if unwrap_model(unet).config.in_channels == channels * 2:
+ noisy_model_input = torch.cat([noisy_model_input, noisy_model_input], dim=1)
+
+ if args.class_labels_conditioning == "timesteps":
+ class_labels = timesteps
+ else:
+ class_labels = None
+
+ # Predict the noise residual
+ model_pred = unet(
+ noisy_model_input,
+ timesteps,
+ encoder_hidden_states,
+ class_labels=class_labels,
+ return_dict=False,
+ )[0]
+
+ # if model predicts variance, throw away the prediction. we will only train on the
+ # simplified training objective. This means that all schedulers using the fine tuned
+ # model must be configured to use one of the fixed variance variance types.
+ if model_pred.shape[1] == 6:
+ model_pred, _ = torch.chunk(model_pred, 2, dim=1)
+
+ # Get the target for loss depending on the prediction type
+ if noise_scheduler.config.prediction_type == "epsilon":
+ target = noise
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ target = noise_scheduler.get_velocity(model_input, noise, timesteps)
+ else:
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
+
+ if args.with_prior_preservation:
+ # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
+ model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
+ target, target_prior = torch.chunk(target, 2, dim=0)
+
+ # Compute instance loss
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
+
+ # Compute prior loss
+ prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
+
+ # Add the prior loss to the instance loss.
+ loss = loss + args.prior_loss_weight * prior_loss
+ else:
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
+
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ accelerator.clip_grad_norm_(params_to_optimize, args.max_grad_norm)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ global_step += 1
+
+ if accelerator.is_main_process:
+ if global_step % args.checkpointing_steps == 0:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+ accelerator.log(logs, step=global_step)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ if accelerator.is_main_process:
+ if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
+ # create pipeline
+ pipeline = DiffusionPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ unet=unwrap_model(unet),
+ text_encoder=None if args.pre_compute_text_embeddings else unwrap_model(text_encoder),
+ revision=args.revision,
+ variant=args.variant,
+ torch_dtype=weight_dtype,
+ )
+
+ if args.pre_compute_text_embeddings:
+ pipeline_args = {
+ "prompt_embeds": validation_prompt_encoder_hidden_states,
+ "negative_prompt_embeds": validation_prompt_negative_prompt_embeds,
+ }
+ else:
+ pipeline_args = {"prompt": args.validation_prompt}
+
+ images = log_validation(
+ pipeline,
+ args,
+ accelerator,
+ pipeline_args,
+ epoch,
+ torch_dtype=weight_dtype,
+ )
+
+ # Save the lora layers
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ unet = unwrap_model(unet)
+ unet = unet.to(torch.float32)
+
+ unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet))
+
+ if args.train_text_encoder:
+ text_encoder = unwrap_model(text_encoder)
+ text_encoder_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(text_encoder))
+ else:
+ text_encoder_state_dict = None
+
+ StableDiffusionLoraLoaderMixin.save_lora_weights(
+ save_directory=args.output_dir,
+ unet_lora_layers=unet_lora_state_dict,
+ text_encoder_lora_layers=text_encoder_state_dict,
+ )
+
+ # Final inference
+ # Load previous pipeline
+ pipeline = DiffusionPipeline.from_pretrained(
+ args.pretrained_model_name_or_path, revision=args.revision, variant=args.variant, torch_dtype=weight_dtype
+ )
+
+ # load attention processors
+ pipeline.load_lora_weights(args.output_dir, weight_name="pytorch_lora_weights.safetensors")
+
+ # run inference
+ images = []
+ if args.validation_prompt and args.num_validation_images > 0:
+ pipeline_args = {"prompt": args.validation_prompt, "num_inference_steps": 25}
+ images = log_validation(
+ pipeline,
+ args,
+ accelerator,
+ pipeline_args,
+ epoch,
+ is_final_validation=True,
+ torch_dtype=weight_dtype,
+ )
+
+ if args.push_to_hub:
+ save_model_card(
+ repo_id,
+ images=images,
+ base_model=args.pretrained_model_name_or_path,
+ train_text_encoder=args.train_text_encoder,
+ prompt=args.instance_prompt,
+ repo_folder=args.output_dir,
+ pipeline=pipeline,
+ )
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ main(args)
diff --git a/diffusers/examples/dreambooth/train_dreambooth_lora_flux.py b/diffusers/examples/dreambooth/train_dreambooth_lora_flux.py
new file mode 100644
index 0000000000000000000000000000000000000000..6091622719eedf6b477dc8f73bdf0fccde9422d7
--- /dev/null
+++ b/diffusers/examples/dreambooth/train_dreambooth_lora_flux.py
@@ -0,0 +1,1891 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+import argparse
+import copy
+import itertools
+import logging
+import math
+import os
+import random
+import shutil
+import warnings
+from contextlib import nullcontext
+from pathlib import Path
+
+import numpy as np
+import torch
+import torch.utils.checkpoint
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
+from huggingface_hub import create_repo, upload_folder
+from huggingface_hub.utils import insecure_hashlib
+from peft import LoraConfig, set_peft_model_state_dict
+from peft.utils import get_peft_model_state_dict
+from PIL import Image
+from PIL.ImageOps import exif_transpose
+from torch.utils.data import Dataset
+from torchvision import transforms
+from torchvision.transforms.functional import crop
+from tqdm.auto import tqdm
+from transformers import CLIPTokenizer, PretrainedConfig, T5TokenizerFast
+
+import diffusers
+from diffusers import (
+ AutoencoderKL,
+ FlowMatchEulerDiscreteScheduler,
+ FluxPipeline,
+ FluxTransformer2DModel,
+)
+from diffusers.optimization import get_scheduler
+from diffusers.training_utils import (
+ _set_state_dict_into_text_encoder,
+ cast_training_params,
+ clear_objs_and_retain_memory,
+ compute_density_for_timestep_sampling,
+ compute_loss_weighting_for_sd3,
+)
+from diffusers.utils import (
+ check_min_version,
+ convert_unet_state_dict_to_peft,
+ is_wandb_available,
+)
+from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
+from diffusers.utils.torch_utils import is_compiled_module
+
+
+if is_wandb_available():
+ import wandb
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.31.0.dev0")
+
+logger = get_logger(__name__)
+
+
+def save_model_card(
+ repo_id: str,
+ images=None,
+ base_model: str = None,
+ train_text_encoder=False,
+ instance_prompt=None,
+ validation_prompt=None,
+ repo_folder=None,
+):
+ widget_dict = []
+ if images is not None:
+ for i, image in enumerate(images):
+ image.save(os.path.join(repo_folder, f"image_{i}.png"))
+ widget_dict.append(
+ {"text": validation_prompt if validation_prompt else " ", "output": {"url": f"image_{i}.png"}}
+ )
+
+ model_description = f"""
+# Flux DreamBooth LoRA - {repo_id}
+
+
+
+## Model description
+
+These are {repo_id} DreamBooth LoRA weights for {base_model}.
+
+The weights were trained using [DreamBooth](https://dreambooth.github.io/) with the [Flux diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/README_flux.md).
+
+Was LoRA for the text encoder enabled? {train_text_encoder}.
+
+## Trigger words
+
+You should use `{instance_prompt}` to trigger the image generation.
+
+## Download model
+
+[Download the *.safetensors LoRA]({repo_id}/tree/main) in the Files & versions tab.
+
+## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers)
+
+```py
+from diffusers import AutoPipelineForText2Image
+import torch
+pipeline = AutoPipelineForText2Image.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16).to('cuda')
+pipeline.load_lora_weights('{repo_id}', weight_name='pytorch_lora_weights.safetensors')
+image = pipeline('{validation_prompt if validation_prompt else instance_prompt}').images[0]
+```
+
+For more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters)
+
+## License
+
+Please adhere to the licensing terms as described [here](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md).
+"""
+ model_card = load_or_create_model_card(
+ repo_id_or_path=repo_id,
+ from_training=True,
+ license="other",
+ base_model=base_model,
+ prompt=instance_prompt,
+ model_description=model_description,
+ widget=widget_dict,
+ )
+ tags = [
+ "text-to-image",
+ "diffusers-training",
+ "diffusers",
+ "lora",
+ "flux",
+ "flux-diffusers",
+ "template:sd-lora",
+ ]
+
+ model_card = populate_model_card(model_card, tags=tags)
+ model_card.save(os.path.join(repo_folder, "README.md"))
+
+
+def load_text_encoders(class_one, class_two):
+ text_encoder_one = class_one.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
+ )
+ text_encoder_two = class_two.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant
+ )
+ return text_encoder_one, text_encoder_two
+
+
+def log_validation(
+ pipeline,
+ args,
+ accelerator,
+ pipeline_args,
+ epoch,
+ torch_dtype,
+ is_final_validation=False,
+):
+ logger.info(
+ f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
+ f" {args.validation_prompt}."
+ )
+ pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)
+ pipeline.set_progress_bar_config(disable=True)
+
+ # run inference
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
+ # autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()
+ autocast_ctx = nullcontext()
+
+ with autocast_ctx:
+ images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)]
+
+ for tracker in accelerator.trackers:
+ phase_name = "test" if is_final_validation else "validation"
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images(phase_name, np_images, epoch, dataformats="NHWC")
+ if tracker.name == "wandb":
+ tracker.log(
+ {
+ phase_name: [
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images)
+ ]
+ }
+ )
+
+ del pipeline
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+
+ return images
+
+
+def import_model_class_from_model_name_or_path(
+ pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
+):
+ text_encoder_config = PretrainedConfig.from_pretrained(
+ pretrained_model_name_or_path, subfolder=subfolder, revision=revision
+ )
+ model_class = text_encoder_config.architectures[0]
+ if model_class == "CLIPTextModel":
+ from transformers import CLIPTextModel
+
+ return CLIPTextModel
+ elif model_class == "T5EncoderModel":
+ from transformers import T5EncoderModel
+
+ return T5EncoderModel
+ else:
+ raise ValueError(f"{model_class} is not supported.")
+
+
+def parse_args(input_args=None):
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--variant",
+ type=str,
+ default=None,
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
+ )
+ parser.add_argument(
+ "--dataset_name",
+ type=str,
+ default=None,
+ help=(
+ "The name of the Dataset (from the HuggingFace hub) containing the training data of instance images (could be your own, possibly private,"
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
+ " or to a folder containing files that 🤗 Datasets can understand."
+ ),
+ )
+ parser.add_argument(
+ "--dataset_config_name",
+ type=str,
+ default=None,
+ help="The config of the Dataset, leave as None if there's only one config.",
+ )
+ parser.add_argument(
+ "--instance_data_dir",
+ type=str,
+ default=None,
+ help=("A folder containing the training data. "),
+ )
+
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+
+ parser.add_argument(
+ "--image_column",
+ type=str,
+ default="image",
+ help="The column of the dataset containing the target image. By "
+ "default, the standard Image Dataset maps out 'file_name' "
+ "to 'image'.",
+ )
+ parser.add_argument(
+ "--caption_column",
+ type=str,
+ default=None,
+ help="The column of the dataset containing the instance prompt for each image",
+ )
+
+ parser.add_argument("--repeats", type=int, default=1, help="How many times to repeat the training data.")
+
+ parser.add_argument(
+ "--class_data_dir",
+ type=str,
+ default=None,
+ required=False,
+ help="A folder containing the training data of class images.",
+ )
+ parser.add_argument(
+ "--instance_prompt",
+ type=str,
+ default=None,
+ required=True,
+ help="The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'",
+ )
+ parser.add_argument(
+ "--class_prompt",
+ type=str,
+ default=None,
+ help="The prompt to specify images in the same class as provided instance images.",
+ )
+ parser.add_argument(
+ "--max_sequence_length",
+ type=int,
+ default=512,
+ help="Maximum sequence length to use with with the T5 text encoder",
+ )
+ parser.add_argument(
+ "--validation_prompt",
+ type=str,
+ default=None,
+ help="A prompt that is used during validation to verify that the model is learning.",
+ )
+ parser.add_argument(
+ "--num_validation_images",
+ type=int,
+ default=4,
+ help="Number of images that should be generated during validation with `validation_prompt`.",
+ )
+ parser.add_argument(
+ "--validation_epochs",
+ type=int,
+ default=50,
+ help=(
+ "Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt"
+ " `args.validation_prompt` multiple times: `args.num_validation_images`."
+ ),
+ )
+ parser.add_argument(
+ "--rank",
+ type=int,
+ default=4,
+ help=("The dimension of the LoRA update matrices."),
+ )
+ parser.add_argument(
+ "--with_prior_preservation",
+ default=False,
+ action="store_true",
+ help="Flag to add prior preservation loss.",
+ )
+ parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.")
+ parser.add_argument(
+ "--num_class_images",
+ type=int,
+ default=100,
+ help=(
+ "Minimal class images for prior preservation loss. If there are not enough images already present in"
+ " class_data_dir, additional images will be sampled with class_prompt."
+ ),
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="flux-dreambooth-lora",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--center_crop",
+ default=False,
+ action="store_true",
+ help=(
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
+ " cropped. The images will be resized to the resolution first before cropping."
+ ),
+ )
+ parser.add_argument(
+ "--random_flip",
+ action="store_true",
+ help="whether to randomly flip images horizontally",
+ )
+ parser.add_argument(
+ "--train_text_encoder",
+ action="store_true",
+ help="Whether to train the text encoder. If set, the text encoder should be float32 precision.",
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument(
+ "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=1)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
+ " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+
+ parser.add_argument(
+ "--guidance_scale",
+ type=float,
+ default=3.5,
+ help="the FLUX.1 dev variant is a guidance distilled model",
+ )
+
+ parser.add_argument(
+ "--text_encoder_lr",
+ type=float,
+ default=5e-6,
+ help="Text encoder learning rate to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--lr_num_cycles",
+ type=int,
+ default=1,
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
+ )
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ parser.add_argument(
+ "--weighting_scheme",
+ type=str,
+ default="none",
+ choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"],
+ help=('We default to the "none" weighting scheme for uniform sampling and uniform loss'),
+ )
+ parser.add_argument(
+ "--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme."
+ )
+ parser.add_argument(
+ "--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme."
+ )
+ parser.add_argument(
+ "--mode_scale",
+ type=float,
+ default=1.29,
+ help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.",
+ )
+ parser.add_argument(
+ "--optimizer",
+ type=str,
+ default="AdamW",
+ help=('The optimizer type to use. Choose between ["AdamW", "prodigy"]'),
+ )
+
+ parser.add_argument(
+ "--use_8bit_adam",
+ action="store_true",
+ help="Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW",
+ )
+
+ parser.add_argument(
+ "--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam and Prodigy optimizers."
+ )
+ parser.add_argument(
+ "--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam and Prodigy optimizers."
+ )
+ parser.add_argument(
+ "--prodigy_beta3",
+ type=float,
+ default=None,
+ help="coefficients for computing the Prodigy stepsize using running averages. If set to None, "
+ "uses the value of square root of beta2. Ignored if optimizer is adamW",
+ )
+ parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params")
+ parser.add_argument(
+ "--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder"
+ )
+
+ parser.add_argument(
+ "--adam_epsilon",
+ type=float,
+ default=1e-08,
+ help="Epsilon value for the Adam optimizer and Prodigy optimizers.",
+ )
+
+ parser.add_argument(
+ "--prodigy_use_bias_correction",
+ type=bool,
+ default=True,
+ help="Turn on Adam's bias correction. True by default. Ignored if optimizer is adamW",
+ )
+ parser.add_argument(
+ "--prodigy_safeguard_warmup",
+ type=bool,
+ default=True,
+ help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. "
+ "Ignored if optimizer is adamW",
+ )
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument(
+ "--cache_latents",
+ action="store_true",
+ default=False,
+ help="Cache the VAE latents",
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument(
+ "--upcast_before_saving",
+ action="store_true",
+ default=False,
+ help=(
+ "Whether to upcast the trained transformer layers to float32 before saving (at the end of training). "
+ "Defaults to precision dtype used for training to save memory"
+ ),
+ )
+ parser.add_argument(
+ "--prior_generation_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp32", "fp16", "bf16"],
+ help=(
+ "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32."
+ ),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+
+ if input_args is not None:
+ args = parser.parse_args(input_args)
+ else:
+ args = parser.parse_args()
+
+ if args.dataset_name is None and args.instance_data_dir is None:
+ raise ValueError("Specify either `--dataset_name` or `--instance_data_dir`")
+
+ if args.dataset_name is not None and args.instance_data_dir is not None:
+ raise ValueError("Specify only one of `--dataset_name` or `--instance_data_dir`")
+
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ if args.with_prior_preservation:
+ if args.class_data_dir is None:
+ raise ValueError("You must specify a data directory for class images.")
+ if args.class_prompt is None:
+ raise ValueError("You must specify prompt for class images.")
+ else:
+ # logger is not available yet
+ if args.class_data_dir is not None:
+ warnings.warn("You need not use --class_data_dir without --with_prior_preservation.")
+ if args.class_prompt is not None:
+ warnings.warn("You need not use --class_prompt without --with_prior_preservation.")
+
+ return args
+
+
+class DreamBoothDataset(Dataset):
+ """
+ A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
+ It pre-processes the images.
+ """
+
+ def __init__(
+ self,
+ instance_data_root,
+ instance_prompt,
+ class_prompt,
+ class_data_root=None,
+ class_num=None,
+ size=1024,
+ repeats=1,
+ center_crop=False,
+ ):
+ self.size = size
+ self.center_crop = center_crop
+
+ self.instance_prompt = instance_prompt
+ self.custom_instance_prompts = None
+ self.class_prompt = class_prompt
+
+ # if --dataset_name is provided or a metadata jsonl file is provided in the local --instance_data directory,
+ # we load the training data using load_dataset
+ if args.dataset_name is not None:
+ try:
+ from datasets import load_dataset
+ except ImportError:
+ raise ImportError(
+ "You are trying to load your data using the datasets library. If you wish to train using custom "
+ "captions please install the datasets library: `pip install datasets`. If you wish to load a "
+ "local folder containing images only, specify --instance_data_dir instead."
+ )
+ # Downloading and loading a dataset from the hub.
+ # See more about loading custom images at
+ # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script
+ dataset = load_dataset(
+ args.dataset_name,
+ args.dataset_config_name,
+ cache_dir=args.cache_dir,
+ )
+ # Preprocessing the datasets.
+ column_names = dataset["train"].column_names
+
+ # 6. Get the column names for input/target.
+ if args.image_column is None:
+ image_column = column_names[0]
+ logger.info(f"image column defaulting to {image_column}")
+ else:
+ image_column = args.image_column
+ if image_column not in column_names:
+ raise ValueError(
+ f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
+ )
+ instance_images = dataset["train"][image_column]
+
+ if args.caption_column is None:
+ logger.info(
+ "No caption column provided, defaulting to instance_prompt for all images. If your dataset "
+ "contains captions/prompts for the images, make sure to specify the "
+ "column as --caption_column"
+ )
+ self.custom_instance_prompts = None
+ else:
+ if args.caption_column not in column_names:
+ raise ValueError(
+ f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
+ )
+ custom_instance_prompts = dataset["train"][args.caption_column]
+ # create final list of captions according to --repeats
+ self.custom_instance_prompts = []
+ for caption in custom_instance_prompts:
+ self.custom_instance_prompts.extend(itertools.repeat(caption, repeats))
+ else:
+ self.instance_data_root = Path(instance_data_root)
+ if not self.instance_data_root.exists():
+ raise ValueError("Instance images root doesn't exists.")
+
+ instance_images = [Image.open(path) for path in list(Path(instance_data_root).iterdir())]
+ self.custom_instance_prompts = None
+
+ self.instance_images = []
+ for img in instance_images:
+ self.instance_images.extend(itertools.repeat(img, repeats))
+
+ self.pixel_values = []
+ train_resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR)
+ train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size)
+ train_flip = transforms.RandomHorizontalFlip(p=1.0)
+ train_transforms = transforms.Compose(
+ [
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+ )
+ for image in self.instance_images:
+ image = exif_transpose(image)
+ if not image.mode == "RGB":
+ image = image.convert("RGB")
+ image = train_resize(image)
+ if args.random_flip and random.random() < 0.5:
+ # flip
+ image = train_flip(image)
+ if args.center_crop:
+ y1 = max(0, int(round((image.height - args.resolution) / 2.0)))
+ x1 = max(0, int(round((image.width - args.resolution) / 2.0)))
+ image = train_crop(image)
+ else:
+ y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution))
+ image = crop(image, y1, x1, h, w)
+ image = train_transforms(image)
+ self.pixel_values.append(image)
+
+ self.num_instance_images = len(self.instance_images)
+ self._length = self.num_instance_images
+
+ if class_data_root is not None:
+ self.class_data_root = Path(class_data_root)
+ self.class_data_root.mkdir(parents=True, exist_ok=True)
+ self.class_images_path = list(self.class_data_root.iterdir())
+ if class_num is not None:
+ self.num_class_images = min(len(self.class_images_path), class_num)
+ else:
+ self.num_class_images = len(self.class_images_path)
+ self._length = max(self.num_class_images, self.num_instance_images)
+ else:
+ self.class_data_root = None
+
+ self.image_transforms = transforms.Compose(
+ [
+ transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
+ transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+ )
+
+ def __len__(self):
+ return self._length
+
+ def __getitem__(self, index):
+ example = {}
+ instance_image = self.pixel_values[index % self.num_instance_images]
+ example["instance_images"] = instance_image
+
+ if self.custom_instance_prompts:
+ caption = self.custom_instance_prompts[index % self.num_instance_images]
+ if caption:
+ example["instance_prompt"] = caption
+ else:
+ example["instance_prompt"] = self.instance_prompt
+
+ else: # custom prompts were provided, but length does not match size of image dataset
+ example["instance_prompt"] = self.instance_prompt
+
+ if self.class_data_root:
+ class_image = Image.open(self.class_images_path[index % self.num_class_images])
+ class_image = exif_transpose(class_image)
+
+ if not class_image.mode == "RGB":
+ class_image = class_image.convert("RGB")
+ example["class_images"] = self.image_transforms(class_image)
+ example["class_prompt"] = self.class_prompt
+
+ return example
+
+
+def collate_fn(examples, with_prior_preservation=False):
+ pixel_values = [example["instance_images"] for example in examples]
+ prompts = [example["instance_prompt"] for example in examples]
+
+ # Concat class and instance examples for prior preservation.
+ # We do this to avoid doing two forward passes.
+ if with_prior_preservation:
+ pixel_values += [example["class_images"] for example in examples]
+ prompts += [example["class_prompt"] for example in examples]
+
+ pixel_values = torch.stack(pixel_values)
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
+
+ batch = {"pixel_values": pixel_values, "prompts": prompts}
+ return batch
+
+
+class PromptDataset(Dataset):
+ "A simple dataset to prepare the prompts to generate class images on multiple GPUs."
+
+ def __init__(self, prompt, num_samples):
+ self.prompt = prompt
+ self.num_samples = num_samples
+
+ def __len__(self):
+ return self.num_samples
+
+ def __getitem__(self, index):
+ example = {}
+ example["prompt"] = self.prompt
+ example["index"] = index
+ return example
+
+
+def tokenize_prompt(tokenizer, prompt, max_sequence_length):
+ text_inputs = tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ return_length=False,
+ return_overflowing_tokens=False,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ return text_input_ids
+
+
+def _encode_prompt_with_t5(
+ text_encoder,
+ tokenizer,
+ max_sequence_length=512,
+ prompt=None,
+ num_images_per_prompt=1,
+ device=None,
+ text_input_ids=None,
+):
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ if tokenizer is not None:
+ text_inputs = tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ return_length=False,
+ return_overflowing_tokens=False,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ else:
+ if text_input_ids is None:
+ raise ValueError("text_input_ids must be provided when the tokenizer is not specified")
+
+ prompt_embeds = text_encoder(text_input_ids.to(device))[0]
+
+ dtype = text_encoder.dtype
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ _, seq_len, _ = prompt_embeds.shape
+
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ return prompt_embeds
+
+
+def _encode_prompt_with_clip(
+ text_encoder,
+ tokenizer,
+ prompt: str,
+ device=None,
+ text_input_ids=None,
+ num_images_per_prompt: int = 1,
+):
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ if tokenizer is not None:
+ text_inputs = tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=77,
+ truncation=True,
+ return_overflowing_tokens=False,
+ return_length=False,
+ return_tensors="pt",
+ )
+
+ text_input_ids = text_inputs.input_ids
+ else:
+ if text_input_ids is None:
+ raise ValueError("text_input_ids must be provided when the tokenizer is not specified")
+
+ prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False)
+
+ # Use pooled output of CLIPTextModel
+ prompt_embeds = prompt_embeds.pooler_output
+ prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device)
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
+
+ return prompt_embeds
+
+
+def encode_prompt(
+ text_encoders,
+ tokenizers,
+ prompt: str,
+ max_sequence_length,
+ device=None,
+ num_images_per_prompt: int = 1,
+ text_input_ids_list=None,
+):
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+ dtype = text_encoders[0].dtype
+
+ pooled_prompt_embeds = _encode_prompt_with_clip(
+ text_encoder=text_encoders[0],
+ tokenizer=tokenizers[0],
+ prompt=prompt,
+ device=device if device is not None else text_encoders[0].device,
+ num_images_per_prompt=num_images_per_prompt,
+ text_input_ids=text_input_ids_list[0] if text_input_ids_list else None,
+ )
+
+ prompt_embeds = _encode_prompt_with_t5(
+ text_encoder=text_encoders[1],
+ tokenizer=tokenizers[1],
+ max_sequence_length=max_sequence_length,
+ prompt=prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device if device is not None else text_encoders[1].device,
+ text_input_ids=text_input_ids_list[1] if text_input_ids_list else None,
+ )
+
+ text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
+ text_ids = text_ids.repeat(num_images_per_prompt, 1, 1)
+
+ return prompt_embeds, pooled_prompt_embeds, text_ids
+
+
+def main(args):
+ if args.report_to == "wandb" and args.hub_token is not None:
+ raise ValueError(
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
+ " Please use `huggingface-cli login` to authenticate with the Hub."
+ )
+
+ if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
+ # due to pytorch#99272, MPS does not yet support bfloat16.
+ raise ValueError(
+ "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
+ )
+
+ logging_dir = Path(args.output_dir, args.logging_dir)
+
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+ kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ kwargs_handlers=[kwargs],
+ )
+
+ # Disable AMP for MPS.
+ if torch.backends.mps.is_available():
+ accelerator.native_amp = False
+
+ if args.report_to == "wandb":
+ if not is_wandb_available():
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Generate class images if prior preservation is enabled.
+ if args.with_prior_preservation:
+ class_images_dir = Path(args.class_data_dir)
+ if not class_images_dir.exists():
+ class_images_dir.mkdir(parents=True)
+ cur_class_images = len(list(class_images_dir.iterdir()))
+
+ if cur_class_images < args.num_class_images:
+ has_supported_fp16_accelerator = torch.cuda.is_available() or torch.backends.mps.is_available()
+ torch_dtype = torch.float16 if has_supported_fp16_accelerator else torch.float32
+ if args.prior_generation_precision == "fp32":
+ torch_dtype = torch.float32
+ elif args.prior_generation_precision == "fp16":
+ torch_dtype = torch.float16
+ elif args.prior_generation_precision == "bf16":
+ torch_dtype = torch.bfloat16
+ pipeline = FluxPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ torch_dtype=torch_dtype,
+ revision=args.revision,
+ variant=args.variant,
+ )
+ pipeline.set_progress_bar_config(disable=True)
+
+ num_new_images = args.num_class_images - cur_class_images
+ logger.info(f"Number of class images to sample: {num_new_images}.")
+
+ sample_dataset = PromptDataset(args.class_prompt, num_new_images)
+ sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
+
+ sample_dataloader = accelerator.prepare(sample_dataloader)
+ pipeline.to(accelerator.device)
+
+ for example in tqdm(
+ sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
+ ):
+ images = pipeline(example["prompt"]).images
+
+ for i, image in enumerate(images):
+ hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest()
+ image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
+ image.save(image_filename)
+
+ del pipeline
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name,
+ exist_ok=True,
+ ).repo_id
+
+ # Load the tokenizers
+ tokenizer_one = CLIPTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="tokenizer",
+ revision=args.revision,
+ )
+ tokenizer_two = T5TokenizerFast.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="tokenizer_2",
+ revision=args.revision,
+ )
+
+ # import correct text encoder classes
+ text_encoder_cls_one = import_model_class_from_model_name_or_path(
+ args.pretrained_model_name_or_path, args.revision
+ )
+ text_encoder_cls_two = import_model_class_from_model_name_or_path(
+ args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2"
+ )
+
+ # Load scheduler and models
+ noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="scheduler"
+ )
+ noise_scheduler_copy = copy.deepcopy(noise_scheduler)
+ text_encoder_one, text_encoder_two = load_text_encoders(text_encoder_cls_one, text_encoder_cls_two)
+ vae = AutoencoderKL.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="vae",
+ revision=args.revision,
+ variant=args.variant,
+ )
+ transformer = FluxTransformer2DModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="transformer", revision=args.revision, variant=args.variant
+ )
+
+ # We only train the additional adapter LoRA layers
+ transformer.requires_grad_(False)
+ vae.requires_grad_(False)
+ text_encoder_one.requires_grad_(False)
+ text_encoder_two.requires_grad_(False)
+
+ # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
+ # as these weights are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:
+ # due to pytorch#99272, MPS does not yet support bfloat16.
+ raise ValueError(
+ "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
+ )
+
+ vae.to(accelerator.device, dtype=weight_dtype)
+ transformer.to(accelerator.device, dtype=weight_dtype)
+ text_encoder_one.to(accelerator.device, dtype=weight_dtype)
+ text_encoder_two.to(accelerator.device, dtype=weight_dtype)
+
+ if args.gradient_checkpointing:
+ transformer.enable_gradient_checkpointing()
+ if args.train_text_encoder:
+ text_encoder_one.gradient_checkpointing_enable()
+
+ # now we will add new LoRA weights to the attention layers
+ transformer_lora_config = LoraConfig(
+ r=args.rank,
+ lora_alpha=args.rank,
+ init_lora_weights="gaussian",
+ target_modules=["to_k", "to_q", "to_v", "to_out.0"],
+ )
+ transformer.add_adapter(transformer_lora_config)
+ if args.train_text_encoder:
+ text_lora_config = LoraConfig(
+ r=args.rank,
+ lora_alpha=args.rank,
+ init_lora_weights="gaussian",
+ target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
+ )
+ text_encoder_one.add_adapter(text_lora_config)
+
+ def unwrap_model(model):
+ model = accelerator.unwrap_model(model)
+ model = model._orig_mod if is_compiled_module(model) else model
+ return model
+
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
+ def save_model_hook(models, weights, output_dir):
+ if accelerator.is_main_process:
+ transformer_lora_layers_to_save = None
+ text_encoder_one_lora_layers_to_save = None
+
+ for model in models:
+ if isinstance(model, type(unwrap_model(transformer))):
+ transformer_lora_layers_to_save = get_peft_model_state_dict(model)
+ elif isinstance(model, type(unwrap_model(text_encoder_one))):
+ text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model)
+ else:
+ raise ValueError(f"unexpected save model: {model.__class__}")
+
+ # make sure to pop weight so that corresponding model is not saved again
+ weights.pop()
+
+ FluxPipeline.save_lora_weights(
+ output_dir,
+ transformer_lora_layers=transformer_lora_layers_to_save,
+ text_encoder_lora_layers=text_encoder_one_lora_layers_to_save,
+ )
+
+ def load_model_hook(models, input_dir):
+ transformer_ = None
+ text_encoder_one_ = None
+
+ while len(models) > 0:
+ model = models.pop()
+
+ if isinstance(model, type(unwrap_model(transformer))):
+ transformer_ = model
+ elif isinstance(model, type(unwrap_model(text_encoder_one))):
+ text_encoder_one_ = model
+ else:
+ raise ValueError(f"unexpected save model: {model.__class__}")
+
+ lora_state_dict = FluxPipeline.lora_state_dict(input_dir)
+
+ transformer_state_dict = {
+ f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.")
+ }
+ transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
+ incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
+ if incompatible_keys is not None:
+ # check only for unexpected keys
+ unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
+ if unexpected_keys:
+ logger.warning(
+ f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
+ f" {unexpected_keys}. "
+ )
+ if args.train_text_encoder:
+ # Do we need to call `scale_lora_layers()` here?
+ _set_state_dict_into_text_encoder(lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_one_)
+
+ # Make sure the trainable params are in float32. This is again needed since the base models
+ # are in `weight_dtype`. More details:
+ # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
+ if args.mixed_precision == "fp16":
+ models = [transformer_]
+ if args.train_text_encoder:
+ models.extend([text_encoder_one_])
+ # only upcast trainable parameters (LoRA) into fp32
+ cast_training_params(models)
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32 and torch.cuda.is_available():
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ # Make sure the trainable params are in float32.
+ if args.mixed_precision == "fp16":
+ models = [transformer]
+ if args.train_text_encoder:
+ models.extend([text_encoder_one])
+ # only upcast trainable parameters (LoRA) into fp32
+ cast_training_params(models, dtype=torch.float32)
+
+ transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters()))
+ if args.train_text_encoder:
+ text_lora_parameters_one = list(filter(lambda p: p.requires_grad, text_encoder_one.parameters()))
+
+ # Optimization parameters
+ transformer_parameters_with_lr = {"params": transformer_lora_parameters, "lr": args.learning_rate}
+ if args.train_text_encoder:
+ # different learning rate for text encoder and unet
+ text_parameters_one_with_lr = {
+ "params": text_lora_parameters_one,
+ "weight_decay": args.adam_weight_decay_text_encoder,
+ "lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate,
+ }
+ params_to_optimize = [
+ transformer_parameters_with_lr,
+ text_parameters_one_with_lr,
+ ]
+ else:
+ params_to_optimize = [transformer_parameters_with_lr]
+
+ # Optimizer creation
+ if not (args.optimizer.lower() == "prodigy" or args.optimizer.lower() == "adamw"):
+ logger.warning(
+ f"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy]."
+ "Defaulting to adamW"
+ )
+ args.optimizer = "adamw"
+
+ if args.use_8bit_adam and not args.optimizer.lower() == "adamw":
+ logger.warning(
+ f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was "
+ f"set to {args.optimizer.lower()}"
+ )
+
+ if args.optimizer.lower() == "adamw":
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
+ )
+
+ optimizer_class = bnb.optim.AdamW8bit
+ else:
+ optimizer_class = torch.optim.AdamW
+
+ optimizer = optimizer_class(
+ params_to_optimize,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ if args.optimizer.lower() == "prodigy":
+ try:
+ import prodigyopt
+ except ImportError:
+ raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`")
+
+ optimizer_class = prodigyopt.Prodigy
+
+ if args.learning_rate <= 0.1:
+ logger.warning(
+ "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0"
+ )
+ if args.train_text_encoder and args.text_encoder_lr:
+ logger.warning(
+ f"Learning rates were provided both for the transformer and the text encoder- e.g. text_encoder_lr:"
+ f" {args.text_encoder_lr} and learning_rate: {args.learning_rate}. "
+ f"When using prodigy only learning_rate is used as the initial learning rate."
+ )
+ # changes the learning rate of text_encoder_parameters_one and text_encoder_parameters_two to be
+ # --learning_rate
+ params_to_optimize[1]["lr"] = args.learning_rate
+ params_to_optimize[2]["lr"] = args.learning_rate
+
+ optimizer = optimizer_class(
+ params_to_optimize,
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ beta3=args.prodigy_beta3,
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ decouple=args.prodigy_decouple,
+ use_bias_correction=args.prodigy_use_bias_correction,
+ safeguard_warmup=args.prodigy_safeguard_warmup,
+ )
+
+ # Dataset and DataLoaders creation:
+ train_dataset = DreamBoothDataset(
+ instance_data_root=args.instance_data_dir,
+ instance_prompt=args.instance_prompt,
+ class_prompt=args.class_prompt,
+ class_data_root=args.class_data_dir if args.with_prior_preservation else None,
+ class_num=args.num_class_images,
+ size=args.resolution,
+ repeats=args.repeats,
+ center_crop=args.center_crop,
+ )
+
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset,
+ batch_size=args.train_batch_size,
+ shuffle=True,
+ collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),
+ num_workers=args.dataloader_num_workers,
+ )
+
+ if not args.train_text_encoder:
+ tokenizers = [tokenizer_one, tokenizer_two]
+ text_encoders = [text_encoder_one, text_encoder_two]
+
+ def compute_text_embeddings(prompt, text_encoders, tokenizers):
+ with torch.no_grad():
+ prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt(
+ text_encoders, tokenizers, prompt, args.max_sequence_length
+ )
+ prompt_embeds = prompt_embeds.to(accelerator.device)
+ pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device)
+ text_ids = text_ids.to(accelerator.device)
+ return prompt_embeds, pooled_prompt_embeds, text_ids
+
+ # If no type of tuning is done on the text_encoder and custom instance prompts are NOT
+ # provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid
+ # the redundant encoding.
+ if not args.train_text_encoder and not train_dataset.custom_instance_prompts:
+ instance_prompt_hidden_states, instance_pooled_prompt_embeds, instance_text_ids = compute_text_embeddings(
+ args.instance_prompt, text_encoders, tokenizers
+ )
+
+ # Handle class prompt for prior-preservation.
+ if args.with_prior_preservation:
+ if not args.train_text_encoder:
+ class_prompt_hidden_states, class_pooled_prompt_embeds, class_text_ids = compute_text_embeddings(
+ args.class_prompt, text_encoders, tokenizers
+ )
+
+ # Clear the memory here
+ if not args.train_text_encoder and not train_dataset.custom_instance_prompts:
+ clear_objs_and_retain_memory([tokenizers, text_encoders, text_encoder_one, text_encoder_two])
+
+ # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
+ # pack the statically computed variables appropriately here. This is so that we don't
+ # have to pass them to the dataloader.
+
+ if not train_dataset.custom_instance_prompts:
+ if not args.train_text_encoder:
+ prompt_embeds = instance_prompt_hidden_states
+ pooled_prompt_embeds = instance_pooled_prompt_embeds
+ text_ids = instance_text_ids
+ if args.with_prior_preservation:
+ prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0)
+ pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, class_pooled_prompt_embeds], dim=0)
+ text_ids = torch.cat([text_ids, class_text_ids], dim=0)
+ # if we're optimizing the text encoder (both if instance prompt is used for all images or custom prompts)
+ # we need to tokenize and encode the batch prompts on all training steps
+ else:
+ tokens_one = tokenize_prompt(tokenizer_one, args.instance_prompt, max_sequence_length=77)
+ tokens_two = tokenize_prompt(
+ tokenizer_two, args.instance_prompt, max_sequence_length=args.max_sequence_length
+ )
+ if args.with_prior_preservation:
+ class_tokens_one = tokenize_prompt(tokenizer_one, args.class_prompt, max_sequence_length=77)
+ class_tokens_two = tokenize_prompt(
+ tokenizer_two, args.class_prompt, max_sequence_length=args.max_sequence_length
+ )
+ tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0)
+ tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0)
+
+ vae_config_shift_factor = vae.config.shift_factor
+ vae_config_scaling_factor = vae.config.scaling_factor
+ vae_config_block_out_channels = vae.config.block_out_channels
+ if args.cache_latents:
+ latents_cache = []
+ for batch in tqdm(train_dataloader, desc="Caching latents"):
+ with torch.no_grad():
+ batch["pixel_values"] = batch["pixel_values"].to(
+ accelerator.device, non_blocking=True, dtype=weight_dtype
+ )
+ latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist)
+
+ if args.validation_prompt is None:
+ clear_objs_and_retain_memory([vae])
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
+ num_cycles=args.lr_num_cycles,
+ power=args.lr_power,
+ )
+
+ # Prepare everything with our `accelerator`.
+ if args.train_text_encoder:
+ (
+ transformer,
+ text_encoder_one,
+ optimizer,
+ train_dataloader,
+ lr_scheduler,
+ ) = accelerator.prepare(
+ transformer,
+ text_encoder_one,
+ optimizer,
+ train_dataloader,
+ lr_scheduler,
+ )
+ else:
+ transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ transformer, optimizer, train_dataloader, lr_scheduler
+ )
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ tracker_name = "dreambooth-flux-dev-lora"
+ accelerator.init_trackers(tracker_name, config=vars(args))
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the mos recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ initial_global_step = 0
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ initial_global_step = global_step
+ first_epoch = global_step // num_update_steps_per_epoch
+
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
+
+ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
+ sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype)
+ schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device)
+ timesteps = timesteps.to(accelerator.device)
+ step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
+
+ sigma = sigmas[step_indices].flatten()
+ while len(sigma.shape) < n_dim:
+ sigma = sigma.unsqueeze(-1)
+ return sigma
+
+ for epoch in range(first_epoch, args.num_train_epochs):
+ transformer.train()
+ if args.train_text_encoder:
+ text_encoder_one.train()
+ # set top parameter requires_grad = True for gradient checkpointing works
+ accelerator.unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True)
+
+ for step, batch in enumerate(train_dataloader):
+ models_to_accumulate = [transformer]
+ if args.train_text_encoder:
+ models_to_accumulate.extend([text_encoder_one])
+ with accelerator.accumulate(models_to_accumulate):
+ prompts = batch["prompts"]
+
+ # encode batch prompts when custom prompts are provided for each image -
+ if train_dataset.custom_instance_prompts:
+ if not args.train_text_encoder:
+ prompt_embeds, pooled_prompt_embeds, text_ids = compute_text_embeddings(
+ prompts, text_encoders, tokenizers
+ )
+ else:
+ tokens_one = tokenize_prompt(tokenizer_one, prompts, max_sequence_length=77)
+ tokens_two = tokenize_prompt(
+ tokenizer_two, prompts, max_sequence_length=args.max_sequence_length
+ )
+ prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt(
+ text_encoders=[text_encoder_one, text_encoder_two],
+ tokenizers=[None, None],
+ text_input_ids_list=[tokens_one, tokens_two],
+ max_sequence_length=args.max_sequence_length,
+ device=accelerator.device,
+ prompt=prompts,
+ )
+ else:
+ if args.train_text_encoder:
+ prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt(
+ text_encoders=[text_encoder_one, text_encoder_two],
+ tokenizers=[None, None],
+ text_input_ids_list=[tokens_one, tokens_two],
+ max_sequence_length=args.max_sequence_length,
+ device=accelerator.device,
+ prompt=args.instance_prompt,
+ )
+
+ # Convert images to latent space
+ if args.cache_latents:
+ model_input = latents_cache[step].sample()
+ else:
+ pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
+ model_input = vae.encode(pixel_values).latent_dist.sample()
+ model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor
+ model_input = model_input.to(dtype=weight_dtype)
+
+ vae_scale_factor = 2 ** (len(vae_config_block_out_channels))
+
+ latent_image_ids = FluxPipeline._prepare_latent_image_ids(
+ model_input.shape[0],
+ model_input.shape[2],
+ model_input.shape[3],
+ accelerator.device,
+ weight_dtype,
+ )
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(model_input)
+ bsz = model_input.shape[0]
+
+ # Sample a random timestep for each image
+ # for weighting schemes where we sample timesteps non-uniformly
+ u = compute_density_for_timestep_sampling(
+ weighting_scheme=args.weighting_scheme,
+ batch_size=bsz,
+ logit_mean=args.logit_mean,
+ logit_std=args.logit_std,
+ mode_scale=args.mode_scale,
+ )
+ indices = (u * noise_scheduler_copy.config.num_train_timesteps).long()
+ timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device)
+
+ # Add noise according to flow matching.
+ # zt = (1 - texp) * x + texp * z1
+ sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)
+ noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise
+
+ packed_noisy_model_input = FluxPipeline._pack_latents(
+ noisy_model_input,
+ batch_size=model_input.shape[0],
+ num_channels_latents=model_input.shape[1],
+ height=model_input.shape[2],
+ width=model_input.shape[3],
+ )
+
+ # handle guidance
+ if transformer.config.guidance_embeds:
+ guidance = torch.tensor([args.guidance_scale], device=accelerator.device)
+ guidance = guidance.expand(model_input.shape[0])
+ else:
+ guidance = None
+
+ # Predict the noise residual
+ model_pred = transformer(
+ hidden_states=packed_noisy_model_input,
+ # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
+ timestep=timesteps / 1000,
+ guidance=guidance,
+ pooled_projections=pooled_prompt_embeds,
+ encoder_hidden_states=prompt_embeds,
+ txt_ids=text_ids,
+ img_ids=latent_image_ids,
+ return_dict=False,
+ )[0]
+ model_pred = FluxPipeline._unpack_latents(
+ model_pred,
+ height=int(model_input.shape[2] * vae_scale_factor / 2),
+ width=int(model_input.shape[3] * vae_scale_factor / 2),
+ vae_scale_factor=vae_scale_factor,
+ )
+
+ # these weighting schemes use a uniform timestep sampling
+ # and instead post-weight the loss
+ weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
+
+ # flow matching loss
+ target = noise - model_input
+
+ if args.with_prior_preservation:
+ # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
+ model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
+ target, target_prior = torch.chunk(target, 2, dim=0)
+
+ # Compute prior loss
+ prior_loss = torch.mean(
+ (weighting.float() * (model_pred_prior.float() - target_prior.float()) ** 2).reshape(
+ target_prior.shape[0], -1
+ ),
+ 1,
+ )
+ prior_loss = prior_loss.mean()
+
+ # Compute regular loss.
+ loss = torch.mean(
+ (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1),
+ 1,
+ )
+ loss = loss.mean()
+
+ if args.with_prior_preservation:
+ # Add the prior loss to the instance loss.
+ loss = loss + args.prior_loss_weight * prior_loss
+
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ params_to_clip = (
+ itertools.chain(transformer.parameters(), text_encoder_one.parameters())
+ if args.train_text_encoder
+ else transformer.parameters()
+ )
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
+
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ global_step += 1
+
+ if accelerator.is_main_process:
+ if global_step % args.checkpointing_steps == 0:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+ accelerator.log(logs, step=global_step)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ if accelerator.is_main_process:
+ if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
+ # create pipeline
+ if not args.train_text_encoder:
+ text_encoder_one, text_encoder_two = load_text_encoders(text_encoder_cls_one, text_encoder_cls_two)
+ pipeline = FluxPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ vae=vae,
+ text_encoder=accelerator.unwrap_model(text_encoder_one),
+ text_encoder_2=accelerator.unwrap_model(text_encoder_two),
+ transformer=accelerator.unwrap_model(transformer),
+ revision=args.revision,
+ variant=args.variant,
+ torch_dtype=weight_dtype,
+ )
+ pipeline_args = {"prompt": args.validation_prompt}
+ images = log_validation(
+ pipeline=pipeline,
+ args=args,
+ accelerator=accelerator,
+ pipeline_args=pipeline_args,
+ epoch=epoch,
+ torch_dtype=weight_dtype,
+ )
+ if not args.train_text_encoder:
+ clear_objs_and_retain_memory([text_encoder_one, text_encoder_two])
+
+ # Save the lora layers
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ transformer = unwrap_model(transformer)
+ if args.upcast_before_saving:
+ transformer.to(torch.float32)
+ else:
+ transformer = transformer.to(weight_dtype)
+ transformer_lora_layers = get_peft_model_state_dict(transformer)
+
+ if args.train_text_encoder:
+ text_encoder_one = unwrap_model(text_encoder_one)
+ text_encoder_lora_layers = get_peft_model_state_dict(text_encoder_one.to(torch.float32))
+ else:
+ text_encoder_lora_layers = None
+
+ FluxPipeline.save_lora_weights(
+ save_directory=args.output_dir,
+ transformer_lora_layers=transformer_lora_layers,
+ text_encoder_lora_layers=text_encoder_lora_layers,
+ )
+
+ # Final inference
+ # Load previous pipeline
+ pipeline = FluxPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ revision=args.revision,
+ variant=args.variant,
+ torch_dtype=weight_dtype,
+ )
+ # load attention processors
+ pipeline.load_lora_weights(args.output_dir)
+
+ # run inference
+ images = []
+ if args.validation_prompt and args.num_validation_images > 0:
+ pipeline_args = {"prompt": args.validation_prompt}
+ images = log_validation(
+ pipeline=pipeline,
+ args=args,
+ accelerator=accelerator,
+ pipeline_args=pipeline_args,
+ epoch=epoch,
+ is_final_validation=True,
+ torch_dtype=weight_dtype,
+ )
+
+ if args.push_to_hub:
+ save_model_card(
+ repo_id,
+ images=images,
+ base_model=args.pretrained_model_name_or_path,
+ train_text_encoder=args.train_text_encoder,
+ instance_prompt=args.instance_prompt,
+ validation_prompt=args.validation_prompt,
+ repo_folder=args.output_dir,
+ )
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ main(args)
diff --git a/diffusers/examples/dreambooth/train_dreambooth_lora_sd3.py b/diffusers/examples/dreambooth/train_dreambooth_lora_sd3.py
new file mode 100644
index 0000000000000000000000000000000000000000..3060813bbbdcbed164c1de1e9fd433463e30d2fe
--- /dev/null
+++ b/diffusers/examples/dreambooth/train_dreambooth_lora_sd3.py
@@ -0,0 +1,1870 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+import argparse
+import copy
+import itertools
+import logging
+import math
+import os
+import random
+import shutil
+import warnings
+from contextlib import nullcontext
+from pathlib import Path
+
+import numpy as np
+import torch
+import torch.utils.checkpoint
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
+from huggingface_hub import create_repo, upload_folder
+from huggingface_hub.utils import insecure_hashlib
+from peft import LoraConfig, set_peft_model_state_dict
+from peft.utils import get_peft_model_state_dict
+from PIL import Image
+from PIL.ImageOps import exif_transpose
+from torch.utils.data import Dataset
+from torchvision import transforms
+from torchvision.transforms.functional import crop
+from tqdm.auto import tqdm
+from transformers import CLIPTokenizer, PretrainedConfig, T5TokenizerFast
+
+import diffusers
+from diffusers import (
+ AutoencoderKL,
+ FlowMatchEulerDiscreteScheduler,
+ SD3Transformer2DModel,
+ StableDiffusion3Pipeline,
+)
+from diffusers.optimization import get_scheduler
+from diffusers.training_utils import (
+ _set_state_dict_into_text_encoder,
+ cast_training_params,
+ clear_objs_and_retain_memory,
+ compute_density_for_timestep_sampling,
+ compute_loss_weighting_for_sd3,
+)
+from diffusers.utils import (
+ check_min_version,
+ convert_unet_state_dict_to_peft,
+ is_wandb_available,
+)
+from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
+from diffusers.utils.torch_utils import is_compiled_module
+
+
+if is_wandb_available():
+ import wandb
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.31.0.dev0")
+
+logger = get_logger(__name__)
+
+
+def save_model_card(
+ repo_id: str,
+ images=None,
+ base_model: str = None,
+ train_text_encoder=False,
+ instance_prompt=None,
+ validation_prompt=None,
+ repo_folder=None,
+):
+ widget_dict = []
+ if images is not None:
+ for i, image in enumerate(images):
+ image.save(os.path.join(repo_folder, f"image_{i}.png"))
+ widget_dict.append(
+ {"text": validation_prompt if validation_prompt else " ", "output": {"url": f"image_{i}.png"}}
+ )
+
+ model_description = f"""
+# SD3 DreamBooth LoRA - {repo_id}
+
+
+
+## Model description
+
+These are {repo_id} DreamBooth LoRA weights for {base_model}.
+
+The weights were trained using [DreamBooth](https://dreambooth.github.io/) with the [SD3 diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/README_sd3.md).
+
+Was LoRA for the text encoder enabled? {train_text_encoder}.
+
+## Trigger words
+
+You should use `{instance_prompt}` to trigger the image generation.
+
+## Download model
+
+[Download the *.safetensors LoRA]({repo_id}/tree/main) in the Files & versions tab.
+
+## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers)
+
+```py
+from diffusers import AutoPipelineForText2Image
+import torch
+pipeline = AutoPipelineForText2Image.from_pretrained('stabilityai/stable-diffusion-3-medium-diffusers', torch_dtype=torch.float16).to('cuda')
+pipeline.load_lora_weights('{repo_id}', weight_name='pytorch_lora_weights.safetensors')
+image = pipeline('{validation_prompt if validation_prompt else instance_prompt}').images[0]
+```
+
+### Use it with UIs such as AUTOMATIC1111, Comfy UI, SD.Next, Invoke
+
+- **LoRA**: download **[`diffusers_lora_weights.safetensors` here 💾](/{repo_id}/blob/main/diffusers_lora_weights.safetensors)**.
+ - Rename it and place it on your `models/Lora` folder.
+ - On AUTOMATIC1111, load the LoRA by adding `` to your prompt. On ComfyUI just [load it as a regular LoRA](https://comfyanonymous.github.io/ComfyUI_examples/lora/).
+
+For more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters)
+
+## License
+
+Please adhere to the licensing terms as described [here](https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/LICENSE).
+"""
+ model_card = load_or_create_model_card(
+ repo_id_or_path=repo_id,
+ from_training=True,
+ license="openrail++",
+ base_model=base_model,
+ prompt=instance_prompt,
+ model_description=model_description,
+ widget=widget_dict,
+ )
+ tags = [
+ "text-to-image",
+ "diffusers-training",
+ "diffusers",
+ "lora",
+ "sd3",
+ "sd3-diffusers",
+ "template:sd-lora",
+ ]
+
+ model_card = populate_model_card(model_card, tags=tags)
+ model_card.save(os.path.join(repo_folder, "README.md"))
+
+
+def load_text_encoders(class_one, class_two, class_three):
+ text_encoder_one = class_one.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
+ )
+ text_encoder_two = class_two.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant
+ )
+ text_encoder_three = class_three.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder_3", revision=args.revision, variant=args.variant
+ )
+ return text_encoder_one, text_encoder_two, text_encoder_three
+
+
+def log_validation(
+ pipeline,
+ args,
+ accelerator,
+ pipeline_args,
+ epoch,
+ torch_dtype,
+ is_final_validation=False,
+):
+ logger.info(
+ f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
+ f" {args.validation_prompt}."
+ )
+ pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)
+ pipeline.set_progress_bar_config(disable=True)
+
+ # run inference
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
+ # autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()
+ autocast_ctx = nullcontext()
+
+ with autocast_ctx:
+ images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)]
+
+ for tracker in accelerator.trackers:
+ phase_name = "test" if is_final_validation else "validation"
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images(phase_name, np_images, epoch, dataformats="NHWC")
+ if tracker.name == "wandb":
+ tracker.log(
+ {
+ phase_name: [
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images)
+ ]
+ }
+ )
+
+ clear_objs_and_retain_memory(objs=[pipeline])
+
+ return images
+
+
+def import_model_class_from_model_name_or_path(
+ pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
+):
+ text_encoder_config = PretrainedConfig.from_pretrained(
+ pretrained_model_name_or_path, subfolder=subfolder, revision=revision
+ )
+ model_class = text_encoder_config.architectures[0]
+ if model_class == "CLIPTextModelWithProjection":
+ from transformers import CLIPTextModelWithProjection
+
+ return CLIPTextModelWithProjection
+ elif model_class == "T5EncoderModel":
+ from transformers import T5EncoderModel
+
+ return T5EncoderModel
+ else:
+ raise ValueError(f"{model_class} is not supported.")
+
+
+def parse_args(input_args=None):
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--variant",
+ type=str,
+ default=None,
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
+ )
+ parser.add_argument(
+ "--dataset_name",
+ type=str,
+ default=None,
+ help=(
+ "The name of the Dataset (from the HuggingFace hub) containing the training data of instance images (could be your own, possibly private,"
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
+ " or to a folder containing files that 🤗 Datasets can understand."
+ ),
+ )
+ parser.add_argument(
+ "--dataset_config_name",
+ type=str,
+ default=None,
+ help="The config of the Dataset, leave as None if there's only one config.",
+ )
+ parser.add_argument(
+ "--instance_data_dir",
+ type=str,
+ default=None,
+ help=("A folder containing the training data. "),
+ )
+
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+
+ parser.add_argument(
+ "--image_column",
+ type=str,
+ default="image",
+ help="The column of the dataset containing the target image. By "
+ "default, the standard Image Dataset maps out 'file_name' "
+ "to 'image'.",
+ )
+ parser.add_argument(
+ "--caption_column",
+ type=str,
+ default=None,
+ help="The column of the dataset containing the instance prompt for each image",
+ )
+
+ parser.add_argument("--repeats", type=int, default=1, help="How many times to repeat the training data.")
+
+ parser.add_argument(
+ "--class_data_dir",
+ type=str,
+ default=None,
+ required=False,
+ help="A folder containing the training data of class images.",
+ )
+ parser.add_argument(
+ "--instance_prompt",
+ type=str,
+ default=None,
+ required=True,
+ help="The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'",
+ )
+ parser.add_argument(
+ "--class_prompt",
+ type=str,
+ default=None,
+ help="The prompt to specify images in the same class as provided instance images.",
+ )
+ parser.add_argument(
+ "--max_sequence_length",
+ type=int,
+ default=77,
+ help="Maximum sequence length to use with with the T5 text encoder",
+ )
+ parser.add_argument(
+ "--validation_prompt",
+ type=str,
+ default=None,
+ help="A prompt that is used during validation to verify that the model is learning.",
+ )
+ parser.add_argument(
+ "--num_validation_images",
+ type=int,
+ default=4,
+ help="Number of images that should be generated during validation with `validation_prompt`.",
+ )
+ parser.add_argument(
+ "--validation_epochs",
+ type=int,
+ default=50,
+ help=(
+ "Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt"
+ " `args.validation_prompt` multiple times: `args.num_validation_images`."
+ ),
+ )
+ parser.add_argument(
+ "--rank",
+ type=int,
+ default=4,
+ help=("The dimension of the LoRA update matrices."),
+ )
+ parser.add_argument(
+ "--with_prior_preservation",
+ default=False,
+ action="store_true",
+ help="Flag to add prior preservation loss.",
+ )
+ parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.")
+ parser.add_argument(
+ "--num_class_images",
+ type=int,
+ default=100,
+ help=(
+ "Minimal class images for prior preservation loss. If there are not enough images already present in"
+ " class_data_dir, additional images will be sampled with class_prompt."
+ ),
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="sd3-dreambooth",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--center_crop",
+ default=False,
+ action="store_true",
+ help=(
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
+ " cropped. The images will be resized to the resolution first before cropping."
+ ),
+ )
+ parser.add_argument(
+ "--random_flip",
+ action="store_true",
+ help="whether to randomly flip images horizontally",
+ )
+ parser.add_argument(
+ "--train_text_encoder",
+ action="store_true",
+ help="Whether to train the text encoder (clip text encoders only). If set, the text encoder should be float32 precision.",
+ )
+
+ parser.add_argument(
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument(
+ "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=1)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
+ " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+
+ parser.add_argument(
+ "--text_encoder_lr",
+ type=float,
+ default=5e-6,
+ help="Text encoder learning rate to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--lr_num_cycles",
+ type=int,
+ default=1,
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
+ )
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ parser.add_argument(
+ "--weighting_scheme",
+ type=str,
+ default="logit_normal",
+ choices=["sigma_sqrt", "logit_normal", "mode", "cosmap"],
+ )
+ parser.add_argument(
+ "--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme."
+ )
+ parser.add_argument(
+ "--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme."
+ )
+ parser.add_argument(
+ "--mode_scale",
+ type=float,
+ default=1.29,
+ help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.",
+ )
+ parser.add_argument(
+ "--precondition_outputs",
+ type=int,
+ default=1,
+ help="Flag indicating if we are preconditioning the model outputs or not as done in EDM. This affects how "
+ "model `target` is calculated.",
+ )
+ parser.add_argument(
+ "--optimizer",
+ type=str,
+ default="AdamW",
+ help=('The optimizer type to use. Choose between ["AdamW", "prodigy"]'),
+ )
+
+ parser.add_argument(
+ "--use_8bit_adam",
+ action="store_true",
+ help="Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW",
+ )
+
+ parser.add_argument(
+ "--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam and Prodigy optimizers."
+ )
+ parser.add_argument(
+ "--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam and Prodigy optimizers."
+ )
+ parser.add_argument(
+ "--prodigy_beta3",
+ type=float,
+ default=None,
+ help="coefficients for computing the Prodigy stepsize using running averages. If set to None, "
+ "uses the value of square root of beta2. Ignored if optimizer is adamW",
+ )
+ parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params")
+ parser.add_argument(
+ "--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder"
+ )
+
+ parser.add_argument(
+ "--adam_epsilon",
+ type=float,
+ default=1e-08,
+ help="Epsilon value for the Adam optimizer and Prodigy optimizers.",
+ )
+
+ parser.add_argument(
+ "--prodigy_use_bias_correction",
+ type=bool,
+ default=True,
+ help="Turn on Adam's bias correction. True by default. Ignored if optimizer is adamW",
+ )
+ parser.add_argument(
+ "--prodigy_safeguard_warmup",
+ type=bool,
+ default=True,
+ help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. "
+ "Ignored if optimizer is adamW",
+ )
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument(
+ "--prior_generation_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp32", "fp16", "bf16"],
+ help=(
+ "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32."
+ ),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+
+ if input_args is not None:
+ args = parser.parse_args(input_args)
+ else:
+ args = parser.parse_args()
+
+ if args.dataset_name is None and args.instance_data_dir is None:
+ raise ValueError("Specify either `--dataset_name` or `--instance_data_dir`")
+
+ if args.dataset_name is not None and args.instance_data_dir is not None:
+ raise ValueError("Specify only one of `--dataset_name` or `--instance_data_dir`")
+
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ if args.with_prior_preservation:
+ if args.class_data_dir is None:
+ raise ValueError("You must specify a data directory for class images.")
+ if args.class_prompt is None:
+ raise ValueError("You must specify prompt for class images.")
+ else:
+ # logger is not available yet
+ if args.class_data_dir is not None:
+ warnings.warn("You need not use --class_data_dir without --with_prior_preservation.")
+ if args.class_prompt is not None:
+ warnings.warn("You need not use --class_prompt without --with_prior_preservation.")
+
+ return args
+
+
+class DreamBoothDataset(Dataset):
+ """
+ A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
+ It pre-processes the images.
+ """
+
+ def __init__(
+ self,
+ instance_data_root,
+ instance_prompt,
+ class_prompt,
+ class_data_root=None,
+ class_num=None,
+ size=1024,
+ repeats=1,
+ center_crop=False,
+ ):
+ self.size = size
+ self.center_crop = center_crop
+
+ self.instance_prompt = instance_prompt
+ self.custom_instance_prompts = None
+ self.class_prompt = class_prompt
+
+ # if --dataset_name is provided or a metadata jsonl file is provided in the local --instance_data directory,
+ # we load the training data using load_dataset
+ if args.dataset_name is not None:
+ try:
+ from datasets import load_dataset
+ except ImportError:
+ raise ImportError(
+ "You are trying to load your data using the datasets library. If you wish to train using custom "
+ "captions please install the datasets library: `pip install datasets`. If you wish to load a "
+ "local folder containing images only, specify --instance_data_dir instead."
+ )
+ # Downloading and loading a dataset from the hub.
+ # See more about loading custom images at
+ # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script
+ dataset = load_dataset(
+ args.dataset_name,
+ args.dataset_config_name,
+ cache_dir=args.cache_dir,
+ )
+ # Preprocessing the datasets.
+ column_names = dataset["train"].column_names
+
+ # 6. Get the column names for input/target.
+ if args.image_column is None:
+ image_column = column_names[0]
+ logger.info(f"image column defaulting to {image_column}")
+ else:
+ image_column = args.image_column
+ if image_column not in column_names:
+ raise ValueError(
+ f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
+ )
+ instance_images = dataset["train"][image_column]
+
+ if args.caption_column is None:
+ logger.info(
+ "No caption column provided, defaulting to instance_prompt for all images. If your dataset "
+ "contains captions/prompts for the images, make sure to specify the "
+ "column as --caption_column"
+ )
+ self.custom_instance_prompts = None
+ else:
+ if args.caption_column not in column_names:
+ raise ValueError(
+ f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
+ )
+ custom_instance_prompts = dataset["train"][args.caption_column]
+ # create final list of captions according to --repeats
+ self.custom_instance_prompts = []
+ for caption in custom_instance_prompts:
+ self.custom_instance_prompts.extend(itertools.repeat(caption, repeats))
+ else:
+ self.instance_data_root = Path(instance_data_root)
+ if not self.instance_data_root.exists():
+ raise ValueError("Instance images root doesn't exists.")
+
+ instance_images = [Image.open(path) for path in list(Path(instance_data_root).iterdir())]
+ self.custom_instance_prompts = None
+
+ self.instance_images = []
+ for img in instance_images:
+ self.instance_images.extend(itertools.repeat(img, repeats))
+
+ self.pixel_values = []
+ train_resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR)
+ train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size)
+ train_flip = transforms.RandomHorizontalFlip(p=1.0)
+ train_transforms = transforms.Compose(
+ [
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+ )
+ for image in self.instance_images:
+ image = exif_transpose(image)
+ if not image.mode == "RGB":
+ image = image.convert("RGB")
+ image = train_resize(image)
+ if args.random_flip and random.random() < 0.5:
+ # flip
+ image = train_flip(image)
+ if args.center_crop:
+ y1 = max(0, int(round((image.height - args.resolution) / 2.0)))
+ x1 = max(0, int(round((image.width - args.resolution) / 2.0)))
+ image = train_crop(image)
+ else:
+ y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution))
+ image = crop(image, y1, x1, h, w)
+ image = train_transforms(image)
+ self.pixel_values.append(image)
+
+ self.num_instance_images = len(self.instance_images)
+ self._length = self.num_instance_images
+
+ if class_data_root is not None:
+ self.class_data_root = Path(class_data_root)
+ self.class_data_root.mkdir(parents=True, exist_ok=True)
+ self.class_images_path = list(self.class_data_root.iterdir())
+ if class_num is not None:
+ self.num_class_images = min(len(self.class_images_path), class_num)
+ else:
+ self.num_class_images = len(self.class_images_path)
+ self._length = max(self.num_class_images, self.num_instance_images)
+ else:
+ self.class_data_root = None
+
+ self.image_transforms = transforms.Compose(
+ [
+ transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
+ transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+ )
+
+ def __len__(self):
+ return self._length
+
+ def __getitem__(self, index):
+ example = {}
+ instance_image = self.pixel_values[index % self.num_instance_images]
+ example["instance_images"] = instance_image
+
+ if self.custom_instance_prompts:
+ caption = self.custom_instance_prompts[index % self.num_instance_images]
+ if caption:
+ example["instance_prompt"] = caption
+ else:
+ example["instance_prompt"] = self.instance_prompt
+
+ else: # custom prompts were provided, but length does not match size of image dataset
+ example["instance_prompt"] = self.instance_prompt
+
+ if self.class_data_root:
+ class_image = Image.open(self.class_images_path[index % self.num_class_images])
+ class_image = exif_transpose(class_image)
+
+ if not class_image.mode == "RGB":
+ class_image = class_image.convert("RGB")
+ example["class_images"] = self.image_transforms(class_image)
+ example["class_prompt"] = self.class_prompt
+
+ return example
+
+
+def collate_fn(examples, with_prior_preservation=False):
+ pixel_values = [example["instance_images"] for example in examples]
+ prompts = [example["instance_prompt"] for example in examples]
+
+ # Concat class and instance examples for prior preservation.
+ # We do this to avoid doing two forward passes.
+ if with_prior_preservation:
+ pixel_values += [example["class_images"] for example in examples]
+ prompts += [example["class_prompt"] for example in examples]
+
+ pixel_values = torch.stack(pixel_values)
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
+
+ batch = {"pixel_values": pixel_values, "prompts": prompts}
+ return batch
+
+
+class PromptDataset(Dataset):
+ "A simple dataset to prepare the prompts to generate class images on multiple GPUs."
+
+ def __init__(self, prompt, num_samples):
+ self.prompt = prompt
+ self.num_samples = num_samples
+
+ def __len__(self):
+ return self.num_samples
+
+ def __getitem__(self, index):
+ example = {}
+ example["prompt"] = self.prompt
+ example["index"] = index
+ return example
+
+
+def tokenize_prompt(tokenizer, prompt):
+ text_inputs = tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=77,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ return text_input_ids
+
+
+def _encode_prompt_with_t5(
+ text_encoder,
+ tokenizer,
+ max_sequence_length,
+ prompt=None,
+ num_images_per_prompt=1,
+ device=None,
+ text_input_ids=None,
+):
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ if tokenizer is not None:
+ text_inputs = tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ add_special_tokens=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ else:
+ if text_input_ids is None:
+ raise ValueError("text_input_ids must be provided when the tokenizer is not specified")
+
+ prompt_embeds = text_encoder(text_input_ids.to(device))[0]
+
+ dtype = text_encoder.dtype
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ _, seq_len, _ = prompt_embeds.shape
+
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ return prompt_embeds
+
+
+def _encode_prompt_with_clip(
+ text_encoder,
+ tokenizer,
+ prompt: str,
+ device=None,
+ text_input_ids=None,
+ num_images_per_prompt: int = 1,
+):
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ if tokenizer is not None:
+ text_inputs = tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=77,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ text_input_ids = text_inputs.input_ids
+ else:
+ if text_input_ids is None:
+ raise ValueError("text_input_ids must be provided when the tokenizer is not specified")
+
+ prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
+
+ pooled_prompt_embeds = prompt_embeds[0]
+ prompt_embeds = prompt_embeds.hidden_states[-2]
+ prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device)
+
+ _, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ return prompt_embeds, pooled_prompt_embeds
+
+
+def encode_prompt(
+ text_encoders,
+ tokenizers,
+ prompt: str,
+ max_sequence_length,
+ device=None,
+ num_images_per_prompt: int = 1,
+ text_input_ids_list=None,
+):
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ clip_tokenizers = tokenizers[:2]
+ clip_text_encoders = text_encoders[:2]
+
+ clip_prompt_embeds_list = []
+ clip_pooled_prompt_embeds_list = []
+ for i, (tokenizer, text_encoder) in enumerate(zip(clip_tokenizers, clip_text_encoders)):
+ prompt_embeds, pooled_prompt_embeds = _encode_prompt_with_clip(
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ prompt=prompt,
+ device=device if device is not None else text_encoder.device,
+ num_images_per_prompt=num_images_per_prompt,
+ text_input_ids=text_input_ids_list[i] if text_input_ids_list else None,
+ )
+ clip_prompt_embeds_list.append(prompt_embeds)
+ clip_pooled_prompt_embeds_list.append(pooled_prompt_embeds)
+
+ clip_prompt_embeds = torch.cat(clip_prompt_embeds_list, dim=-1)
+ pooled_prompt_embeds = torch.cat(clip_pooled_prompt_embeds_list, dim=-1)
+
+ t5_prompt_embed = _encode_prompt_with_t5(
+ text_encoders[-1],
+ tokenizers[-1],
+ max_sequence_length,
+ prompt=prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ text_input_ids=text_input_ids_list[-1] if text_input_ids_list else None,
+ device=device if device is not None else text_encoders[-1].device,
+ )
+
+ clip_prompt_embeds = torch.nn.functional.pad(
+ clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1])
+ )
+ prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2)
+
+ return prompt_embeds, pooled_prompt_embeds
+
+
+def main(args):
+ if args.report_to == "wandb" and args.hub_token is not None:
+ raise ValueError(
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
+ " Please use `huggingface-cli login` to authenticate with the Hub."
+ )
+
+ if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
+ # due to pytorch#99272, MPS does not yet support bfloat16.
+ raise ValueError(
+ "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
+ )
+
+ logging_dir = Path(args.output_dir, args.logging_dir)
+
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+ kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ kwargs_handlers=[kwargs],
+ )
+
+ # Disable AMP for MPS.
+ if torch.backends.mps.is_available():
+ accelerator.native_amp = False
+
+ if args.report_to == "wandb":
+ if not is_wandb_available():
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Generate class images if prior preservation is enabled.
+ if args.with_prior_preservation:
+ class_images_dir = Path(args.class_data_dir)
+ if not class_images_dir.exists():
+ class_images_dir.mkdir(parents=True)
+ cur_class_images = len(list(class_images_dir.iterdir()))
+
+ if cur_class_images < args.num_class_images:
+ has_supported_fp16_accelerator = torch.cuda.is_available() or torch.backends.mps.is_available()
+ torch_dtype = torch.float16 if has_supported_fp16_accelerator else torch.float32
+ if args.prior_generation_precision == "fp32":
+ torch_dtype = torch.float32
+ elif args.prior_generation_precision == "fp16":
+ torch_dtype = torch.float16
+ elif args.prior_generation_precision == "bf16":
+ torch_dtype = torch.bfloat16
+ pipeline = StableDiffusion3Pipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ torch_dtype=torch_dtype,
+ revision=args.revision,
+ variant=args.variant,
+ )
+ pipeline.set_progress_bar_config(disable=True)
+
+ num_new_images = args.num_class_images - cur_class_images
+ logger.info(f"Number of class images to sample: {num_new_images}.")
+
+ sample_dataset = PromptDataset(args.class_prompt, num_new_images)
+ sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
+
+ sample_dataloader = accelerator.prepare(sample_dataloader)
+ pipeline.to(accelerator.device)
+
+ for example in tqdm(
+ sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
+ ):
+ images = pipeline(example["prompt"]).images
+
+ for i, image in enumerate(images):
+ hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest()
+ image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
+ image.save(image_filename)
+
+ clear_objs_and_retain_memory(objs=[pipeline])
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name,
+ exist_ok=True,
+ ).repo_id
+
+ # Load the tokenizers
+ tokenizer_one = CLIPTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="tokenizer",
+ revision=args.revision,
+ )
+ tokenizer_two = CLIPTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="tokenizer_2",
+ revision=args.revision,
+ )
+ tokenizer_three = T5TokenizerFast.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="tokenizer_3",
+ revision=args.revision,
+ )
+
+ # import correct text encoder classes
+ text_encoder_cls_one = import_model_class_from_model_name_or_path(
+ args.pretrained_model_name_or_path, args.revision
+ )
+ text_encoder_cls_two = import_model_class_from_model_name_or_path(
+ args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2"
+ )
+ text_encoder_cls_three = import_model_class_from_model_name_or_path(
+ args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_3"
+ )
+
+ # Load scheduler and models
+ noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="scheduler"
+ )
+ noise_scheduler_copy = copy.deepcopy(noise_scheduler)
+ text_encoder_one, text_encoder_two, text_encoder_three = load_text_encoders(
+ text_encoder_cls_one, text_encoder_cls_two, text_encoder_cls_three
+ )
+ vae = AutoencoderKL.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="vae",
+ revision=args.revision,
+ variant=args.variant,
+ )
+ transformer = SD3Transformer2DModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="transformer", revision=args.revision, variant=args.variant
+ )
+
+ transformer.requires_grad_(False)
+ vae.requires_grad_(False)
+ text_encoder_one.requires_grad_(False)
+ text_encoder_two.requires_grad_(False)
+ text_encoder_three.requires_grad_(False)
+
+ # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora transformer) to half-precision
+ # as these weights are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:
+ # due to pytorch#99272, MPS does not yet support bfloat16.
+ raise ValueError(
+ "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
+ )
+
+ vae.to(accelerator.device, dtype=torch.float32)
+ transformer.to(accelerator.device, dtype=weight_dtype)
+ text_encoder_one.to(accelerator.device, dtype=weight_dtype)
+ text_encoder_two.to(accelerator.device, dtype=weight_dtype)
+ text_encoder_three.to(accelerator.device, dtype=weight_dtype)
+
+ if args.gradient_checkpointing:
+ transformer.enable_gradient_checkpointing()
+ if args.train_text_encoder:
+ text_encoder_one.gradient_checkpointing_enable()
+ text_encoder_two.gradient_checkpointing_enable()
+
+ # now we will add new LoRA weights to the attention layers
+ transformer_lora_config = LoraConfig(
+ r=args.rank,
+ lora_alpha=args.rank,
+ init_lora_weights="gaussian",
+ target_modules=["to_k", "to_q", "to_v", "to_out.0"],
+ )
+ transformer.add_adapter(transformer_lora_config)
+
+ if args.train_text_encoder:
+ text_lora_config = LoraConfig(
+ r=args.rank,
+ lora_alpha=args.rank,
+ init_lora_weights="gaussian",
+ target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
+ )
+ text_encoder_one.add_adapter(text_lora_config)
+ text_encoder_two.add_adapter(text_lora_config)
+
+ def unwrap_model(model):
+ model = accelerator.unwrap_model(model)
+ model = model._orig_mod if is_compiled_module(model) else model
+ return model
+
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
+ def save_model_hook(models, weights, output_dir):
+ if accelerator.is_main_process:
+ transformer_lora_layers_to_save = None
+ text_encoder_one_lora_layers_to_save = None
+ text_encoder_two_lora_layers_to_save = None
+
+ for model in models:
+ if isinstance(model, type(unwrap_model(transformer))):
+ transformer_lora_layers_to_save = get_peft_model_state_dict(model)
+ elif isinstance(model, type(unwrap_model(text_encoder_one))):
+ text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model)
+ elif isinstance(model, type(unwrap_model(text_encoder_two))):
+ text_encoder_two_lora_layers_to_save = get_peft_model_state_dict(model)
+ else:
+ raise ValueError(f"unexpected save model: {model.__class__}")
+
+ # make sure to pop weight so that corresponding model is not saved again
+ weights.pop()
+
+ StableDiffusion3Pipeline.save_lora_weights(
+ output_dir,
+ transformer_lora_layers=transformer_lora_layers_to_save,
+ text_encoder_lora_layers=text_encoder_one_lora_layers_to_save,
+ text_encoder_2_lora_layers=text_encoder_two_lora_layers_to_save,
+ )
+
+ def load_model_hook(models, input_dir):
+ transformer_ = None
+ text_encoder_one_ = None
+ text_encoder_two_ = None
+
+ while len(models) > 0:
+ model = models.pop()
+
+ if isinstance(model, type(unwrap_model(transformer))):
+ transformer_ = model
+ elif isinstance(model, type(unwrap_model(text_encoder_one))):
+ text_encoder_one_ = model
+ elif isinstance(model, type(unwrap_model(text_encoder_two))):
+ text_encoder_two_ = model
+ else:
+ raise ValueError(f"unexpected save model: {model.__class__}")
+
+ lora_state_dict = StableDiffusion3Pipeline.lora_state_dict(input_dir)
+
+ transformer_state_dict = {
+ f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.")
+ }
+ transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
+ incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
+ if incompatible_keys is not None:
+ # check only for unexpected keys
+ unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
+ if unexpected_keys:
+ logger.warning(
+ f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
+ f" {unexpected_keys}. "
+ )
+ if args.train_text_encoder:
+ # Do we need to call `scale_lora_layers()` here?
+ _set_state_dict_into_text_encoder(lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_one_)
+
+ _set_state_dict_into_text_encoder(
+ lora_state_dict, prefix="text_encoder_2.", text_encoder=text_encoder_two_
+ )
+
+ # Make sure the trainable params are in float32. This is again needed since the base models
+ # are in `weight_dtype`. More details:
+ # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
+ if args.mixed_precision == "fp16":
+ models = [transformer_]
+ if args.train_text_encoder:
+ models.extend([text_encoder_one_, text_encoder_two_])
+ # only upcast trainable parameters (LoRA) into fp32
+ cast_training_params(models)
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32 and torch.cuda.is_available():
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ # Make sure the trainable params are in float32.
+ if args.mixed_precision == "fp16":
+ models = [transformer]
+ if args.train_text_encoder:
+ models.extend([text_encoder_one, text_encoder_two])
+ # only upcast trainable parameters (LoRA) into fp32
+ cast_training_params(models, dtype=torch.float32)
+
+ transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters()))
+ if args.train_text_encoder:
+ text_lora_parameters_one = list(filter(lambda p: p.requires_grad, text_encoder_one.parameters()))
+ text_lora_parameters_two = list(filter(lambda p: p.requires_grad, text_encoder_two.parameters()))
+
+ # Optimization parameters
+ transformer_parameters_with_lr = {"params": transformer_lora_parameters, "lr": args.learning_rate}
+ if args.train_text_encoder:
+ # different learning rate for text encoder and unet
+ text_lora_parameters_one_with_lr = {
+ "params": text_lora_parameters_one,
+ "weight_decay": args.adam_weight_decay_text_encoder,
+ "lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate,
+ }
+ text_lora_parameters_two_with_lr = {
+ "params": text_lora_parameters_two,
+ "weight_decay": args.adam_weight_decay_text_encoder,
+ "lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate,
+ }
+ params_to_optimize = [
+ transformer_parameters_with_lr,
+ text_lora_parameters_one_with_lr,
+ text_lora_parameters_two_with_lr,
+ ]
+ else:
+ params_to_optimize = [transformer_parameters_with_lr]
+
+ # Optimizer creation
+ if not (args.optimizer.lower() == "prodigy" or args.optimizer.lower() == "adamw"):
+ logger.warning(
+ f"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy]."
+ "Defaulting to adamW"
+ )
+ args.optimizer = "adamw"
+
+ if args.use_8bit_adam and not args.optimizer.lower() == "adamw":
+ logger.warning(
+ f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was "
+ f"set to {args.optimizer.lower()}"
+ )
+
+ if args.optimizer.lower() == "adamw":
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
+ )
+
+ optimizer_class = bnb.optim.AdamW8bit
+ else:
+ optimizer_class = torch.optim.AdamW
+
+ optimizer = optimizer_class(
+ params_to_optimize,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ if args.optimizer.lower() == "prodigy":
+ try:
+ import prodigyopt
+ except ImportError:
+ raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`")
+
+ optimizer_class = prodigyopt.Prodigy
+
+ if args.learning_rate <= 0.1:
+ logger.warning(
+ "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0"
+ )
+
+ optimizer = optimizer_class(
+ params_to_optimize,
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ beta3=args.prodigy_beta3,
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ decouple=args.prodigy_decouple,
+ use_bias_correction=args.prodigy_use_bias_correction,
+ safeguard_warmup=args.prodigy_safeguard_warmup,
+ )
+
+ # Dataset and DataLoaders creation:
+ train_dataset = DreamBoothDataset(
+ instance_data_root=args.instance_data_dir,
+ instance_prompt=args.instance_prompt,
+ class_prompt=args.class_prompt,
+ class_data_root=args.class_data_dir if args.with_prior_preservation else None,
+ class_num=args.num_class_images,
+ size=args.resolution,
+ repeats=args.repeats,
+ center_crop=args.center_crop,
+ )
+
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset,
+ batch_size=args.train_batch_size,
+ shuffle=True,
+ collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),
+ num_workers=args.dataloader_num_workers,
+ )
+
+ if not args.train_text_encoder:
+ tokenizers = [tokenizer_one, tokenizer_two, tokenizer_three]
+ text_encoders = [text_encoder_one, text_encoder_two, text_encoder_three]
+
+ def compute_text_embeddings(prompt, text_encoders, tokenizers):
+ with torch.no_grad():
+ prompt_embeds, pooled_prompt_embeds = encode_prompt(
+ text_encoders, tokenizers, prompt, args.max_sequence_length
+ )
+ prompt_embeds = prompt_embeds.to(accelerator.device)
+ pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device)
+ return prompt_embeds, pooled_prompt_embeds
+
+ if not args.train_text_encoder and not train_dataset.custom_instance_prompts:
+ instance_prompt_hidden_states, instance_pooled_prompt_embeds = compute_text_embeddings(
+ args.instance_prompt, text_encoders, tokenizers
+ )
+
+ # Handle class prompt for prior-preservation.
+ if args.with_prior_preservation:
+ if not args.train_text_encoder:
+ class_prompt_hidden_states, class_pooled_prompt_embeds = compute_text_embeddings(
+ args.class_prompt, text_encoders, tokenizers
+ )
+
+ # Clear the memory here
+ if not args.train_text_encoder and not train_dataset.custom_instance_prompts:
+ # Explicitly delete the objects as well, otherwise only the lists are deleted and the original references remain, preventing garbage collection
+ clear_objs_and_retain_memory(
+ objs=[tokenizers, text_encoders, text_encoder_one, text_encoder_two, text_encoder_three]
+ )
+
+ # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
+ # pack the statically computed variables appropriately here. This is so that we don't
+ # have to pass them to the dataloader.
+
+ if not train_dataset.custom_instance_prompts:
+ if not args.train_text_encoder:
+ prompt_embeds = instance_prompt_hidden_states
+ pooled_prompt_embeds = instance_pooled_prompt_embeds
+ if args.with_prior_preservation:
+ prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0)
+ pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, class_pooled_prompt_embeds], dim=0)
+ # if we're optimizing the text encoder (both if instance prompt is used for all images or custom prompts) we need to tokenize and encode the
+ # batch prompts on all training steps
+ else:
+ tokens_one = tokenize_prompt(tokenizer_one, args.instance_prompt)
+ tokens_two = tokenize_prompt(tokenizer_two, args.instance_prompt)
+ tokens_three = tokenize_prompt(tokenizer_three, args.instance_prompt)
+ if args.with_prior_preservation:
+ class_tokens_one = tokenize_prompt(tokenizer_one, args.class_prompt)
+ class_tokens_two = tokenize_prompt(tokenizer_two, args.class_prompt)
+ class_tokens_three = tokenize_prompt(tokenizer_three, args.class_prompt)
+ tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0)
+ tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0)
+ tokens_three = torch.cat([tokens_three, class_tokens_three], dim=0)
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
+ num_cycles=args.lr_num_cycles,
+ power=args.lr_power,
+ )
+
+ # Prepare everything with our `accelerator`.
+ # Prepare everything with our `accelerator`.
+ if args.train_text_encoder:
+ (
+ transformer,
+ text_encoder_one,
+ text_encoder_two,
+ optimizer,
+ train_dataloader,
+ lr_scheduler,
+ ) = accelerator.prepare(
+ transformer, text_encoder_one, text_encoder_two, optimizer, train_dataloader, lr_scheduler
+ )
+ assert text_encoder_one is not None
+ assert text_encoder_two is not None
+ assert text_encoder_three is not None
+ else:
+ transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ transformer, optimizer, train_dataloader, lr_scheduler
+ )
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ tracker_name = "dreambooth-sd3-lora"
+ accelerator.init_trackers(tracker_name, config=vars(args))
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the mos recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ initial_global_step = 0
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ initial_global_step = global_step
+ first_epoch = global_step // num_update_steps_per_epoch
+
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
+
+ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
+ sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype)
+ schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device)
+ timesteps = timesteps.to(accelerator.device)
+ step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
+
+ sigma = sigmas[step_indices].flatten()
+ while len(sigma.shape) < n_dim:
+ sigma = sigma.unsqueeze(-1)
+ return sigma
+
+ for epoch in range(first_epoch, args.num_train_epochs):
+ transformer.train()
+ if args.train_text_encoder:
+ text_encoder_one.train()
+ text_encoder_two.train()
+
+ # set top parameter requires_grad = True for gradient checkpointing works
+ accelerator.unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True)
+ accelerator.unwrap_model(text_encoder_two).text_model.embeddings.requires_grad_(True)
+
+ for step, batch in enumerate(train_dataloader):
+ models_to_accumulate = [transformer]
+ with accelerator.accumulate(models_to_accumulate):
+ pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
+ prompts = batch["prompts"]
+
+ # encode batch prompts when custom prompts are provided for each image -
+ if train_dataset.custom_instance_prompts:
+ if not args.train_text_encoder:
+ prompt_embeds, pooled_prompt_embeds = compute_text_embeddings(
+ prompts, text_encoders, tokenizers
+ )
+ else:
+ tokens_one = tokenize_prompt(tokenizer_one, prompts)
+ tokens_two = tokenize_prompt(tokenizer_two, prompts)
+ tokens_three = tokenize_prompt(tokenizer_three, prompts)
+ prompt_embeds, pooled_prompt_embeds = encode_prompt(
+ text_encoders=[text_encoder_one, text_encoder_two, text_encoder_three],
+ tokenizers=[None, None, None],
+ prompt=prompts,
+ max_sequence_length=args.max_sequence_length,
+ text_input_ids_list=[tokens_one, tokens_two, tokens_three],
+ )
+ else:
+ if args.train_text_encoder:
+ prompt_embeds, pooled_prompt_embeds = encode_prompt(
+ text_encoders=[text_encoder_one, text_encoder_two, text_encoder_three],
+ tokenizers=[None, None, tokenizer_three],
+ prompt=args.instance_prompt,
+ max_sequence_length=args.max_sequence_length,
+ text_input_ids_list=[tokens_one, tokens_two, tokens_three],
+ )
+
+ # Convert images to latent space
+ model_input = vae.encode(pixel_values).latent_dist.sample()
+ model_input = (model_input - vae.config.shift_factor) * vae.config.scaling_factor
+ model_input = model_input.to(dtype=weight_dtype)
+
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(model_input)
+ bsz = model_input.shape[0]
+
+ # Sample a random timestep for each image
+ # for weighting schemes where we sample timesteps non-uniformly
+ u = compute_density_for_timestep_sampling(
+ weighting_scheme=args.weighting_scheme,
+ batch_size=bsz,
+ logit_mean=args.logit_mean,
+ logit_std=args.logit_std,
+ mode_scale=args.mode_scale,
+ )
+ indices = (u * noise_scheduler_copy.config.num_train_timesteps).long()
+ timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device)
+
+ # Add noise according to flow matching.
+ # zt = (1 - texp) * x + texp * z1
+ sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)
+ noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise
+
+ # Predict the noise residual
+ model_pred = transformer(
+ hidden_states=noisy_model_input,
+ timestep=timesteps,
+ encoder_hidden_states=prompt_embeds,
+ pooled_projections=pooled_prompt_embeds,
+ return_dict=False,
+ )[0]
+
+ # Follow: Section 5 of https://arxiv.org/abs/2206.00364.
+ # Preconditioning of the model outputs.
+ if args.precondition_outputs:
+ model_pred = model_pred * (-sigmas) + noisy_model_input
+
+ # these weighting schemes use a uniform timestep sampling
+ # and instead post-weight the loss
+ weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
+
+ # flow matching loss
+ if args.precondition_outputs:
+ target = model_input
+ else:
+ target = noise - model_input
+
+ if args.with_prior_preservation:
+ # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
+ model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
+ target, target_prior = torch.chunk(target, 2, dim=0)
+
+ # Compute prior loss
+ prior_loss = torch.mean(
+ (weighting.float() * (model_pred_prior.float() - target_prior.float()) ** 2).reshape(
+ target_prior.shape[0], -1
+ ),
+ 1,
+ )
+ prior_loss = prior_loss.mean()
+
+ # Compute regular loss.
+ loss = torch.mean(
+ (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1),
+ 1,
+ )
+ loss = loss.mean()
+
+ if args.with_prior_preservation:
+ # Add the prior loss to the instance loss.
+ loss = loss + args.prior_loss_weight * prior_loss
+
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ params_to_clip = (
+ itertools.chain(
+ transformer_lora_parameters, text_lora_parameters_one, text_lora_parameters_two
+ )
+ if args.train_text_encoder
+ else transformer_lora_parameters
+ )
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
+
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ global_step += 1
+
+ if accelerator.is_main_process:
+ if global_step % args.checkpointing_steps == 0:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+ accelerator.log(logs, step=global_step)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ if accelerator.is_main_process:
+ if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
+ if not args.train_text_encoder:
+ # create pipeline
+ text_encoder_one, text_encoder_two, text_encoder_three = load_text_encoders(
+ text_encoder_cls_one, text_encoder_cls_two, text_encoder_cls_three
+ )
+ pipeline = StableDiffusion3Pipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ vae=vae,
+ text_encoder=accelerator.unwrap_model(text_encoder_one),
+ text_encoder_2=accelerator.unwrap_model(text_encoder_two),
+ text_encoder_3=accelerator.unwrap_model(text_encoder_three),
+ transformer=accelerator.unwrap_model(transformer),
+ revision=args.revision,
+ variant=args.variant,
+ torch_dtype=weight_dtype,
+ )
+ pipeline_args = {"prompt": args.validation_prompt}
+ images = log_validation(
+ pipeline=pipeline,
+ args=args,
+ accelerator=accelerator,
+ pipeline_args=pipeline_args,
+ epoch=epoch,
+ torch_dtype=weight_dtype,
+ )
+ objs = []
+ if not args.train_text_encoder:
+ objs.extend([text_encoder_one, text_encoder_two, text_encoder_three])
+
+ clear_objs_and_retain_memory(objs=objs)
+
+ # Save the lora layers
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ transformer = unwrap_model(transformer)
+ transformer = transformer.to(torch.float32)
+ transformer_lora_layers = get_peft_model_state_dict(transformer)
+
+ if args.train_text_encoder:
+ text_encoder_one = unwrap_model(text_encoder_one)
+ text_encoder_lora_layers = get_peft_model_state_dict(text_encoder_one.to(torch.float32))
+ text_encoder_two = unwrap_model(text_encoder_two)
+ text_encoder_2_lora_layers = get_peft_model_state_dict(text_encoder_two.to(torch.float32))
+ else:
+ text_encoder_lora_layers = None
+ text_encoder_2_lora_layers = None
+
+ StableDiffusion3Pipeline.save_lora_weights(
+ save_directory=args.output_dir,
+ transformer_lora_layers=transformer_lora_layers,
+ text_encoder_lora_layers=text_encoder_lora_layers,
+ text_encoder_2_lora_layers=text_encoder_2_lora_layers,
+ )
+
+ # Final inference
+ # Load previous pipeline
+ pipeline = StableDiffusion3Pipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ revision=args.revision,
+ variant=args.variant,
+ torch_dtype=weight_dtype,
+ )
+ # load attention processors
+ pipeline.load_lora_weights(args.output_dir)
+
+ # run inference
+ images = []
+ if args.validation_prompt and args.num_validation_images > 0:
+ pipeline_args = {"prompt": args.validation_prompt}
+ images = log_validation(
+ pipeline=pipeline,
+ args=args,
+ accelerator=accelerator,
+ pipeline_args=pipeline_args,
+ epoch=epoch,
+ is_final_validation=True,
+ torch_dtype=weight_dtype,
+ )
+
+ if args.push_to_hub:
+ save_model_card(
+ repo_id,
+ images=images,
+ base_model=args.pretrained_model_name_or_path,
+ instance_prompt=args.instance_prompt,
+ validation_prompt=args.validation_prompt,
+ train_text_encoder=args.train_text_encoder,
+ repo_folder=args.output_dir,
+ )
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ main(args)
diff --git a/diffusers/examples/dreambooth/train_dreambooth_lora_sdxl.py b/diffusers/examples/dreambooth/train_dreambooth_lora_sdxl.py
new file mode 100644
index 0000000000000000000000000000000000000000..016464165c440506d25a754e151d7761808740df
--- /dev/null
+++ b/diffusers/examples/dreambooth/train_dreambooth_lora_sdxl.py
@@ -0,0 +1,1987 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+import argparse
+import gc
+import itertools
+import json
+import logging
+import math
+import os
+import random
+import shutil
+import warnings
+from contextlib import nullcontext
+from pathlib import Path
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
+from huggingface_hub import create_repo, hf_hub_download, upload_folder
+from huggingface_hub.utils import insecure_hashlib
+from packaging import version
+from peft import LoraConfig, set_peft_model_state_dict
+from peft.utils import get_peft_model_state_dict
+from PIL import Image
+from PIL.ImageOps import exif_transpose
+from safetensors.torch import load_file, save_file
+from torch.utils.data import Dataset
+from torchvision import transforms
+from torchvision.transforms.functional import crop
+from tqdm.auto import tqdm
+from transformers import AutoTokenizer, PretrainedConfig
+
+import diffusers
+from diffusers import (
+ AutoencoderKL,
+ DDPMScheduler,
+ DPMSolverMultistepScheduler,
+ EDMEulerScheduler,
+ EulerDiscreteScheduler,
+ StableDiffusionXLPipeline,
+ UNet2DConditionModel,
+)
+from diffusers.loaders import StableDiffusionLoraLoaderMixin
+from diffusers.optimization import get_scheduler
+from diffusers.training_utils import _set_state_dict_into_text_encoder, cast_training_params, compute_snr
+from diffusers.utils import (
+ check_min_version,
+ convert_all_state_dict_to_peft,
+ convert_state_dict_to_diffusers,
+ convert_state_dict_to_kohya,
+ convert_unet_state_dict_to_peft,
+ is_wandb_available,
+)
+from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
+from diffusers.utils.import_utils import is_xformers_available
+from diffusers.utils.torch_utils import is_compiled_module
+
+
+if is_wandb_available():
+ import wandb
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.31.0.dev0")
+
+logger = get_logger(__name__)
+
+
+def determine_scheduler_type(pretrained_model_name_or_path, revision):
+ model_index_filename = "model_index.json"
+ if os.path.isdir(pretrained_model_name_or_path):
+ model_index = os.path.join(pretrained_model_name_or_path, model_index_filename)
+ else:
+ model_index = hf_hub_download(
+ repo_id=pretrained_model_name_or_path, filename=model_index_filename, revision=revision
+ )
+
+ with open(model_index, "r") as f:
+ scheduler_type = json.load(f)["scheduler"][1]
+ return scheduler_type
+
+
+def save_model_card(
+ repo_id: str,
+ use_dora: bool,
+ images=None,
+ base_model: str = None,
+ train_text_encoder=False,
+ instance_prompt=None,
+ validation_prompt=None,
+ repo_folder=None,
+ vae_path=None,
+):
+ widget_dict = []
+ if images is not None:
+ for i, image in enumerate(images):
+ image.save(os.path.join(repo_folder, f"image_{i}.png"))
+ widget_dict.append(
+ {"text": validation_prompt if validation_prompt else " ", "output": {"url": f"image_{i}.png"}}
+ )
+
+ model_description = f"""
+# {'SDXL' if 'playground' not in base_model else 'Playground'} LoRA DreamBooth - {repo_id}
+
+
+
+## Model description
+
+These are {repo_id} LoRA adaption weights for {base_model}.
+
+The weights were trained using [DreamBooth](https://dreambooth.github.io/).
+
+LoRA for the text encoder was enabled: {train_text_encoder}.
+
+Special VAE used for training: {vae_path}.
+
+## Trigger words
+
+You should use {instance_prompt} to trigger the image generation.
+
+## Download model
+
+Weights for this model are available in Safetensors format.
+
+[Download]({repo_id}/tree/main) them in the Files & versions tab.
+
+"""
+ if "playground" in base_model:
+ model_description += """\n
+## License
+
+Please adhere to the licensing terms as described [here](https://huggingface.co/playgroundai/playground-v2.5-1024px-aesthetic/blob/main/LICENSE.md).
+"""
+ model_card = load_or_create_model_card(
+ repo_id_or_path=repo_id,
+ from_training=True,
+ license="openrail++" if "playground" not in base_model else "playground-v2dot5-community",
+ base_model=base_model,
+ prompt=instance_prompt,
+ model_description=model_description,
+ widget=widget_dict,
+ )
+ tags = [
+ "text-to-image",
+ "text-to-image",
+ "diffusers-training",
+ "diffusers",
+ "lora" if not use_dora else "dora",
+ "template:sd-lora",
+ ]
+ if "playground" in base_model:
+ tags.extend(["playground", "playground-diffusers"])
+ else:
+ tags.extend(["stable-diffusion-xl", "stable-diffusion-xl-diffusers"])
+
+ model_card = populate_model_card(model_card, tags=tags)
+ model_card.save(os.path.join(repo_folder, "README.md"))
+
+
+def log_validation(
+ pipeline,
+ args,
+ accelerator,
+ pipeline_args,
+ epoch,
+ torch_dtype,
+ is_final_validation=False,
+):
+ logger.info(
+ f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
+ f" {args.validation_prompt}."
+ )
+
+ # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
+ scheduler_args = {}
+
+ if not args.do_edm_style_training:
+ if "variance_type" in pipeline.scheduler.config:
+ variance_type = pipeline.scheduler.config.variance_type
+
+ if variance_type in ["learned", "learned_range"]:
+ variance_type = "fixed_small"
+
+ scheduler_args["variance_type"] = variance_type
+
+ pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args)
+
+ pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)
+ pipeline.set_progress_bar_config(disable=True)
+
+ # run inference
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
+ # Currently the context determination is a bit hand-wavy. We can improve it in the future if there's a better
+ # way to condition it. Reference: https://github.com/huggingface/diffusers/pull/7126#issuecomment-1968523051
+ if torch.backends.mps.is_available() or "playground" in args.pretrained_model_name_or_path:
+ autocast_ctx = nullcontext()
+ else:
+ autocast_ctx = torch.autocast(accelerator.device.type)
+
+ with autocast_ctx:
+ images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)]
+
+ for tracker in accelerator.trackers:
+ phase_name = "test" if is_final_validation else "validation"
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images(phase_name, np_images, epoch, dataformats="NHWC")
+ if tracker.name == "wandb":
+ tracker.log(
+ {
+ phase_name: [
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images)
+ ]
+ }
+ )
+
+ del pipeline
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+
+ return images
+
+
+def import_model_class_from_model_name_or_path(
+ pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
+):
+ text_encoder_config = PretrainedConfig.from_pretrained(
+ pretrained_model_name_or_path, subfolder=subfolder, revision=revision
+ )
+ model_class = text_encoder_config.architectures[0]
+
+ if model_class == "CLIPTextModel":
+ from transformers import CLIPTextModel
+
+ return CLIPTextModel
+ elif model_class == "CLIPTextModelWithProjection":
+ from transformers import CLIPTextModelWithProjection
+
+ return CLIPTextModelWithProjection
+ else:
+ raise ValueError(f"{model_class} is not supported.")
+
+
+def parse_args(input_args=None):
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--pretrained_vae_model_name_or_path",
+ type=str,
+ default=None,
+ help="Path to pretrained VAE model with better numerical stability. More details: https://github.com/huggingface/diffusers/pull/4038.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--variant",
+ type=str,
+ default=None,
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
+ )
+ parser.add_argument(
+ "--dataset_name",
+ type=str,
+ default=None,
+ help=(
+ "The name of the Dataset (from the HuggingFace hub) containing the training data of instance images (could be your own, possibly private,"
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
+ " or to a folder containing files that 🤗 Datasets can understand."
+ ),
+ )
+ parser.add_argument(
+ "--dataset_config_name",
+ type=str,
+ default=None,
+ help="The config of the Dataset, leave as None if there's only one config.",
+ )
+ parser.add_argument(
+ "--instance_data_dir",
+ type=str,
+ default=None,
+ help=("A folder containing the training data. "),
+ )
+
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+
+ parser.add_argument(
+ "--image_column",
+ type=str,
+ default="image",
+ help="The column of the dataset containing the target image. By "
+ "default, the standard Image Dataset maps out 'file_name' "
+ "to 'image'.",
+ )
+ parser.add_argument(
+ "--caption_column",
+ type=str,
+ default=None,
+ help="The column of the dataset containing the instance prompt for each image",
+ )
+
+ parser.add_argument("--repeats", type=int, default=1, help="How many times to repeat the training data.")
+
+ parser.add_argument(
+ "--class_data_dir",
+ type=str,
+ default=None,
+ required=False,
+ help="A folder containing the training data of class images.",
+ )
+ parser.add_argument(
+ "--instance_prompt",
+ type=str,
+ default=None,
+ required=True,
+ help="The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'",
+ )
+ parser.add_argument(
+ "--class_prompt",
+ type=str,
+ default=None,
+ help="The prompt to specify images in the same class as provided instance images.",
+ )
+ parser.add_argument(
+ "--validation_prompt",
+ type=str,
+ default=None,
+ help="A prompt that is used during validation to verify that the model is learning.",
+ )
+ parser.add_argument(
+ "--num_validation_images",
+ type=int,
+ default=4,
+ help="Number of images that should be generated during validation with `validation_prompt`.",
+ )
+ parser.add_argument(
+ "--validation_epochs",
+ type=int,
+ default=50,
+ help=(
+ "Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt"
+ " `args.validation_prompt` multiple times: `args.num_validation_images`."
+ ),
+ )
+ parser.add_argument(
+ "--do_edm_style_training",
+ default=False,
+ action="store_true",
+ help="Flag to conduct training using the EDM formulation as introduced in https://arxiv.org/abs/2206.00364.",
+ )
+ parser.add_argument(
+ "--with_prior_preservation",
+ default=False,
+ action="store_true",
+ help="Flag to add prior preservation loss.",
+ )
+ parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.")
+ parser.add_argument(
+ "--num_class_images",
+ type=int,
+ default=100,
+ help=(
+ "Minimal class images for prior preservation loss. If there are not enough images already present in"
+ " class_data_dir, additional images will be sampled with class_prompt."
+ ),
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="lora-dreambooth-model",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument(
+ "--output_kohya_format",
+ action="store_true",
+ help="Flag to additionally generate final state dict in the Kohya format so that it becomes compatible with A111, Comfy, Kohya, etc.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=1024,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--center_crop",
+ default=False,
+ action="store_true",
+ help=(
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
+ " cropped. The images will be resized to the resolution first before cropping."
+ ),
+ )
+ parser.add_argument(
+ "--random_flip",
+ action="store_true",
+ help="whether to randomly flip images horizontally",
+ )
+ parser.add_argument(
+ "--train_text_encoder",
+ action="store_true",
+ help="Whether to train the text encoder. If set, the text encoder should be float32 precision.",
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument(
+ "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=1)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
+ " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+
+ parser.add_argument(
+ "--text_encoder_lr",
+ type=float,
+ default=5e-6,
+ help="Text encoder learning rate to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+
+ parser.add_argument(
+ "--snr_gamma",
+ type=float,
+ default=None,
+ help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
+ "More details here: https://arxiv.org/abs/2303.09556.",
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--lr_num_cycles",
+ type=int,
+ default=1,
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
+ )
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+
+ parser.add_argument(
+ "--optimizer",
+ type=str,
+ default="AdamW",
+ help=('The optimizer type to use. Choose between ["AdamW", "prodigy"]'),
+ )
+
+ parser.add_argument(
+ "--use_8bit_adam",
+ action="store_true",
+ help="Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW",
+ )
+
+ parser.add_argument(
+ "--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam and Prodigy optimizers."
+ )
+ parser.add_argument(
+ "--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam and Prodigy optimizers."
+ )
+ parser.add_argument(
+ "--prodigy_beta3",
+ type=float,
+ default=None,
+ help="coefficients for computing the Prodigy stepsize using running averages. If set to None, "
+ "uses the value of square root of beta2. Ignored if optimizer is adamW",
+ )
+ parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params")
+ parser.add_argument(
+ "--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder"
+ )
+
+ parser.add_argument(
+ "--adam_epsilon",
+ type=float,
+ default=1e-08,
+ help="Epsilon value for the Adam optimizer and Prodigy optimizers.",
+ )
+
+ parser.add_argument(
+ "--prodigy_use_bias_correction",
+ type=bool,
+ default=True,
+ help="Turn on Adam's bias correction. True by default. Ignored if optimizer is adamW",
+ )
+ parser.add_argument(
+ "--prodigy_safeguard_warmup",
+ type=bool,
+ default=True,
+ help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. "
+ "Ignored if optimizer is adamW",
+ )
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument(
+ "--prior_generation_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp32", "fp16", "bf16"],
+ help=(
+ "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32."
+ ),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+ parser.add_argument(
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
+ )
+ parser.add_argument(
+ "--rank",
+ type=int,
+ default=4,
+ help=("The dimension of the LoRA update matrices."),
+ )
+ parser.add_argument(
+ "--use_dora",
+ action="store_true",
+ default=False,
+ help=(
+ "Wether to train a DoRA as proposed in- DoRA: Weight-Decomposed Low-Rank Adaptation https://arxiv.org/abs/2402.09353. "
+ "Note: to use DoRA you need to install peft from main, `pip install git+https://github.com/huggingface/peft.git`"
+ ),
+ )
+
+ if input_args is not None:
+ args = parser.parse_args(input_args)
+ else:
+ args = parser.parse_args()
+
+ if args.dataset_name is None and args.instance_data_dir is None:
+ raise ValueError("Specify either `--dataset_name` or `--instance_data_dir`")
+
+ if args.dataset_name is not None and args.instance_data_dir is not None:
+ raise ValueError("Specify only one of `--dataset_name` or `--instance_data_dir`")
+
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ if args.with_prior_preservation:
+ if args.class_data_dir is None:
+ raise ValueError("You must specify a data directory for class images.")
+ if args.class_prompt is None:
+ raise ValueError("You must specify prompt for class images.")
+ else:
+ # logger is not available yet
+ if args.class_data_dir is not None:
+ warnings.warn("You need not use --class_data_dir without --with_prior_preservation.")
+ if args.class_prompt is not None:
+ warnings.warn("You need not use --class_prompt without --with_prior_preservation.")
+
+ return args
+
+
+class DreamBoothDataset(Dataset):
+ """
+ A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
+ It pre-processes the images.
+ """
+
+ def __init__(
+ self,
+ instance_data_root,
+ instance_prompt,
+ class_prompt,
+ class_data_root=None,
+ class_num=None,
+ size=1024,
+ repeats=1,
+ center_crop=False,
+ ):
+ self.size = size
+ self.center_crop = center_crop
+
+ self.instance_prompt = instance_prompt
+ self.custom_instance_prompts = None
+ self.class_prompt = class_prompt
+
+ # if --dataset_name is provided or a metadata jsonl file is provided in the local --instance_data directory,
+ # we load the training data using load_dataset
+ if args.dataset_name is not None:
+ try:
+ from datasets import load_dataset
+ except ImportError:
+ raise ImportError(
+ "You are trying to load your data using the datasets library. If you wish to train using custom "
+ "captions please install the datasets library: `pip install datasets`. If you wish to load a "
+ "local folder containing images only, specify --instance_data_dir instead."
+ )
+ # Downloading and loading a dataset from the hub.
+ # See more about loading custom images at
+ # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script
+ dataset = load_dataset(
+ args.dataset_name,
+ args.dataset_config_name,
+ cache_dir=args.cache_dir,
+ )
+ # Preprocessing the datasets.
+ column_names = dataset["train"].column_names
+
+ # 6. Get the column names for input/target.
+ if args.image_column is None:
+ image_column = column_names[0]
+ logger.info(f"image column defaulting to {image_column}")
+ else:
+ image_column = args.image_column
+ if image_column not in column_names:
+ raise ValueError(
+ f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
+ )
+ instance_images = dataset["train"][image_column]
+
+ if args.caption_column is None:
+ logger.info(
+ "No caption column provided, defaulting to instance_prompt for all images. If your dataset "
+ "contains captions/prompts for the images, make sure to specify the "
+ "column as --caption_column"
+ )
+ self.custom_instance_prompts = None
+ else:
+ if args.caption_column not in column_names:
+ raise ValueError(
+ f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
+ )
+ custom_instance_prompts = dataset["train"][args.caption_column]
+ # create final list of captions according to --repeats
+ self.custom_instance_prompts = []
+ for caption in custom_instance_prompts:
+ self.custom_instance_prompts.extend(itertools.repeat(caption, repeats))
+ else:
+ self.instance_data_root = Path(instance_data_root)
+ if not self.instance_data_root.exists():
+ raise ValueError("Instance images root doesn't exists.")
+
+ instance_images = [Image.open(path) for path in list(Path(instance_data_root).iterdir())]
+ self.custom_instance_prompts = None
+
+ self.instance_images = []
+ for img in instance_images:
+ self.instance_images.extend(itertools.repeat(img, repeats))
+
+ # image processing to prepare for using SD-XL micro-conditioning
+ self.original_sizes = []
+ self.crop_top_lefts = []
+ self.pixel_values = []
+ train_resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR)
+ train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size)
+ train_flip = transforms.RandomHorizontalFlip(p=1.0)
+ train_transforms = transforms.Compose(
+ [
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+ )
+ for image in self.instance_images:
+ image = exif_transpose(image)
+ if not image.mode == "RGB":
+ image = image.convert("RGB")
+ self.original_sizes.append((image.height, image.width))
+ image = train_resize(image)
+ if args.random_flip and random.random() < 0.5:
+ # flip
+ image = train_flip(image)
+ if args.center_crop:
+ y1 = max(0, int(round((image.height - args.resolution) / 2.0)))
+ x1 = max(0, int(round((image.width - args.resolution) / 2.0)))
+ image = train_crop(image)
+ else:
+ y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution))
+ image = crop(image, y1, x1, h, w)
+ crop_top_left = (y1, x1)
+ self.crop_top_lefts.append(crop_top_left)
+ image = train_transforms(image)
+ self.pixel_values.append(image)
+
+ self.num_instance_images = len(self.instance_images)
+ self._length = self.num_instance_images
+
+ if class_data_root is not None:
+ self.class_data_root = Path(class_data_root)
+ self.class_data_root.mkdir(parents=True, exist_ok=True)
+ self.class_images_path = list(self.class_data_root.iterdir())
+ if class_num is not None:
+ self.num_class_images = min(len(self.class_images_path), class_num)
+ else:
+ self.num_class_images = len(self.class_images_path)
+ self._length = max(self.num_class_images, self.num_instance_images)
+ else:
+ self.class_data_root = None
+
+ self.image_transforms = transforms.Compose(
+ [
+ transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
+ transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+ )
+
+ def __len__(self):
+ return self._length
+
+ def __getitem__(self, index):
+ example = {}
+ instance_image = self.pixel_values[index % self.num_instance_images]
+ original_size = self.original_sizes[index % self.num_instance_images]
+ crop_top_left = self.crop_top_lefts[index % self.num_instance_images]
+ example["instance_images"] = instance_image
+ example["original_size"] = original_size
+ example["crop_top_left"] = crop_top_left
+
+ if self.custom_instance_prompts:
+ caption = self.custom_instance_prompts[index % self.num_instance_images]
+ if caption:
+ example["instance_prompt"] = caption
+ else:
+ example["instance_prompt"] = self.instance_prompt
+
+ else: # custom prompts were provided, but length does not match size of image dataset
+ example["instance_prompt"] = self.instance_prompt
+
+ if self.class_data_root:
+ class_image = Image.open(self.class_images_path[index % self.num_class_images])
+ class_image = exif_transpose(class_image)
+
+ if not class_image.mode == "RGB":
+ class_image = class_image.convert("RGB")
+ example["class_images"] = self.image_transforms(class_image)
+ example["class_prompt"] = self.class_prompt
+
+ return example
+
+
+def collate_fn(examples, with_prior_preservation=False):
+ pixel_values = [example["instance_images"] for example in examples]
+ prompts = [example["instance_prompt"] for example in examples]
+ original_sizes = [example["original_size"] for example in examples]
+ crop_top_lefts = [example["crop_top_left"] for example in examples]
+
+ # Concat class and instance examples for prior preservation.
+ # We do this to avoid doing two forward passes.
+ if with_prior_preservation:
+ pixel_values += [example["class_images"] for example in examples]
+ prompts += [example["class_prompt"] for example in examples]
+ original_sizes += [example["original_size"] for example in examples]
+ crop_top_lefts += [example["crop_top_left"] for example in examples]
+
+ pixel_values = torch.stack(pixel_values)
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
+
+ batch = {
+ "pixel_values": pixel_values,
+ "prompts": prompts,
+ "original_sizes": original_sizes,
+ "crop_top_lefts": crop_top_lefts,
+ }
+ return batch
+
+
+class PromptDataset(Dataset):
+ """A simple dataset to prepare the prompts to generate class images on multiple GPUs."""
+
+ def __init__(self, prompt, num_samples):
+ self.prompt = prompt
+ self.num_samples = num_samples
+
+ def __len__(self):
+ return self.num_samples
+
+ def __getitem__(self, index):
+ example = {}
+ example["prompt"] = self.prompt
+ example["index"] = index
+ return example
+
+
+def tokenize_prompt(tokenizer, prompt):
+ text_inputs = tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ return text_input_ids
+
+
+# Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt
+def encode_prompt(text_encoders, tokenizers, prompt, text_input_ids_list=None):
+ prompt_embeds_list = []
+
+ for i, text_encoder in enumerate(text_encoders):
+ if tokenizers is not None:
+ tokenizer = tokenizers[i]
+ text_input_ids = tokenize_prompt(tokenizer, prompt)
+ else:
+ assert text_input_ids_list is not None
+ text_input_ids = text_input_ids_list[i]
+
+ prompt_embeds = text_encoder(
+ text_input_ids.to(text_encoder.device), output_hidden_states=True, return_dict=False
+ )
+
+ # We are only ALWAYS interested in the pooled output of the final text encoder
+ pooled_prompt_embeds = prompt_embeds[0]
+ prompt_embeds = prompt_embeds[-1][-2]
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
+ prompt_embeds_list.append(prompt_embeds)
+
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
+ pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
+ return prompt_embeds, pooled_prompt_embeds
+
+
+def main(args):
+ if args.report_to == "wandb" and args.hub_token is not None:
+ raise ValueError(
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
+ " Please use `huggingface-cli login` to authenticate with the Hub."
+ )
+
+ if args.do_edm_style_training and args.snr_gamma is not None:
+ raise ValueError("Min-SNR formulation is not supported when conducting EDM-style training.")
+
+ if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
+ # due to pytorch#99272, MPS does not yet support bfloat16.
+ raise ValueError(
+ "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
+ )
+
+ logging_dir = Path(args.output_dir, args.logging_dir)
+
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+ kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ kwargs_handlers=[kwargs],
+ )
+
+ # Disable AMP for MPS.
+ if torch.backends.mps.is_available():
+ accelerator.native_amp = False
+
+ if args.report_to == "wandb":
+ if not is_wandb_available():
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Generate class images if prior preservation is enabled.
+ if args.with_prior_preservation:
+ class_images_dir = Path(args.class_data_dir)
+ if not class_images_dir.exists():
+ class_images_dir.mkdir(parents=True)
+ cur_class_images = len(list(class_images_dir.iterdir()))
+
+ if cur_class_images < args.num_class_images:
+ has_supported_fp16_accelerator = torch.cuda.is_available() or torch.backends.mps.is_available()
+ torch_dtype = torch.float16 if has_supported_fp16_accelerator else torch.float32
+ if args.prior_generation_precision == "fp32":
+ torch_dtype = torch.float32
+ elif args.prior_generation_precision == "fp16":
+ torch_dtype = torch.float16
+ elif args.prior_generation_precision == "bf16":
+ torch_dtype = torch.bfloat16
+ pipeline = StableDiffusionXLPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ torch_dtype=torch_dtype,
+ revision=args.revision,
+ variant=args.variant,
+ )
+ pipeline.set_progress_bar_config(disable=True)
+
+ num_new_images = args.num_class_images - cur_class_images
+ logger.info(f"Number of class images to sample: {num_new_images}.")
+
+ sample_dataset = PromptDataset(args.class_prompt, num_new_images)
+ sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
+
+ sample_dataloader = accelerator.prepare(sample_dataloader)
+ pipeline.to(accelerator.device)
+
+ for example in tqdm(
+ sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
+ ):
+ images = pipeline(example["prompt"]).images
+
+ for i, image in enumerate(images):
+ hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest()
+ image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
+ image.save(image_filename)
+
+ del pipeline
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
+ ).repo_id
+
+ # Load the tokenizers
+ tokenizer_one = AutoTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="tokenizer",
+ revision=args.revision,
+ use_fast=False,
+ )
+ tokenizer_two = AutoTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="tokenizer_2",
+ revision=args.revision,
+ use_fast=False,
+ )
+
+ # import correct text encoder classes
+ text_encoder_cls_one = import_model_class_from_model_name_or_path(
+ args.pretrained_model_name_or_path, args.revision
+ )
+ text_encoder_cls_two = import_model_class_from_model_name_or_path(
+ args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2"
+ )
+
+ # Load scheduler and models
+ scheduler_type = determine_scheduler_type(args.pretrained_model_name_or_path, args.revision)
+ if "EDM" in scheduler_type:
+ args.do_edm_style_training = True
+ noise_scheduler = EDMEulerScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
+ logger.info("Performing EDM-style training!")
+ elif args.do_edm_style_training:
+ noise_scheduler = EulerDiscreteScheduler.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="scheduler"
+ )
+ logger.info("Performing EDM-style training!")
+ else:
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
+
+ text_encoder_one = text_encoder_cls_one.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
+ )
+ text_encoder_two = text_encoder_cls_two.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant
+ )
+ vae_path = (
+ args.pretrained_model_name_or_path
+ if args.pretrained_vae_model_name_or_path is None
+ else args.pretrained_vae_model_name_or_path
+ )
+ vae = AutoencoderKL.from_pretrained(
+ vae_path,
+ subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
+ revision=args.revision,
+ variant=args.variant,
+ )
+ latents_mean = latents_std = None
+ if hasattr(vae.config, "latents_mean") and vae.config.latents_mean is not None:
+ latents_mean = torch.tensor(vae.config.latents_mean).view(1, 4, 1, 1)
+ if hasattr(vae.config, "latents_std") and vae.config.latents_std is not None:
+ latents_std = torch.tensor(vae.config.latents_std).view(1, 4, 1, 1)
+
+ unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
+ )
+
+ # We only train the additional adapter LoRA layers
+ vae.requires_grad_(False)
+ text_encoder_one.requires_grad_(False)
+ text_encoder_two.requires_grad_(False)
+ unet.requires_grad_(False)
+
+ # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
+ # as these weights are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:
+ # due to pytorch#99272, MPS does not yet support bfloat16.
+ raise ValueError(
+ "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
+ )
+
+ # Move unet, vae and text_encoder to device and cast to weight_dtype
+ unet.to(accelerator.device, dtype=weight_dtype)
+
+ # The VAE is always in float32 to avoid NaN losses.
+ vae.to(accelerator.device, dtype=torch.float32)
+
+ text_encoder_one.to(accelerator.device, dtype=weight_dtype)
+ text_encoder_two.to(accelerator.device, dtype=weight_dtype)
+
+ if args.enable_xformers_memory_efficient_attention:
+ if is_xformers_available():
+ import xformers
+
+ xformers_version = version.parse(xformers.__version__)
+ if xformers_version == version.parse("0.0.16"):
+ logger.warning(
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, "
+ "please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
+ )
+ unet.enable_xformers_memory_efficient_attention()
+ else:
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
+
+ if args.gradient_checkpointing:
+ unet.enable_gradient_checkpointing()
+ if args.train_text_encoder:
+ text_encoder_one.gradient_checkpointing_enable()
+ text_encoder_two.gradient_checkpointing_enable()
+
+ # now we will add new LoRA weights to the attention layers
+ unet_lora_config = LoraConfig(
+ r=args.rank,
+ use_dora=args.use_dora,
+ lora_alpha=args.rank,
+ init_lora_weights="gaussian",
+ target_modules=["to_k", "to_q", "to_v", "to_out.0"],
+ )
+ unet.add_adapter(unet_lora_config)
+
+ # The text encoder comes from 🤗 transformers, so we cannot directly modify it.
+ # So, instead, we monkey-patch the forward calls of its attention-blocks.
+ if args.train_text_encoder:
+ text_lora_config = LoraConfig(
+ r=args.rank,
+ use_dora=args.use_dora,
+ lora_alpha=args.rank,
+ init_lora_weights="gaussian",
+ target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
+ )
+ text_encoder_one.add_adapter(text_lora_config)
+ text_encoder_two.add_adapter(text_lora_config)
+
+ def unwrap_model(model):
+ model = accelerator.unwrap_model(model)
+ model = model._orig_mod if is_compiled_module(model) else model
+ return model
+
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
+ def save_model_hook(models, weights, output_dir):
+ if accelerator.is_main_process:
+ # there are only two options here. Either are just the unet attn processor layers
+ # or there are the unet and text encoder atten layers
+ unet_lora_layers_to_save = None
+ text_encoder_one_lora_layers_to_save = None
+ text_encoder_two_lora_layers_to_save = None
+
+ for model in models:
+ if isinstance(model, type(unwrap_model(unet))):
+ unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model))
+ elif isinstance(model, type(unwrap_model(text_encoder_one))):
+ text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers(
+ get_peft_model_state_dict(model)
+ )
+ elif isinstance(model, type(unwrap_model(text_encoder_two))):
+ text_encoder_two_lora_layers_to_save = convert_state_dict_to_diffusers(
+ get_peft_model_state_dict(model)
+ )
+ else:
+ raise ValueError(f"unexpected save model: {model.__class__}")
+
+ # make sure to pop weight so that corresponding model is not saved again
+ weights.pop()
+
+ StableDiffusionXLPipeline.save_lora_weights(
+ output_dir,
+ unet_lora_layers=unet_lora_layers_to_save,
+ text_encoder_lora_layers=text_encoder_one_lora_layers_to_save,
+ text_encoder_2_lora_layers=text_encoder_two_lora_layers_to_save,
+ )
+
+ def load_model_hook(models, input_dir):
+ unet_ = None
+ text_encoder_one_ = None
+ text_encoder_two_ = None
+
+ while len(models) > 0:
+ model = models.pop()
+
+ if isinstance(model, type(unwrap_model(unet))):
+ unet_ = model
+ elif isinstance(model, type(unwrap_model(text_encoder_one))):
+ text_encoder_one_ = model
+ elif isinstance(model, type(unwrap_model(text_encoder_two))):
+ text_encoder_two_ = model
+ else:
+ raise ValueError(f"unexpected save model: {model.__class__}")
+
+ lora_state_dict, network_alphas = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir)
+
+ unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")}
+ unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
+ incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
+ if incompatible_keys is not None:
+ # check only for unexpected keys
+ unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
+ if unexpected_keys:
+ logger.warning(
+ f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
+ f" {unexpected_keys}. "
+ )
+
+ if args.train_text_encoder:
+ # Do we need to call `scale_lora_layers()` here?
+ _set_state_dict_into_text_encoder(lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_one_)
+
+ _set_state_dict_into_text_encoder(
+ lora_state_dict, prefix="text_encoder_2.", text_encoder=text_encoder_two_
+ )
+
+ # Make sure the trainable params are in float32. This is again needed since the base models
+ # are in `weight_dtype`. More details:
+ # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
+ if args.mixed_precision == "fp16":
+ models = [unet_]
+ if args.train_text_encoder:
+ models.extend([text_encoder_one_, text_encoder_two_])
+ # only upcast trainable parameters (LoRA) into fp32
+ cast_training_params(models)
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32 and torch.cuda.is_available():
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ # Make sure the trainable params are in float32.
+ if args.mixed_precision == "fp16":
+ models = [unet]
+ if args.train_text_encoder:
+ models.extend([text_encoder_one, text_encoder_two])
+
+ # only upcast trainable parameters (LoRA) into fp32
+ cast_training_params(models, dtype=torch.float32)
+
+ unet_lora_parameters = list(filter(lambda p: p.requires_grad, unet.parameters()))
+
+ if args.train_text_encoder:
+ text_lora_parameters_one = list(filter(lambda p: p.requires_grad, text_encoder_one.parameters()))
+ text_lora_parameters_two = list(filter(lambda p: p.requires_grad, text_encoder_two.parameters()))
+
+ # Optimization parameters
+ unet_lora_parameters_with_lr = {"params": unet_lora_parameters, "lr": args.learning_rate}
+ if args.train_text_encoder:
+ # different learning rate for text encoder and unet
+ text_lora_parameters_one_with_lr = {
+ "params": text_lora_parameters_one,
+ "weight_decay": args.adam_weight_decay_text_encoder,
+ "lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate,
+ }
+ text_lora_parameters_two_with_lr = {
+ "params": text_lora_parameters_two,
+ "weight_decay": args.adam_weight_decay_text_encoder,
+ "lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate,
+ }
+ params_to_optimize = [
+ unet_lora_parameters_with_lr,
+ text_lora_parameters_one_with_lr,
+ text_lora_parameters_two_with_lr,
+ ]
+ else:
+ params_to_optimize = [unet_lora_parameters_with_lr]
+
+ # Optimizer creation
+ if not (args.optimizer.lower() == "prodigy" or args.optimizer.lower() == "adamw"):
+ logger.warning(
+ f"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy]."
+ "Defaulting to adamW"
+ )
+ args.optimizer = "adamw"
+
+ if args.use_8bit_adam and not args.optimizer.lower() == "adamw":
+ logger.warning(
+ f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was "
+ f"set to {args.optimizer.lower()}"
+ )
+
+ if args.optimizer.lower() == "adamw":
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
+ )
+
+ optimizer_class = bnb.optim.AdamW8bit
+ else:
+ optimizer_class = torch.optim.AdamW
+
+ optimizer = optimizer_class(
+ params_to_optimize,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ if args.optimizer.lower() == "prodigy":
+ try:
+ import prodigyopt
+ except ImportError:
+ raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`")
+
+ optimizer_class = prodigyopt.Prodigy
+
+ if args.learning_rate <= 0.1:
+ logger.warning(
+ "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0"
+ )
+ if args.train_text_encoder and args.text_encoder_lr:
+ logger.warning(
+ f"Learning rates were provided both for the unet and the text encoder- e.g. text_encoder_lr:"
+ f" {args.text_encoder_lr} and learning_rate: {args.learning_rate}. "
+ f"When using prodigy only learning_rate is used as the initial learning rate."
+ )
+ # changes the learning rate of text_encoder_parameters_one and text_encoder_parameters_two to be
+ # --learning_rate
+ params_to_optimize[1]["lr"] = args.learning_rate
+ params_to_optimize[2]["lr"] = args.learning_rate
+
+ optimizer = optimizer_class(
+ params_to_optimize,
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ beta3=args.prodigy_beta3,
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ decouple=args.prodigy_decouple,
+ use_bias_correction=args.prodigy_use_bias_correction,
+ safeguard_warmup=args.prodigy_safeguard_warmup,
+ )
+
+ # Dataset and DataLoaders creation:
+ train_dataset = DreamBoothDataset(
+ instance_data_root=args.instance_data_dir,
+ instance_prompt=args.instance_prompt,
+ class_prompt=args.class_prompt,
+ class_data_root=args.class_data_dir if args.with_prior_preservation else None,
+ class_num=args.num_class_images,
+ size=args.resolution,
+ repeats=args.repeats,
+ center_crop=args.center_crop,
+ )
+
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset,
+ batch_size=args.train_batch_size,
+ shuffle=True,
+ collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),
+ num_workers=args.dataloader_num_workers,
+ )
+
+ # Computes additional embeddings/ids required by the SDXL UNet.
+ # regular text embeddings (when `train_text_encoder` is not True)
+ # pooled text embeddings
+ # time ids
+
+ def compute_time_ids(original_size, crops_coords_top_left):
+ # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
+ target_size = (args.resolution, args.resolution)
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
+ add_time_ids = torch.tensor([add_time_ids])
+ add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype)
+ return add_time_ids
+
+ if not args.train_text_encoder:
+ tokenizers = [tokenizer_one, tokenizer_two]
+ text_encoders = [text_encoder_one, text_encoder_two]
+
+ def compute_text_embeddings(prompt, text_encoders, tokenizers):
+ with torch.no_grad():
+ prompt_embeds, pooled_prompt_embeds = encode_prompt(text_encoders, tokenizers, prompt)
+ prompt_embeds = prompt_embeds.to(accelerator.device)
+ pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device)
+ return prompt_embeds, pooled_prompt_embeds
+
+ # If no type of tuning is done on the text_encoder and custom instance prompts are NOT
+ # provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid
+ # the redundant encoding.
+ if not args.train_text_encoder and not train_dataset.custom_instance_prompts:
+ instance_prompt_hidden_states, instance_pooled_prompt_embeds = compute_text_embeddings(
+ args.instance_prompt, text_encoders, tokenizers
+ )
+
+ # Handle class prompt for prior-preservation.
+ if args.with_prior_preservation:
+ if not args.train_text_encoder:
+ class_prompt_hidden_states, class_pooled_prompt_embeds = compute_text_embeddings(
+ args.class_prompt, text_encoders, tokenizers
+ )
+
+ # Clear the memory here
+ if not args.train_text_encoder and not train_dataset.custom_instance_prompts:
+ del tokenizers, text_encoders
+ gc.collect()
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+
+ # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
+ # pack the statically computed variables appropriately here. This is so that we don't
+ # have to pass them to the dataloader.
+
+ if not train_dataset.custom_instance_prompts:
+ if not args.train_text_encoder:
+ prompt_embeds = instance_prompt_hidden_states
+ unet_add_text_embeds = instance_pooled_prompt_embeds
+ if args.with_prior_preservation:
+ prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0)
+ unet_add_text_embeds = torch.cat([unet_add_text_embeds, class_pooled_prompt_embeds], dim=0)
+ # if we're optimizing the text encoder (both if instance prompt is used for all images or custom prompts) we need to tokenize and encode the
+ # batch prompts on all training steps
+ else:
+ tokens_one = tokenize_prompt(tokenizer_one, args.instance_prompt)
+ tokens_two = tokenize_prompt(tokenizer_two, args.instance_prompt)
+ if args.with_prior_preservation:
+ class_tokens_one = tokenize_prompt(tokenizer_one, args.class_prompt)
+ class_tokens_two = tokenize_prompt(tokenizer_two, args.class_prompt)
+ tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0)
+ tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0)
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
+ num_cycles=args.lr_num_cycles,
+ power=args.lr_power,
+ )
+
+ # Prepare everything with our `accelerator`.
+ if args.train_text_encoder:
+ unet, text_encoder_one, text_encoder_two, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ unet, text_encoder_one, text_encoder_two, optimizer, train_dataloader, lr_scheduler
+ )
+ else:
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ unet, optimizer, train_dataloader, lr_scheduler
+ )
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ tracker_name = (
+ "dreambooth-lora-sd-xl"
+ if "playground" not in args.pretrained_model_name_or_path
+ else "dreambooth-lora-playground"
+ )
+ accelerator.init_trackers(tracker_name, config=vars(args))
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the mos recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ initial_global_step = 0
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ initial_global_step = global_step
+ first_epoch = global_step // num_update_steps_per_epoch
+
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
+
+ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
+ sigmas = noise_scheduler.sigmas.to(device=accelerator.device, dtype=dtype)
+ schedule_timesteps = noise_scheduler.timesteps.to(accelerator.device)
+ timesteps = timesteps.to(accelerator.device)
+
+ step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
+
+ sigma = sigmas[step_indices].flatten()
+ while len(sigma.shape) < n_dim:
+ sigma = sigma.unsqueeze(-1)
+ return sigma
+
+ for epoch in range(first_epoch, args.num_train_epochs):
+ unet.train()
+ if args.train_text_encoder:
+ text_encoder_one.train()
+ text_encoder_two.train()
+
+ # set top parameter requires_grad = True for gradient checkpointing works
+ accelerator.unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True)
+ accelerator.unwrap_model(text_encoder_two).text_model.embeddings.requires_grad_(True)
+
+ for step, batch in enumerate(train_dataloader):
+ with accelerator.accumulate(unet):
+ pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
+ prompts = batch["prompts"]
+
+ # encode batch prompts when custom prompts are provided for each image -
+ if train_dataset.custom_instance_prompts:
+ if not args.train_text_encoder:
+ prompt_embeds, unet_add_text_embeds = compute_text_embeddings(
+ prompts, text_encoders, tokenizers
+ )
+ else:
+ tokens_one = tokenize_prompt(tokenizer_one, prompts)
+ tokens_two = tokenize_prompt(tokenizer_two, prompts)
+
+ # Convert images to latent space
+ model_input = vae.encode(pixel_values).latent_dist.sample()
+
+ if latents_mean is None and latents_std is None:
+ model_input = model_input * vae.config.scaling_factor
+ if args.pretrained_vae_model_name_or_path is None:
+ model_input = model_input.to(weight_dtype)
+ else:
+ latents_mean = latents_mean.to(device=model_input.device, dtype=model_input.dtype)
+ latents_std = latents_std.to(device=model_input.device, dtype=model_input.dtype)
+ model_input = (model_input - latents_mean) * vae.config.scaling_factor / latents_std
+ model_input = model_input.to(dtype=weight_dtype)
+
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(model_input)
+ bsz = model_input.shape[0]
+
+ # Sample a random timestep for each image
+ if not args.do_edm_style_training:
+ timesteps = torch.randint(
+ 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device
+ )
+ timesteps = timesteps.long()
+ else:
+ # in EDM formulation, the model is conditioned on the pre-conditioned noise levels
+ # instead of discrete timesteps, so here we sample indices to get the noise levels
+ # from `scheduler.timesteps`
+ indices = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,))
+ timesteps = noise_scheduler.timesteps[indices].to(device=model_input.device)
+
+ # Add noise to the model input according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)
+ # For EDM-style training, we first obtain the sigmas based on the continuous timesteps.
+ # We then precondition the final model inputs based on these sigmas instead of the timesteps.
+ # Follow: Section 5 of https://arxiv.org/abs/2206.00364.
+ if args.do_edm_style_training:
+ sigmas = get_sigmas(timesteps, len(noisy_model_input.shape), noisy_model_input.dtype)
+ if "EDM" in scheduler_type:
+ inp_noisy_latents = noise_scheduler.precondition_inputs(noisy_model_input, sigmas)
+ else:
+ inp_noisy_latents = noisy_model_input / ((sigmas**2 + 1) ** 0.5)
+
+ # time ids
+ add_time_ids = torch.cat(
+ [
+ compute_time_ids(original_size=s, crops_coords_top_left=c)
+ for s, c in zip(batch["original_sizes"], batch["crop_top_lefts"])
+ ]
+ )
+
+ # Calculate the elements to repeat depending on the use of prior-preservation and custom captions.
+ if not train_dataset.custom_instance_prompts:
+ elems_to_repeat_text_embeds = bsz // 2 if args.with_prior_preservation else bsz
+ else:
+ elems_to_repeat_text_embeds = 1
+
+ # Predict the noise residual
+ if not args.train_text_encoder:
+ unet_added_conditions = {
+ "time_ids": add_time_ids,
+ "text_embeds": unet_add_text_embeds.repeat(elems_to_repeat_text_embeds, 1),
+ }
+ prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1)
+ model_pred = unet(
+ inp_noisy_latents if args.do_edm_style_training else noisy_model_input,
+ timesteps,
+ prompt_embeds_input,
+ added_cond_kwargs=unet_added_conditions,
+ return_dict=False,
+ )[0]
+ else:
+ unet_added_conditions = {"time_ids": add_time_ids}
+ prompt_embeds, pooled_prompt_embeds = encode_prompt(
+ text_encoders=[text_encoder_one, text_encoder_two],
+ tokenizers=None,
+ prompt=None,
+ text_input_ids_list=[tokens_one, tokens_two],
+ )
+ unet_added_conditions.update(
+ {"text_embeds": pooled_prompt_embeds.repeat(elems_to_repeat_text_embeds, 1)}
+ )
+ prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1)
+ model_pred = unet(
+ inp_noisy_latents if args.do_edm_style_training else noisy_model_input,
+ timesteps,
+ prompt_embeds_input,
+ added_cond_kwargs=unet_added_conditions,
+ return_dict=False,
+ )[0]
+
+ weighting = None
+ if args.do_edm_style_training:
+ # Similar to the input preconditioning, the model predictions are also preconditioned
+ # on noised model inputs (before preconditioning) and the sigmas.
+ # Follow: Section 5 of https://arxiv.org/abs/2206.00364.
+ if "EDM" in scheduler_type:
+ model_pred = noise_scheduler.precondition_outputs(noisy_model_input, model_pred, sigmas)
+ else:
+ if noise_scheduler.config.prediction_type == "epsilon":
+ model_pred = model_pred * (-sigmas) + noisy_model_input
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ model_pred = model_pred * (-sigmas / (sigmas**2 + 1) ** 0.5) + (
+ noisy_model_input / (sigmas**2 + 1)
+ )
+ # We are not doing weighting here because it tends result in numerical problems.
+ # See: https://github.com/huggingface/diffusers/pull/7126#issuecomment-1968523051
+ # There might be other alternatives for weighting as well:
+ # https://github.com/huggingface/diffusers/pull/7126#discussion_r1505404686
+ if "EDM" not in scheduler_type:
+ weighting = (sigmas**-2.0).float()
+
+ # Get the target for loss depending on the prediction type
+ if noise_scheduler.config.prediction_type == "epsilon":
+ target = model_input if args.do_edm_style_training else noise
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ target = (
+ model_input
+ if args.do_edm_style_training
+ else noise_scheduler.get_velocity(model_input, noise, timesteps)
+ )
+ else:
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
+
+ if args.with_prior_preservation:
+ # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
+ model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
+ target, target_prior = torch.chunk(target, 2, dim=0)
+
+ # Compute prior loss
+ if weighting is not None:
+ prior_loss = torch.mean(
+ (weighting.float() * (model_pred_prior.float() - target_prior.float()) ** 2).reshape(
+ target_prior.shape[0], -1
+ ),
+ 1,
+ )
+ prior_loss = prior_loss.mean()
+ else:
+ prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
+
+ if args.snr_gamma is None:
+ if weighting is not None:
+ loss = torch.mean(
+ (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(
+ target.shape[0], -1
+ ),
+ 1,
+ )
+ loss = loss.mean()
+ else:
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
+ else:
+ # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
+ # Since we predict the noise instead of x_0, the original formulation is slightly changed.
+ # This is discussed in Section 4.2 of the same paper.
+ snr = compute_snr(noise_scheduler, timesteps)
+ base_weight = (
+ torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
+ )
+
+ if noise_scheduler.config.prediction_type == "v_prediction":
+ # Velocity objective needs to be floored to an SNR weight of one.
+ mse_loss_weights = base_weight + 1
+ else:
+ # Epsilon and sample both use the same loss weights.
+ mse_loss_weights = base_weight
+
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
+ loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
+ loss = loss.mean()
+
+ if args.with_prior_preservation:
+ # Add the prior loss to the instance loss.
+ loss = loss + args.prior_loss_weight * prior_loss
+
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ params_to_clip = (
+ itertools.chain(unet_lora_parameters, text_lora_parameters_one, text_lora_parameters_two)
+ if args.train_text_encoder
+ else unet_lora_parameters
+ )
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
+
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ global_step += 1
+
+ if accelerator.is_main_process:
+ if global_step % args.checkpointing_steps == 0:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+ accelerator.log(logs, step=global_step)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ if accelerator.is_main_process:
+ if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
+ # create pipeline
+ if not args.train_text_encoder:
+ text_encoder_one = text_encoder_cls_one.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="text_encoder",
+ revision=args.revision,
+ variant=args.variant,
+ )
+ text_encoder_two = text_encoder_cls_two.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="text_encoder_2",
+ revision=args.revision,
+ variant=args.variant,
+ )
+ pipeline = StableDiffusionXLPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ vae=vae,
+ text_encoder=accelerator.unwrap_model(text_encoder_one),
+ text_encoder_2=accelerator.unwrap_model(text_encoder_two),
+ unet=accelerator.unwrap_model(unet),
+ revision=args.revision,
+ variant=args.variant,
+ torch_dtype=weight_dtype,
+ )
+ pipeline_args = {"prompt": args.validation_prompt}
+
+ images = log_validation(
+ pipeline,
+ args,
+ accelerator,
+ pipeline_args,
+ epoch,
+ torch_dtype=weight_dtype,
+ )
+
+ # Save the lora layers
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ unet = unwrap_model(unet)
+ unet = unet.to(torch.float32)
+ unet_lora_layers = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet))
+
+ if args.train_text_encoder:
+ text_encoder_one = unwrap_model(text_encoder_one)
+ text_encoder_lora_layers = convert_state_dict_to_diffusers(
+ get_peft_model_state_dict(text_encoder_one.to(torch.float32))
+ )
+ text_encoder_two = unwrap_model(text_encoder_two)
+ text_encoder_2_lora_layers = convert_state_dict_to_diffusers(
+ get_peft_model_state_dict(text_encoder_two.to(torch.float32))
+ )
+ else:
+ text_encoder_lora_layers = None
+ text_encoder_2_lora_layers = None
+
+ StableDiffusionXLPipeline.save_lora_weights(
+ save_directory=args.output_dir,
+ unet_lora_layers=unet_lora_layers,
+ text_encoder_lora_layers=text_encoder_lora_layers,
+ text_encoder_2_lora_layers=text_encoder_2_lora_layers,
+ )
+ if args.output_kohya_format:
+ lora_state_dict = load_file(f"{args.output_dir}/pytorch_lora_weights.safetensors")
+ peft_state_dict = convert_all_state_dict_to_peft(lora_state_dict)
+ kohya_state_dict = convert_state_dict_to_kohya(peft_state_dict)
+ save_file(kohya_state_dict, f"{args.output_dir}/pytorch_lora_weights_kohya.safetensors")
+
+ # Final inference
+ # Load previous pipeline
+ vae = AutoencoderKL.from_pretrained(
+ vae_path,
+ subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
+ revision=args.revision,
+ variant=args.variant,
+ torch_dtype=weight_dtype,
+ )
+ pipeline = StableDiffusionXLPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ vae=vae,
+ revision=args.revision,
+ variant=args.variant,
+ torch_dtype=weight_dtype,
+ )
+
+ # load attention processors
+ pipeline.load_lora_weights(args.output_dir)
+
+ # run inference
+ images = []
+ if args.validation_prompt and args.num_validation_images > 0:
+ pipeline_args = {"prompt": args.validation_prompt, "num_inference_steps": 25}
+ images = log_validation(
+ pipeline,
+ args,
+ accelerator,
+ pipeline_args,
+ epoch,
+ is_final_validation=True,
+ torch_dtype=weight_dtype,
+ )
+
+ if args.push_to_hub:
+ save_model_card(
+ repo_id,
+ use_dora=args.use_dora,
+ images=images,
+ base_model=args.pretrained_model_name_or_path,
+ train_text_encoder=args.train_text_encoder,
+ instance_prompt=args.instance_prompt,
+ validation_prompt=args.validation_prompt,
+ repo_folder=args.output_dir,
+ vae_path=args.pretrained_vae_model_name_or_path,
+ )
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ main(args)
diff --git a/diffusers/examples/dreambooth/train_dreambooth_sd3.py b/diffusers/examples/dreambooth/train_dreambooth_sd3.py
new file mode 100644
index 0000000000000000000000000000000000000000..c34024f478c163b9f130a39aa3648a5c45cae918
--- /dev/null
+++ b/diffusers/examples/dreambooth/train_dreambooth_sd3.py
@@ -0,0 +1,1805 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+import argparse
+import copy
+import gc
+import itertools
+import logging
+import math
+import os
+import random
+import shutil
+import warnings
+from contextlib import nullcontext
+from pathlib import Path
+
+import numpy as np
+import torch
+import torch.utils.checkpoint
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
+from huggingface_hub import create_repo, upload_folder
+from huggingface_hub.utils import insecure_hashlib
+from PIL import Image
+from PIL.ImageOps import exif_transpose
+from torch.utils.data import Dataset
+from torchvision import transforms
+from torchvision.transforms.functional import crop
+from tqdm.auto import tqdm
+from transformers import CLIPTextModelWithProjection, CLIPTokenizer, PretrainedConfig, T5EncoderModel, T5TokenizerFast
+
+import diffusers
+from diffusers import (
+ AutoencoderKL,
+ FlowMatchEulerDiscreteScheduler,
+ SD3Transformer2DModel,
+ StableDiffusion3Pipeline,
+)
+from diffusers.optimization import get_scheduler
+from diffusers.training_utils import compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3
+from diffusers.utils import (
+ check_min_version,
+ is_wandb_available,
+)
+from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
+from diffusers.utils.torch_utils import is_compiled_module
+
+
+if is_wandb_available():
+ import wandb
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.31.0.dev0")
+
+logger = get_logger(__name__)
+
+
+def save_model_card(
+ repo_id: str,
+ images=None,
+ base_model: str = None,
+ train_text_encoder=False,
+ instance_prompt=None,
+ validation_prompt=None,
+ repo_folder=None,
+):
+ widget_dict = []
+ if images is not None:
+ for i, image in enumerate(images):
+ image.save(os.path.join(repo_folder, f"image_{i}.png"))
+ widget_dict.append(
+ {"text": validation_prompt if validation_prompt else " ", "output": {"url": f"image_{i}.png"}}
+ )
+
+ model_description = f"""
+# SD3 DreamBooth - {repo_id}
+
+
+
+## Model description
+
+These are {repo_id} DreamBooth weights for {base_model}.
+
+The weights were trained using [DreamBooth](https://dreambooth.github.io/) with the [SD3 diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/README_sd3.md).
+
+Was the text encoder fine-tuned? {train_text_encoder}.
+
+## Trigger words
+
+You should use `{instance_prompt}` to trigger the image generation.
+
+## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers)
+
+```py
+from diffusers import AutoPipelineForText2Image
+import torch
+pipeline = AutoPipelineForText2Image.from_pretrained('{repo_id}', torch_dtype=torch.float16).to('cuda')
+image = pipeline('{validation_prompt if validation_prompt else instance_prompt}').images[0]
+```
+
+## License
+
+Please adhere to the licensing terms as described `[here](https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/LICENSE)`.
+"""
+ model_card = load_or_create_model_card(
+ repo_id_or_path=repo_id,
+ from_training=True,
+ license="openrail++",
+ base_model=base_model,
+ prompt=instance_prompt,
+ model_description=model_description,
+ widget=widget_dict,
+ )
+ tags = [
+ "text-to-image",
+ "diffusers-training",
+ "diffusers",
+ "sd3",
+ "sd3-diffusers",
+ "template:sd-lora",
+ ]
+
+ model_card = populate_model_card(model_card, tags=tags)
+ model_card.save(os.path.join(repo_folder, "README.md"))
+
+
+def load_text_encoders(class_one, class_two, class_three):
+ text_encoder_one = class_one.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
+ )
+ text_encoder_two = class_two.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant
+ )
+ text_encoder_three = class_three.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder_3", revision=args.revision, variant=args.variant
+ )
+ return text_encoder_one, text_encoder_two, text_encoder_three
+
+
+def log_validation(
+ pipeline,
+ args,
+ accelerator,
+ pipeline_args,
+ epoch,
+ torch_dtype,
+ is_final_validation=False,
+):
+ logger.info(
+ f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
+ f" {args.validation_prompt}."
+ )
+ pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)
+ pipeline.set_progress_bar_config(disable=True)
+
+ # run inference
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
+ # autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()
+ autocast_ctx = nullcontext()
+
+ with autocast_ctx:
+ images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)]
+
+ for tracker in accelerator.trackers:
+ phase_name = "test" if is_final_validation else "validation"
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images(phase_name, np_images, epoch, dataformats="NHWC")
+ if tracker.name == "wandb":
+ tracker.log(
+ {
+ phase_name: [
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images)
+ ]
+ }
+ )
+
+ del pipeline
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+
+ return images
+
+
+def import_model_class_from_model_name_or_path(
+ pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
+):
+ text_encoder_config = PretrainedConfig.from_pretrained(
+ pretrained_model_name_or_path, subfolder=subfolder, revision=revision
+ )
+ model_class = text_encoder_config.architectures[0]
+ if model_class == "CLIPTextModelWithProjection":
+ from transformers import CLIPTextModelWithProjection
+
+ return CLIPTextModelWithProjection
+ elif model_class == "T5EncoderModel":
+ from transformers import T5EncoderModel
+
+ return T5EncoderModel
+ else:
+ raise ValueError(f"{model_class} is not supported.")
+
+
+def parse_args(input_args=None):
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--variant",
+ type=str,
+ default=None,
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
+ )
+ parser.add_argument(
+ "--dataset_name",
+ type=str,
+ default=None,
+ help=(
+ "The name of the Dataset (from the HuggingFace hub) containing the training data of instance images (could be your own, possibly private,"
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
+ " or to a folder containing files that 🤗 Datasets can understand."
+ ),
+ )
+ parser.add_argument(
+ "--dataset_config_name",
+ type=str,
+ default=None,
+ help="The config of the Dataset, leave as None if there's only one config.",
+ )
+ parser.add_argument(
+ "--instance_data_dir",
+ type=str,
+ default=None,
+ help=("A folder containing the training data. "),
+ )
+
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+
+ parser.add_argument(
+ "--image_column",
+ type=str,
+ default="image",
+ help="The column of the dataset containing the target image. By "
+ "default, the standard Image Dataset maps out 'file_name' "
+ "to 'image'.",
+ )
+ parser.add_argument(
+ "--caption_column",
+ type=str,
+ default=None,
+ help="The column of the dataset containing the instance prompt for each image",
+ )
+
+ parser.add_argument("--repeats", type=int, default=1, help="How many times to repeat the training data.")
+
+ parser.add_argument(
+ "--class_data_dir",
+ type=str,
+ default=None,
+ required=False,
+ help="A folder containing the training data of class images.",
+ )
+ parser.add_argument(
+ "--instance_prompt",
+ type=str,
+ default=None,
+ required=True,
+ help="The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'",
+ )
+ parser.add_argument(
+ "--class_prompt",
+ type=str,
+ default=None,
+ help="The prompt to specify images in the same class as provided instance images.",
+ )
+ parser.add_argument(
+ "--max_sequence_length",
+ type=int,
+ default=77,
+ help="Maximum sequence length to use with with the T5 text encoder",
+ )
+ parser.add_argument(
+ "--validation_prompt",
+ type=str,
+ default=None,
+ help="A prompt that is used during validation to verify that the model is learning.",
+ )
+ parser.add_argument(
+ "--num_validation_images",
+ type=int,
+ default=4,
+ help="Number of images that should be generated during validation with `validation_prompt`.",
+ )
+ parser.add_argument(
+ "--validation_epochs",
+ type=int,
+ default=50,
+ help=(
+ "Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt"
+ " `args.validation_prompt` multiple times: `args.num_validation_images`."
+ ),
+ )
+ parser.add_argument(
+ "--with_prior_preservation",
+ default=False,
+ action="store_true",
+ help="Flag to add prior preservation loss.",
+ )
+ parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.")
+ parser.add_argument(
+ "--num_class_images",
+ type=int,
+ default=100,
+ help=(
+ "Minimal class images for prior preservation loss. If there are not enough images already present in"
+ " class_data_dir, additional images will be sampled with class_prompt."
+ ),
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="sd3-dreambooth",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--center_crop",
+ default=False,
+ action="store_true",
+ help=(
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
+ " cropped. The images will be resized to the resolution first before cropping."
+ ),
+ )
+ parser.add_argument(
+ "--random_flip",
+ action="store_true",
+ help="whether to randomly flip images horizontally",
+ )
+ parser.add_argument(
+ "--train_text_encoder",
+ action="store_true",
+ help="Whether to train the text encoder. If set, the text encoder should be float32 precision.",
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument(
+ "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=1)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
+ " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+
+ parser.add_argument(
+ "--text_encoder_lr",
+ type=float,
+ default=5e-6,
+ help="Text encoder learning rate to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--lr_num_cycles",
+ type=int,
+ default=1,
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
+ )
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ parser.add_argument(
+ "--weighting_scheme",
+ type=str,
+ default="logit_normal",
+ choices=["sigma_sqrt", "logit_normal", "mode", "cosmap"],
+ )
+ parser.add_argument(
+ "--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme."
+ )
+ parser.add_argument(
+ "--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme."
+ )
+ parser.add_argument(
+ "--mode_scale",
+ type=float,
+ default=1.29,
+ help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.",
+ )
+ parser.add_argument(
+ "--precondition_outputs",
+ type=int,
+ default=1,
+ help="Flag indicating if we are preconditioning the model outputs or not as done in EDM. This affects how "
+ "model `target` is calculated.",
+ )
+ parser.add_argument(
+ "--optimizer",
+ type=str,
+ default="AdamW",
+ help=('The optimizer type to use. Choose between ["AdamW", "prodigy"]'),
+ )
+
+ parser.add_argument(
+ "--use_8bit_adam",
+ action="store_true",
+ help="Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW",
+ )
+
+ parser.add_argument(
+ "--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam and Prodigy optimizers."
+ )
+ parser.add_argument(
+ "--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam and Prodigy optimizers."
+ )
+ parser.add_argument(
+ "--prodigy_beta3",
+ type=float,
+ default=None,
+ help="coefficients for computing the Prodigy stepsize using running averages. If set to None, "
+ "uses the value of square root of beta2. Ignored if optimizer is adamW",
+ )
+ parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params")
+ parser.add_argument(
+ "--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder"
+ )
+
+ parser.add_argument(
+ "--adam_epsilon",
+ type=float,
+ default=1e-08,
+ help="Epsilon value for the Adam optimizer and Prodigy optimizers.",
+ )
+
+ parser.add_argument(
+ "--prodigy_use_bias_correction",
+ type=bool,
+ default=True,
+ help="Turn on Adam's bias correction. True by default. Ignored if optimizer is adamW",
+ )
+ parser.add_argument(
+ "--prodigy_safeguard_warmup",
+ type=bool,
+ default=True,
+ help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. "
+ "Ignored if optimizer is adamW",
+ )
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument(
+ "--prior_generation_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp32", "fp16", "bf16"],
+ help=(
+ "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32."
+ ),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+
+ if input_args is not None:
+ args = parser.parse_args(input_args)
+ else:
+ args = parser.parse_args()
+
+ if args.dataset_name is None and args.instance_data_dir is None:
+ raise ValueError("Specify either `--dataset_name` or `--instance_data_dir`")
+
+ if args.dataset_name is not None and args.instance_data_dir is not None:
+ raise ValueError("Specify only one of `--dataset_name` or `--instance_data_dir`")
+
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ if args.with_prior_preservation:
+ if args.class_data_dir is None:
+ raise ValueError("You must specify a data directory for class images.")
+ if args.class_prompt is None:
+ raise ValueError("You must specify prompt for class images.")
+ else:
+ # logger is not available yet
+ if args.class_data_dir is not None:
+ warnings.warn("You need not use --class_data_dir without --with_prior_preservation.")
+ if args.class_prompt is not None:
+ warnings.warn("You need not use --class_prompt without --with_prior_preservation.")
+
+ return args
+
+
+class DreamBoothDataset(Dataset):
+ """
+ A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
+ It pre-processes the images.
+ """
+
+ def __init__(
+ self,
+ instance_data_root,
+ instance_prompt,
+ class_prompt,
+ class_data_root=None,
+ class_num=None,
+ size=1024,
+ repeats=1,
+ center_crop=False,
+ ):
+ self.size = size
+ self.center_crop = center_crop
+
+ self.instance_prompt = instance_prompt
+ self.custom_instance_prompts = None
+ self.class_prompt = class_prompt
+
+ # if --dataset_name is provided or a metadata jsonl file is provided in the local --instance_data directory,
+ # we load the training data using load_dataset
+ if args.dataset_name is not None:
+ try:
+ from datasets import load_dataset
+ except ImportError:
+ raise ImportError(
+ "You are trying to load your data using the datasets library. If you wish to train using custom "
+ "captions please install the datasets library: `pip install datasets`. If you wish to load a "
+ "local folder containing images only, specify --instance_data_dir instead."
+ )
+ # Downloading and loading a dataset from the hub.
+ # See more about loading custom images at
+ # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script
+ dataset = load_dataset(
+ args.dataset_name,
+ args.dataset_config_name,
+ cache_dir=args.cache_dir,
+ )
+ # Preprocessing the datasets.
+ column_names = dataset["train"].column_names
+
+ # 6. Get the column names for input/target.
+ if args.image_column is None:
+ image_column = column_names[0]
+ logger.info(f"image column defaulting to {image_column}")
+ else:
+ image_column = args.image_column
+ if image_column not in column_names:
+ raise ValueError(
+ f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
+ )
+ instance_images = dataset["train"][image_column]
+
+ if args.caption_column is None:
+ logger.info(
+ "No caption column provided, defaulting to instance_prompt for all images. If your dataset "
+ "contains captions/prompts for the images, make sure to specify the "
+ "column as --caption_column"
+ )
+ self.custom_instance_prompts = None
+ else:
+ if args.caption_column not in column_names:
+ raise ValueError(
+ f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
+ )
+ custom_instance_prompts = dataset["train"][args.caption_column]
+ # create final list of captions according to --repeats
+ self.custom_instance_prompts = []
+ for caption in custom_instance_prompts:
+ self.custom_instance_prompts.extend(itertools.repeat(caption, repeats))
+ else:
+ self.instance_data_root = Path(instance_data_root)
+ if not self.instance_data_root.exists():
+ raise ValueError("Instance images root doesn't exists.")
+
+ instance_images = [Image.open(path) for path in list(Path(instance_data_root).iterdir())]
+ self.custom_instance_prompts = None
+
+ self.instance_images = []
+ for img in instance_images:
+ self.instance_images.extend(itertools.repeat(img, repeats))
+
+ self.pixel_values = []
+ train_resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR)
+ train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size)
+ train_flip = transforms.RandomHorizontalFlip(p=1.0)
+ train_transforms = transforms.Compose(
+ [
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+ )
+ for image in self.instance_images:
+ image = exif_transpose(image)
+ if not image.mode == "RGB":
+ image = image.convert("RGB")
+ image = train_resize(image)
+ if args.random_flip and random.random() < 0.5:
+ # flip
+ image = train_flip(image)
+ if args.center_crop:
+ y1 = max(0, int(round((image.height - args.resolution) / 2.0)))
+ x1 = max(0, int(round((image.width - args.resolution) / 2.0)))
+ image = train_crop(image)
+ else:
+ y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution))
+ image = crop(image, y1, x1, h, w)
+ image = train_transforms(image)
+ self.pixel_values.append(image)
+
+ self.num_instance_images = len(self.instance_images)
+ self._length = self.num_instance_images
+
+ if class_data_root is not None:
+ self.class_data_root = Path(class_data_root)
+ self.class_data_root.mkdir(parents=True, exist_ok=True)
+ self.class_images_path = list(self.class_data_root.iterdir())
+ if class_num is not None:
+ self.num_class_images = min(len(self.class_images_path), class_num)
+ else:
+ self.num_class_images = len(self.class_images_path)
+ self._length = max(self.num_class_images, self.num_instance_images)
+ else:
+ self.class_data_root = None
+
+ self.image_transforms = transforms.Compose(
+ [
+ transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
+ transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+ )
+
+ def __len__(self):
+ return self._length
+
+ def __getitem__(self, index):
+ example = {}
+ instance_image = self.pixel_values[index % self.num_instance_images]
+ example["instance_images"] = instance_image
+
+ if self.custom_instance_prompts:
+ caption = self.custom_instance_prompts[index % self.num_instance_images]
+ if caption:
+ example["instance_prompt"] = caption
+ else:
+ example["instance_prompt"] = self.instance_prompt
+
+ else: # custom prompts were provided, but length does not match size of image dataset
+ example["instance_prompt"] = self.instance_prompt
+
+ if self.class_data_root:
+ class_image = Image.open(self.class_images_path[index % self.num_class_images])
+ class_image = exif_transpose(class_image)
+
+ if not class_image.mode == "RGB":
+ class_image = class_image.convert("RGB")
+ example["class_images"] = self.image_transforms(class_image)
+ example["class_prompt"] = self.class_prompt
+
+ return example
+
+
+def collate_fn(examples, with_prior_preservation=False):
+ pixel_values = [example["instance_images"] for example in examples]
+ prompts = [example["instance_prompt"] for example in examples]
+
+ # Concat class and instance examples for prior preservation.
+ # We do this to avoid doing two forward passes.
+ if with_prior_preservation:
+ pixel_values += [example["class_images"] for example in examples]
+ prompts += [example["class_prompt"] for example in examples]
+
+ pixel_values = torch.stack(pixel_values)
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
+
+ batch = {"pixel_values": pixel_values, "prompts": prompts}
+ return batch
+
+
+class PromptDataset(Dataset):
+ "A simple dataset to prepare the prompts to generate class images on multiple GPUs."
+
+ def __init__(self, prompt, num_samples):
+ self.prompt = prompt
+ self.num_samples = num_samples
+
+ def __len__(self):
+ return self.num_samples
+
+ def __getitem__(self, index):
+ example = {}
+ example["prompt"] = self.prompt
+ example["index"] = index
+ return example
+
+
+def tokenize_prompt(tokenizer, prompt):
+ text_inputs = tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=77,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ return text_input_ids
+
+
+def _encode_prompt_with_t5(
+ text_encoder,
+ tokenizer,
+ max_sequence_length,
+ prompt=None,
+ num_images_per_prompt=1,
+ device=None,
+):
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ text_inputs = tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ add_special_tokens=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ prompt_embeds = text_encoder(text_input_ids.to(device))[0]
+
+ dtype = text_encoder.dtype
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ _, seq_len, _ = prompt_embeds.shape
+
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ return prompt_embeds
+
+
+def _encode_prompt_with_clip(
+ text_encoder,
+ tokenizer,
+ prompt: str,
+ device=None,
+ num_images_per_prompt: int = 1,
+):
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ text_inputs = tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=77,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ text_input_ids = text_inputs.input_ids
+ prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
+
+ pooled_prompt_embeds = prompt_embeds[0]
+ prompt_embeds = prompt_embeds.hidden_states[-2]
+ prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device)
+
+ _, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ return prompt_embeds, pooled_prompt_embeds
+
+
+def encode_prompt(
+ text_encoders,
+ tokenizers,
+ prompt: str,
+ max_sequence_length,
+ device=None,
+ num_images_per_prompt: int = 1,
+):
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ clip_tokenizers = tokenizers[:2]
+ clip_text_encoders = text_encoders[:2]
+
+ clip_prompt_embeds_list = []
+ clip_pooled_prompt_embeds_list = []
+ for tokenizer, text_encoder in zip(clip_tokenizers, clip_text_encoders):
+ prompt_embeds, pooled_prompt_embeds = _encode_prompt_with_clip(
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ prompt=prompt,
+ device=device if device is not None else text_encoder.device,
+ num_images_per_prompt=num_images_per_prompt,
+ )
+ clip_prompt_embeds_list.append(prompt_embeds)
+ clip_pooled_prompt_embeds_list.append(pooled_prompt_embeds)
+
+ clip_prompt_embeds = torch.cat(clip_prompt_embeds_list, dim=-1)
+ pooled_prompt_embeds = torch.cat(clip_pooled_prompt_embeds_list, dim=-1)
+
+ t5_prompt_embed = _encode_prompt_with_t5(
+ text_encoders[-1],
+ tokenizers[-1],
+ max_sequence_length,
+ prompt=prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device if device is not None else text_encoders[-1].device,
+ )
+
+ clip_prompt_embeds = torch.nn.functional.pad(
+ clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1])
+ )
+ prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2)
+
+ return prompt_embeds, pooled_prompt_embeds
+
+
+def main(args):
+ if args.report_to == "wandb" and args.hub_token is not None:
+ raise ValueError(
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
+ " Please use `huggingface-cli login` to authenticate with the Hub."
+ )
+
+ if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
+ # due to pytorch#99272, MPS does not yet support bfloat16.
+ raise ValueError(
+ "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
+ )
+
+ logging_dir = Path(args.output_dir, args.logging_dir)
+
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+ kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ kwargs_handlers=[kwargs],
+ )
+
+ # Disable AMP for MPS.
+ if torch.backends.mps.is_available():
+ accelerator.native_amp = False
+
+ if args.report_to == "wandb":
+ if not is_wandb_available():
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Generate class images if prior preservation is enabled.
+ if args.with_prior_preservation:
+ class_images_dir = Path(args.class_data_dir)
+ if not class_images_dir.exists():
+ class_images_dir.mkdir(parents=True)
+ cur_class_images = len(list(class_images_dir.iterdir()))
+
+ if cur_class_images < args.num_class_images:
+ has_supported_fp16_accelerator = torch.cuda.is_available() or torch.backends.mps.is_available()
+ torch_dtype = torch.float16 if has_supported_fp16_accelerator else torch.float32
+ if args.prior_generation_precision == "fp32":
+ torch_dtype = torch.float32
+ elif args.prior_generation_precision == "fp16":
+ torch_dtype = torch.float16
+ elif args.prior_generation_precision == "bf16":
+ torch_dtype = torch.bfloat16
+ pipeline = StableDiffusion3Pipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ torch_dtype=torch_dtype,
+ revision=args.revision,
+ variant=args.variant,
+ )
+ pipeline.set_progress_bar_config(disable=True)
+
+ num_new_images = args.num_class_images - cur_class_images
+ logger.info(f"Number of class images to sample: {num_new_images}.")
+
+ sample_dataset = PromptDataset(args.class_prompt, num_new_images)
+ sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
+
+ sample_dataloader = accelerator.prepare(sample_dataloader)
+ pipeline.to(accelerator.device)
+
+ for example in tqdm(
+ sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
+ ):
+ images = pipeline(example["prompt"]).images
+
+ for i, image in enumerate(images):
+ hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest()
+ image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
+ image.save(image_filename)
+
+ del pipeline
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name,
+ exist_ok=True,
+ ).repo_id
+
+ # Load the tokenizers
+ tokenizer_one = CLIPTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="tokenizer",
+ revision=args.revision,
+ )
+ tokenizer_two = CLIPTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="tokenizer_2",
+ revision=args.revision,
+ )
+ tokenizer_three = T5TokenizerFast.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="tokenizer_3",
+ revision=args.revision,
+ )
+
+ # import correct text encoder classes
+ text_encoder_cls_one = import_model_class_from_model_name_or_path(
+ args.pretrained_model_name_or_path, args.revision
+ )
+ text_encoder_cls_two = import_model_class_from_model_name_or_path(
+ args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2"
+ )
+ text_encoder_cls_three = import_model_class_from_model_name_or_path(
+ args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_3"
+ )
+
+ # Load scheduler and models
+ noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="scheduler"
+ )
+ noise_scheduler_copy = copy.deepcopy(noise_scheduler)
+ text_encoder_one, text_encoder_two, text_encoder_three = load_text_encoders(
+ text_encoder_cls_one, text_encoder_cls_two, text_encoder_cls_three
+ )
+ vae = AutoencoderKL.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="vae",
+ revision=args.revision,
+ variant=args.variant,
+ )
+ transformer = SD3Transformer2DModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="transformer", revision=args.revision, variant=args.variant
+ )
+
+ transformer.requires_grad_(True)
+ vae.requires_grad_(False)
+ if args.train_text_encoder:
+ text_encoder_one.requires_grad_(True)
+ text_encoder_two.requires_grad_(True)
+ text_encoder_three.requires_grad_(True)
+ else:
+ text_encoder_one.requires_grad_(False)
+ text_encoder_two.requires_grad_(False)
+ text_encoder_three.requires_grad_(False)
+
+ # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora transformer) to half-precision
+ # as these weights are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:
+ # due to pytorch#99272, MPS does not yet support bfloat16.
+ raise ValueError(
+ "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
+ )
+
+ vae.to(accelerator.device, dtype=torch.float32)
+ if not args.train_text_encoder:
+ text_encoder_one.to(accelerator.device, dtype=weight_dtype)
+ text_encoder_two.to(accelerator.device, dtype=weight_dtype)
+ text_encoder_three.to(accelerator.device, dtype=weight_dtype)
+
+ if args.gradient_checkpointing:
+ transformer.enable_gradient_checkpointing()
+ if args.train_text_encoder:
+ text_encoder_one.gradient_checkpointing_enable()
+ text_encoder_two.gradient_checkpointing_enable()
+ text_encoder_three.gradient_checkpointing_enable()
+
+ def unwrap_model(model):
+ model = accelerator.unwrap_model(model)
+ model = model._orig_mod if is_compiled_module(model) else model
+ return model
+
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
+ def save_model_hook(models, weights, output_dir):
+ if accelerator.is_main_process:
+ for i, model in enumerate(models):
+ if isinstance(unwrap_model(model), SD3Transformer2DModel):
+ unwrap_model(model).save_pretrained(os.path.join(output_dir, "transformer"))
+ elif isinstance(unwrap_model(model), (CLIPTextModelWithProjection, T5EncoderModel)):
+ if isinstance(unwrap_model(model), CLIPTextModelWithProjection):
+ hidden_size = unwrap_model(model).config.hidden_size
+ if hidden_size == 768:
+ unwrap_model(model).save_pretrained(os.path.join(output_dir, "text_encoder"))
+ elif hidden_size == 1280:
+ unwrap_model(model).save_pretrained(os.path.join(output_dir, "text_encoder_2"))
+ else:
+ unwrap_model(model).save_pretrained(os.path.join(output_dir, "text_encoder_3"))
+ else:
+ raise ValueError(f"Wrong model supplied: {type(model)=}.")
+
+ # make sure to pop weight so that corresponding model is not saved again
+ weights.pop()
+
+ def load_model_hook(models, input_dir):
+ for _ in range(len(models)):
+ # pop models so that they are not loaded again
+ model = models.pop()
+
+ # load diffusers style into model
+ if isinstance(unwrap_model(model), SD3Transformer2DModel):
+ load_model = SD3Transformer2DModel.from_pretrained(input_dir, subfolder="transformer")
+ model.register_to_config(**load_model.config)
+
+ model.load_state_dict(load_model.state_dict())
+ elif isinstance(unwrap_model(model), (CLIPTextModelWithProjection, T5EncoderModel)):
+ try:
+ load_model = CLIPTextModelWithProjection.from_pretrained(input_dir, subfolder="text_encoder")
+ model(**load_model.config)
+ model.load_state_dict(load_model.state_dict())
+ except Exception:
+ try:
+ load_model = CLIPTextModelWithProjection.from_pretrained(input_dir, subfolder="text_encoder_2")
+ model(**load_model.config)
+ model.load_state_dict(load_model.state_dict())
+ except Exception:
+ try:
+ load_model = T5EncoderModel.from_pretrained(input_dir, subfolder="text_encoder_3")
+ model(**load_model.config)
+ model.load_state_dict(load_model.state_dict())
+ except Exception:
+ raise ValueError(f"Couldn't load the model of type: ({type(model)}).")
+ else:
+ raise ValueError(f"Unsupported model found: {type(model)=}")
+
+ del load_model
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32 and torch.cuda.is_available():
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ # Optimization parameters
+ transformer_parameters_with_lr = {"params": transformer.parameters(), "lr": args.learning_rate}
+ if args.train_text_encoder:
+ # different learning rate for text encoder and unet
+ text_parameters_one_with_lr = {
+ "params": text_encoder_one.parameters(),
+ "weight_decay": args.adam_weight_decay_text_encoder,
+ "lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate,
+ }
+ text_parameters_two_with_lr = {
+ "params": text_encoder_two.parameters(),
+ "weight_decay": args.adam_weight_decay_text_encoder,
+ "lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate,
+ }
+ text_parameters_three_with_lr = {
+ "params": text_encoder_three.parameters(),
+ "weight_decay": args.adam_weight_decay_text_encoder,
+ "lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate,
+ }
+ params_to_optimize = [
+ transformer_parameters_with_lr,
+ text_parameters_one_with_lr,
+ text_parameters_two_with_lr,
+ text_parameters_three_with_lr,
+ ]
+ else:
+ params_to_optimize = [transformer_parameters_with_lr]
+
+ # Optimizer creation
+ if not (args.optimizer.lower() == "prodigy" or args.optimizer.lower() == "adamw"):
+ logger.warning(
+ f"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy]."
+ "Defaulting to adamW"
+ )
+ args.optimizer = "adamw"
+
+ if args.use_8bit_adam and not args.optimizer.lower() == "adamw":
+ logger.warning(
+ f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was "
+ f"set to {args.optimizer.lower()}"
+ )
+
+ if args.optimizer.lower() == "adamw":
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
+ )
+
+ optimizer_class = bnb.optim.AdamW8bit
+ else:
+ optimizer_class = torch.optim.AdamW
+
+ optimizer = optimizer_class(
+ params_to_optimize,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ if args.optimizer.lower() == "prodigy":
+ try:
+ import prodigyopt
+ except ImportError:
+ raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`")
+
+ optimizer_class = prodigyopt.Prodigy
+
+ if args.learning_rate <= 0.1:
+ logger.warning(
+ "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0"
+ )
+ if args.train_text_encoder and args.text_encoder_lr:
+ logger.warning(
+ f"Learning rates were provided both for the transformer and the text encoder- e.g. text_encoder_lr:"
+ f" {args.text_encoder_lr} and learning_rate: {args.learning_rate}. "
+ f"When using prodigy only learning_rate is used as the initial learning rate."
+ )
+ # changes the learning rate of text_encoder_parameters_one and text_encoder_parameters_two to be
+ # --learning_rate
+ params_to_optimize[1]["lr"] = args.learning_rate
+ params_to_optimize[2]["lr"] = args.learning_rate
+ params_to_optimize[3]["lr"] = args.learning_rate
+
+ optimizer = optimizer_class(
+ params_to_optimize,
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ beta3=args.prodigy_beta3,
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ decouple=args.prodigy_decouple,
+ use_bias_correction=args.prodigy_use_bias_correction,
+ safeguard_warmup=args.prodigy_safeguard_warmup,
+ )
+
+ # Dataset and DataLoaders creation:
+ train_dataset = DreamBoothDataset(
+ instance_data_root=args.instance_data_dir,
+ instance_prompt=args.instance_prompt,
+ class_prompt=args.class_prompt,
+ class_data_root=args.class_data_dir if args.with_prior_preservation else None,
+ class_num=args.num_class_images,
+ size=args.resolution,
+ repeats=args.repeats,
+ center_crop=args.center_crop,
+ )
+
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset,
+ batch_size=args.train_batch_size,
+ shuffle=True,
+ collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),
+ num_workers=args.dataloader_num_workers,
+ )
+
+ if not args.train_text_encoder:
+ tokenizers = [tokenizer_one, tokenizer_two, tokenizer_three]
+ text_encoders = [text_encoder_one, text_encoder_two, text_encoder_three]
+
+ def compute_text_embeddings(prompt, text_encoders, tokenizers):
+ with torch.no_grad():
+ prompt_embeds, pooled_prompt_embeds = encode_prompt(
+ text_encoders, tokenizers, prompt, args.max_sequence_length
+ )
+ prompt_embeds = prompt_embeds.to(accelerator.device)
+ pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device)
+ return prompt_embeds, pooled_prompt_embeds
+
+ # If no type of tuning is done on the text_encoder and custom instance prompts are NOT
+ # provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid
+ # the redundant encoding.
+ if not args.train_text_encoder and not train_dataset.custom_instance_prompts:
+ instance_prompt_hidden_states, instance_pooled_prompt_embeds = compute_text_embeddings(
+ args.instance_prompt, text_encoders, tokenizers
+ )
+
+ # Handle class prompt for prior-preservation.
+ if args.with_prior_preservation:
+ if not args.train_text_encoder:
+ class_prompt_hidden_states, class_pooled_prompt_embeds = compute_text_embeddings(
+ args.class_prompt, text_encoders, tokenizers
+ )
+
+ # Clear the memory here
+ if not args.train_text_encoder and not train_dataset.custom_instance_prompts:
+ del tokenizers, text_encoders
+ # Explicitly delete the objects as well, otherwise only the lists are deleted and the original references remain, preventing garbage collection
+ del text_encoder_one, text_encoder_two, text_encoder_three
+ gc.collect()
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+
+ # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
+ # pack the statically computed variables appropriately here. This is so that we don't
+ # have to pass them to the dataloader.
+
+ if not train_dataset.custom_instance_prompts:
+ if not args.train_text_encoder:
+ prompt_embeds = instance_prompt_hidden_states
+ pooled_prompt_embeds = instance_pooled_prompt_embeds
+ if args.with_prior_preservation:
+ prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0)
+ pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, class_pooled_prompt_embeds], dim=0)
+ # if we're optimizing the text encoder (both if instance prompt is used for all images or custom prompts) we need to tokenize and encode the
+ # batch prompts on all training steps
+ else:
+ tokens_one = tokenize_prompt(tokenizer_one, args.instance_prompt)
+ tokens_two = tokenize_prompt(tokenizer_two, args.instance_prompt)
+ tokens_three = tokenize_prompt(tokenizer_three, args.instance_prompt)
+ if args.with_prior_preservation:
+ class_tokens_one = tokenize_prompt(tokenizer_one, args.class_prompt)
+ class_tokens_two = tokenize_prompt(tokenizer_two, args.class_prompt)
+ class_tokens_three = tokenize_prompt(tokenizer_three, args.class_prompt)
+ tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0)
+ tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0)
+ tokens_three = torch.cat([tokens_three, class_tokens_three], dim=0)
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
+ num_cycles=args.lr_num_cycles,
+ power=args.lr_power,
+ )
+
+ # Prepare everything with our `accelerator`.
+ if args.train_text_encoder:
+ (
+ transformer,
+ text_encoder_one,
+ text_encoder_two,
+ text_encoder_three,
+ optimizer,
+ train_dataloader,
+ lr_scheduler,
+ ) = accelerator.prepare(
+ transformer,
+ text_encoder_one,
+ text_encoder_two,
+ text_encoder_three,
+ optimizer,
+ train_dataloader,
+ lr_scheduler,
+ )
+ else:
+ transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ transformer, optimizer, train_dataloader, lr_scheduler
+ )
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ tracker_name = "dreambooth-sd3"
+ accelerator.init_trackers(tracker_name, config=vars(args))
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the mos recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ initial_global_step = 0
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ initial_global_step = global_step
+ first_epoch = global_step // num_update_steps_per_epoch
+
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
+
+ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
+ sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype)
+ schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device)
+ timesteps = timesteps.to(accelerator.device)
+ step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
+
+ sigma = sigmas[step_indices].flatten()
+ while len(sigma.shape) < n_dim:
+ sigma = sigma.unsqueeze(-1)
+ return sigma
+
+ for epoch in range(first_epoch, args.num_train_epochs):
+ transformer.train()
+ if args.train_text_encoder:
+ text_encoder_one.train()
+ text_encoder_two.train()
+ text_encoder_three.train()
+
+ for step, batch in enumerate(train_dataloader):
+ models_to_accumulate = [transformer]
+ if args.train_text_encoder:
+ models_to_accumulate.extend([text_encoder_one, text_encoder_two, text_encoder_three])
+ with accelerator.accumulate(models_to_accumulate):
+ pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
+ prompts = batch["prompts"]
+
+ # encode batch prompts when custom prompts are provided for each image -
+ if train_dataset.custom_instance_prompts:
+ if not args.train_text_encoder:
+ prompt_embeds, pooled_prompt_embeds = compute_text_embeddings(
+ prompts, text_encoders, tokenizers
+ )
+ else:
+ tokens_one = tokenize_prompt(tokenizer_one, prompts)
+ tokens_two = tokenize_prompt(tokenizer_two, prompts)
+ tokens_three = tokenize_prompt(tokenizer_three, prompts)
+
+ # Convert images to latent space
+ model_input = vae.encode(pixel_values).latent_dist.sample()
+ model_input = (model_input - vae.config.shift_factor) * vae.config.scaling_factor
+ model_input = model_input.to(dtype=weight_dtype)
+
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(model_input)
+ bsz = model_input.shape[0]
+
+ # Sample a random timestep for each image
+ # for weighting schemes where we sample timesteps non-uniformly
+ u = compute_density_for_timestep_sampling(
+ weighting_scheme=args.weighting_scheme,
+ batch_size=bsz,
+ logit_mean=args.logit_mean,
+ logit_std=args.logit_std,
+ mode_scale=args.mode_scale,
+ )
+ indices = (u * noise_scheduler_copy.config.num_train_timesteps).long()
+ timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device)
+
+ # Add noise according to flow matching.
+ # zt = (1 - texp) * x + texp * z1
+ sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)
+ noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise
+
+ # Predict the noise residual
+ if not args.train_text_encoder:
+ model_pred = transformer(
+ hidden_states=noisy_model_input,
+ timestep=timesteps,
+ encoder_hidden_states=prompt_embeds,
+ pooled_projections=pooled_prompt_embeds,
+ return_dict=False,
+ )[0]
+ else:
+ prompt_embeds, pooled_prompt_embeds = encode_prompt(
+ text_encoders=[text_encoder_one, text_encoder_two, text_encoder_three],
+ tokenizers=None,
+ prompt=None,
+ text_input_ids_list=[tokens_one, tokens_two, tokens_three],
+ )
+ model_pred = transformer(
+ hidden_states=noisy_model_input,
+ timestep=timesteps,
+ encoder_hidden_states=prompt_embeds,
+ pooled_projections=pooled_prompt_embeds,
+ return_dict=False,
+ )[0]
+
+ # Follow: Section 5 of https://arxiv.org/abs/2206.00364.
+ # Preconditioning of the model outputs.
+ if args.precondition_outputs:
+ model_pred = model_pred * (-sigmas) + noisy_model_input
+
+ # these weighting schemes use a uniform timestep sampling
+ # and instead post-weight the loss
+ weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
+
+ # flow matching loss
+ if args.precondition_outputs:
+ target = model_input
+ else:
+ target = noise - model_input
+
+ if args.with_prior_preservation:
+ # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
+ model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
+ target, target_prior = torch.chunk(target, 2, dim=0)
+
+ # Compute prior loss
+ prior_loss = torch.mean(
+ (weighting.float() * (model_pred_prior.float() - target_prior.float()) ** 2).reshape(
+ target_prior.shape[0], -1
+ ),
+ 1,
+ )
+ prior_loss = prior_loss.mean()
+
+ # Compute regular loss.
+ loss = torch.mean(
+ (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1),
+ 1,
+ )
+ loss = loss.mean()
+
+ if args.with_prior_preservation:
+ # Add the prior loss to the instance loss.
+ loss = loss + args.prior_loss_weight * prior_loss
+
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ params_to_clip = (
+ itertools.chain(
+ transformer.parameters(),
+ text_encoder_one.parameters(),
+ text_encoder_two.parameters(),
+ text_encoder_three.parameters(),
+ )
+ if args.train_text_encoder
+ else transformer.parameters()
+ )
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
+
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ global_step += 1
+
+ if accelerator.is_main_process:
+ if global_step % args.checkpointing_steps == 0:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+ accelerator.log(logs, step=global_step)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ if accelerator.is_main_process:
+ if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
+ # create pipeline
+ if not args.train_text_encoder:
+ text_encoder_one, text_encoder_two, text_encoder_three = load_text_encoders(
+ text_encoder_cls_one, text_encoder_cls_two, text_encoder_cls_three
+ )
+ pipeline = StableDiffusion3Pipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ vae=vae,
+ text_encoder=accelerator.unwrap_model(text_encoder_one),
+ text_encoder_2=accelerator.unwrap_model(text_encoder_two),
+ text_encoder_3=accelerator.unwrap_model(text_encoder_three),
+ transformer=accelerator.unwrap_model(transformer),
+ revision=args.revision,
+ variant=args.variant,
+ torch_dtype=weight_dtype,
+ )
+ pipeline_args = {"prompt": args.validation_prompt}
+ images = log_validation(
+ pipeline=pipeline,
+ args=args,
+ accelerator=accelerator,
+ pipeline_args=pipeline_args,
+ epoch=epoch,
+ torch_dtype=weight_dtype,
+ )
+ if not args.train_text_encoder:
+ del text_encoder_one, text_encoder_two, text_encoder_three
+ torch.cuda.empty_cache()
+ gc.collect()
+
+ # Save the lora layers
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ transformer = unwrap_model(transformer)
+
+ if args.train_text_encoder:
+ text_encoder_one = unwrap_model(text_encoder_one)
+ text_encoder_two = unwrap_model(text_encoder_two)
+ text_encoder_three = unwrap_model(text_encoder_three)
+ pipeline = StableDiffusion3Pipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ transformer=transformer,
+ text_encoder=text_encoder_one,
+ text_encoder_2=text_encoder_two,
+ text_encoder_3=text_encoder_three,
+ )
+ else:
+ pipeline = StableDiffusion3Pipeline.from_pretrained(
+ args.pretrained_model_name_or_path, transformer=transformer
+ )
+
+ # save the pipeline
+ pipeline.save_pretrained(args.output_dir)
+
+ # Final inference
+ # Load previous pipeline
+ pipeline = StableDiffusion3Pipeline.from_pretrained(
+ args.output_dir,
+ revision=args.revision,
+ variant=args.variant,
+ torch_dtype=weight_dtype,
+ )
+
+ # run inference
+ images = []
+ if args.validation_prompt and args.num_validation_images > 0:
+ pipeline_args = {"prompt": args.validation_prompt}
+ images = log_validation(
+ pipeline=pipeline,
+ args=args,
+ accelerator=accelerator,
+ pipeline_args=pipeline_args,
+ epoch=epoch,
+ is_final_validation=True,
+ torch_dtype=weight_dtype,
+ )
+
+ if args.push_to_hub:
+ save_model_card(
+ repo_id,
+ images=images,
+ base_model=args.pretrained_model_name_or_path,
+ train_text_encoder=args.train_text_encoder,
+ instance_prompt=args.instance_prompt,
+ validation_prompt=args.validation_prompt,
+ repo_folder=args.output_dir,
+ )
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ main(args)
diff --git a/diffusers/examples/inference/README.md b/diffusers/examples/inference/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..52d66be8e228d312f1d079e6c8123448b6fa86fd
--- /dev/null
+++ b/diffusers/examples/inference/README.md
@@ -0,0 +1,8 @@
+# Inference Examples
+
+**The inference examples folder is deprecated and will be removed in a future version**.
+**Officially supported inference examples can be found in the [Pipelines folder](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines)**.
+
+- For `Image-to-Image text-guided generation with Stable Diffusion`, please have a look at the official [Pipeline examples](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines#examples)
+- For `In-painting using Stable Diffusion`, please have a look at the official [Pipeline examples](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines#examples)
+- For `Tweak prompts reusing seeds and latents`, please have a look at the official [Pipeline examples](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines#examples)
diff --git a/diffusers/examples/inference/image_to_image.py b/diffusers/examples/inference/image_to_image.py
new file mode 100644
index 0000000000000000000000000000000000000000..86b46c4e606e039cb2ad80b341b2685694f883b4
--- /dev/null
+++ b/diffusers/examples/inference/image_to_image.py
@@ -0,0 +1,9 @@
+import warnings
+
+from diffusers import StableDiffusionImg2ImgPipeline # noqa F401
+
+
+warnings.warn(
+ "The `image_to_image.py` script is outdated. Please use directly `from diffusers import"
+ " StableDiffusionImg2ImgPipeline` instead."
+)
diff --git a/diffusers/examples/inference/inpainting.py b/diffusers/examples/inference/inpainting.py
new file mode 100644
index 0000000000000000000000000000000000000000..8aad208ff34eb4d4ba1c6acfdfe0f97ac9afc4bc
--- /dev/null
+++ b/diffusers/examples/inference/inpainting.py
@@ -0,0 +1,9 @@
+import warnings
+
+from diffusers import StableDiffusionInpaintPipeline as StableDiffusionInpaintPipeline # noqa F401
+
+
+warnings.warn(
+ "The `inpainting.py` script is outdated. Please use directly `from diffusers import"
+ " StableDiffusionInpaintPipeline` instead."
+)
diff --git a/diffusers/examples/instruct_pix2pix/README.md b/diffusers/examples/instruct_pix2pix/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..a5e91a5356257ad9e41c23e8acc49b097e2aee2e
--- /dev/null
+++ b/diffusers/examples/instruct_pix2pix/README.md
@@ -0,0 +1,196 @@
+# InstructPix2Pix training example
+
+[InstructPix2Pix](https://arxiv.org/abs/2211.09800) is a method to fine-tune text-conditioned diffusion models such that they can follow an edit instruction for an input image. Models fine-tuned using this method take the following as inputs:
+
+
+
+
+
+The output is an "edited" image that reflects the edit instruction applied on the input image:
+
+
+
+
+
+The `train_instruct_pix2pix.py` script shows how to implement the training procedure and adapt it for Stable Diffusion.
+
+***Disclaimer: Even though `train_instruct_pix2pix.py` implements the InstructPix2Pix
+training procedure while being faithful to the [original implementation](https://github.com/timothybrooks/instruct-pix2pix) we have only tested it on a [small-scale dataset](https://huggingface.co/datasets/fusing/instructpix2pix-1000-samples). This can impact the end results. For better results, we recommend longer training runs with a larger dataset. [Here](https://huggingface.co/datasets/timbrooks/instructpix2pix-clip-filtered) you can find a large dataset for InstructPix2Pix training.***
+
+## Running locally with PyTorch
+
+### Installing the dependencies
+
+Before running the scripts, make sure to install the library's training dependencies:
+
+**Important**
+
+To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
+```bash
+git clone https://github.com/huggingface/diffusers
+cd diffusers
+pip install -e .
+```
+
+Then cd in the example folder and run
+```bash
+pip install -r requirements.txt
+```
+
+And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
+
+```bash
+accelerate config
+```
+
+Or for a default accelerate configuration without answering questions about your environment
+
+```bash
+accelerate config default
+```
+
+Or if your environment doesn't support an interactive shell e.g. a notebook
+
+```python
+from accelerate.utils import write_basic_config
+write_basic_config()
+```
+
+### Toy example
+
+As mentioned before, we'll use a [small toy dataset](https://huggingface.co/datasets/fusing/instructpix2pix-1000-samples) for training. The dataset
+is a smaller version of the [original dataset](https://huggingface.co/datasets/timbrooks/instructpix2pix-clip-filtered) used in the InstructPix2Pix paper.
+
+Configure environment variables such as the dataset identifier and the Stable Diffusion
+checkpoint:
+
+```bash
+export MODEL_NAME="stable-diffusion-v1-5/stable-diffusion-v1-5"
+export DATASET_ID="fusing/instructpix2pix-1000-samples"
+```
+
+Now, we can launch training:
+
+```bash
+accelerate launch --mixed_precision="fp16" train_instruct_pix2pix.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --dataset_name=$DATASET_ID \
+ --enable_xformers_memory_efficient_attention \
+ --resolution=256 --random_flip \
+ --train_batch_size=4 --gradient_accumulation_steps=4 --gradient_checkpointing \
+ --max_train_steps=15000 \
+ --checkpointing_steps=5000 --checkpoints_total_limit=1 \
+ --learning_rate=5e-05 --max_grad_norm=1 --lr_warmup_steps=0 \
+ --conditioning_dropout_prob=0.05 \
+ --mixed_precision=fp16 \
+ --seed=42 \
+ --push_to_hub
+```
+
+Additionally, we support performing validation inference to monitor training progress
+with Weights and Biases. You can enable this feature with `report_to="wandb"`:
+
+```bash
+accelerate launch --mixed_precision="fp16" train_instruct_pix2pix.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --dataset_name=$DATASET_ID \
+ --enable_xformers_memory_efficient_attention \
+ --resolution=256 --random_flip \
+ --train_batch_size=4 --gradient_accumulation_steps=4 --gradient_checkpointing \
+ --max_train_steps=15000 \
+ --checkpointing_steps=5000 --checkpoints_total_limit=1 \
+ --learning_rate=5e-05 --max_grad_norm=1 --lr_warmup_steps=0 \
+ --conditioning_dropout_prob=0.05 \
+ --mixed_precision=fp16 \
+ --val_image_url="https://hf.co/datasets/diffusers/diffusers-images-docs/resolve/main/mountain.png" \
+ --validation_prompt="make the mountains snowy" \
+ --seed=42 \
+ --report_to=wandb \
+ --push_to_hub
+ ```
+
+ We recommend this type of validation as it can be useful for model debugging. Note that you need `wandb` installed to use this. You can install `wandb` by running `pip install wandb`.
+
+ [Here](https://wandb.ai/sayakpaul/instruct-pix2pix/runs/ctr3kovq), you can find an example training run that includes some validation samples and the training hyperparameters.
+
+ ***Note: In the original paper, the authors observed that even when the model is trained with an image resolution of 256x256, it generalizes well to bigger resolutions such as 512x512. This is likely because of the larger dataset they used during training.***
+
+ ## Training with multiple GPUs
+
+`accelerate` allows for seamless multi-GPU training. Follow the instructions [here](https://huggingface.co/docs/accelerate/basic_tutorials/launch)
+for running distributed training with `accelerate`. Here is an example command:
+
+```bash
+accelerate launch --mixed_precision="fp16" --multi_gpu train_instruct_pix2pix.py \
+ --pretrained_model_name_or_path=stable-diffusion-v1-5/stable-diffusion-v1-5 \
+ --dataset_name=sayakpaul/instructpix2pix-1000-samples \
+ --use_ema \
+ --enable_xformers_memory_efficient_attention \
+ --resolution=512 --random_flip \
+ --train_batch_size=4 --gradient_accumulation_steps=4 --gradient_checkpointing \
+ --max_train_steps=15000 \
+ --checkpointing_steps=5000 --checkpoints_total_limit=1 \
+ --learning_rate=5e-05 --lr_warmup_steps=0 \
+ --conditioning_dropout_prob=0.05 \
+ --mixed_precision=fp16 \
+ --seed=42 \
+ --push_to_hub
+```
+
+ ## Inference
+
+ Once training is complete, we can perform inference:
+
+ ```python
+import PIL
+import requests
+import torch
+from diffusers import StableDiffusionInstructPix2PixPipeline
+
+model_id = "your_model_id" # <- replace this
+pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
+generator = torch.Generator("cuda").manual_seed(0)
+
+url = "https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/test_pix2pix_4.png"
+
+
+def download_image(url):
+ image = PIL.Image.open(requests.get(url, stream=True).raw)
+ image = PIL.ImageOps.exif_transpose(image)
+ image = image.convert("RGB")
+ return image
+
+image = download_image(url)
+prompt = "wipe out the lake"
+num_inference_steps = 20
+image_guidance_scale = 1.5
+guidance_scale = 10
+
+edited_image = pipe(prompt,
+ image=image,
+ num_inference_steps=num_inference_steps,
+ image_guidance_scale=image_guidance_scale,
+ guidance_scale=guidance_scale,
+ generator=generator,
+).images[0]
+edited_image.save("edited_image.png")
+```
+
+An example model repo obtained using this training script can be found
+here - [sayakpaul/instruct-pix2pix](https://huggingface.co/sayakpaul/instruct-pix2pix).
+
+We encourage you to play with the following three parameters to control
+speed and quality during performance:
+
+* `num_inference_steps`
+* `image_guidance_scale`
+* `guidance_scale`
+
+Particularly, `image_guidance_scale` and `guidance_scale` can have a profound impact
+on the generated ("edited") image (see [here](https://twitter.com/RisingSayak/status/1628392199196151808?s=20) for an example).
+
+If you're looking for some interesting ways to use the InstructPix2Pix training methodology, we welcome you to check out this blog post: [Instruction-tuning Stable Diffusion with InstructPix2Pix](https://huggingface.co/blog/instruction-tuning-sd).
+
+## Stable Diffusion XL
+
+There's an equivalent `train_instruct_pix2pix_sdxl.py` script for [Stable Diffusion XL](https://huggingface.co/papers/2307.01952). Please refer to the docs [here](./README_sdxl.md) to learn more.
diff --git a/diffusers/examples/instruct_pix2pix/README_sdxl.md b/diffusers/examples/instruct_pix2pix/README_sdxl.md
new file mode 100644
index 0000000000000000000000000000000000000000..dcf828e80ba8c163b444b1b4d7b819c1e25dacc1
--- /dev/null
+++ b/diffusers/examples/instruct_pix2pix/README_sdxl.md
@@ -0,0 +1,197 @@
+# InstructPix2Pix SDXL training example
+
+***This is based on the original InstructPix2Pix training example.***
+
+[Stable Diffusion XL](https://huggingface.co/papers/2307.01952) (or SDXL) is the latest image generation model that is tailored towards more photorealistic outputs with more detailed imagery and composition compared to previous SD models. It leverages a three times larger UNet backbone. The increase of model parameters is mainly due to more attention blocks and a larger cross-attention context as SDXL uses a second text encoder.
+
+The `train_instruct_pix2pix_sdxl.py` script shows how to implement the training procedure and adapt it for Stable Diffusion XL.
+
+***Disclaimer: Even though `train_instruct_pix2pix_sdxl.py` implements the InstructPix2Pix
+training procedure while being faithful to the [original implementation](https://github.com/timothybrooks/instruct-pix2pix) we have only tested it on a [small-scale dataset](https://huggingface.co/datasets/fusing/instructpix2pix-1000-samples). This can impact the end results. For better results, we recommend longer training runs with a larger dataset. [Here](https://huggingface.co/datasets/timbrooks/instructpix2pix-clip-filtered) you can find a large dataset for InstructPix2Pix training.***
+
+## Running locally with PyTorch
+
+### Installing the dependencies
+
+Refer to the original InstructPix2Pix training example for installing the dependencies.
+
+You will also need to get access of SDXL by filling the [form](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0).
+
+### Toy example
+
+As mentioned before, we'll use a [small toy dataset](https://huggingface.co/datasets/fusing/instructpix2pix-1000-samples) for training. The dataset
+is a smaller version of the [original dataset](https://huggingface.co/datasets/timbrooks/instructpix2pix-clip-filtered) used in the InstructPix2Pix paper.
+
+Configure environment variables such as the dataset identifier and the Stable Diffusion
+checkpoint:
+
+```bash
+export MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0"
+export DATASET_ID="fusing/instructpix2pix-1000-samples"
+```
+
+Now, we can launch training:
+
+```bash
+accelerate launch train_instruct_pix2pix_sdxl.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --dataset_name=$DATASET_ID \
+ --enable_xformers_memory_efficient_attention \
+ --resolution=256 --random_flip \
+ --train_batch_size=4 --gradient_accumulation_steps=4 --gradient_checkpointing \
+ --max_train_steps=15000 \
+ --checkpointing_steps=5000 --checkpoints_total_limit=1 \
+ --learning_rate=5e-05 --max_grad_norm=1 --lr_warmup_steps=0 \
+ --conditioning_dropout_prob=0.05 \
+ --seed=42 \
+ --push_to_hub
+```
+
+Additionally, we support performing validation inference to monitor training progress
+with Weights and Biases. You can enable this feature with `report_to="wandb"`:
+
+```bash
+accelerate launch train_instruct_pix2pix_sdxl.py \
+ --pretrained_model_name_or_path=stabilityai/stable-diffusion-xl-base-1.0 \
+ --dataset_name=$DATASET_ID \
+ --use_ema \
+ --enable_xformers_memory_efficient_attention \
+ --resolution=512 --random_flip \
+ --train_batch_size=4 --gradient_accumulation_steps=4 --gradient_checkpointing \
+ --max_train_steps=15000 \
+ --checkpointing_steps=5000 --checkpoints_total_limit=1 \
+ --learning_rate=5e-05 --lr_warmup_steps=0 \
+ --conditioning_dropout_prob=0.05 \
+ --seed=42 \
+ --val_image_url_or_path="https://datasets-server.huggingface.co/assets/fusing/instructpix2pix-1000-samples/--/fusing--instructpix2pix-1000-samples/train/23/input_image/image.jpg" \
+ --validation_prompt="make it in japan" \
+ --report_to=wandb \
+ --push_to_hub
+ ```
+
+ We recommend this type of validation as it can be useful for model debugging. Note that you need `wandb` installed to use this. You can install `wandb` by running `pip install wandb`.
+
+ [Here](https://wandb.ai/sayakpaul/instruct-pix2pix-sdxl-new/runs/sw53gxmc), you can find an example training run that includes some validation samples and the training hyperparameters.
+
+ ***Note: In the original paper, the authors observed that even when the model is trained with an image resolution of 256x256, it generalizes well to bigger resolutions such as 512x512. This is likely because of the larger dataset they used during training.***
+
+ ## Training with multiple GPUs
+
+`accelerate` allows for seamless multi-GPU training. Follow the instructions [here](https://huggingface.co/docs/accelerate/basic_tutorials/launch)
+for running distributed training with `accelerate`. Here is an example command:
+
+```bash
+accelerate launch --mixed_precision="fp16" --multi_gpu train_instruct_pix2pix_sdxl.py \
+ --pretrained_model_name_or_path=stabilityai/stable-diffusion-xl-base-1.0 \
+ --dataset_name=$DATASET_ID \
+ --use_ema \
+ --enable_xformers_memory_efficient_attention \
+ --resolution=512 --random_flip \
+ --train_batch_size=4 --gradient_accumulation_steps=4 --gradient_checkpointing \
+ --max_train_steps=15000 \
+ --checkpointing_steps=5000 --checkpoints_total_limit=1 \
+ --learning_rate=5e-05 --lr_warmup_steps=0 \
+ --conditioning_dropout_prob=0.05 \
+ --seed=42 \
+ --val_image_url_or_path="https://datasets-server.huggingface.co/assets/fusing/instructpix2pix-1000-samples/--/fusing--instructpix2pix-1000-samples/train/23/input_image/image.jpg" \
+ --validation_prompt="make it in japan" \
+ --report_to=wandb \
+ --push_to_hub
+```
+
+ ## Inference
+
+ Once training is complete, we can perform inference:
+
+ ```python
+import PIL
+import requests
+import torch
+from diffusers import StableDiffusionXLInstructPix2PixPipeline
+
+model_id = "your_model_id" # <- replace this
+pipe = StableDiffusionXLInstructPix2PixPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
+generator = torch.Generator("cuda").manual_seed(0)
+
+url = "https://datasets-server.huggingface.co/assets/fusing/instructpix2pix-1000-samples/--/fusing--instructpix2pix-1000-samples/train/23/input_image/image.jpg"
+
+
+def download_image(url):
+ image = PIL.Image.open(requests.get(url, stream=True).raw)
+ image = PIL.ImageOps.exif_transpose(image)
+ image = image.convert("RGB")
+ return image
+
+image = download_image(url)
+prompt = "make it Japan"
+num_inference_steps = 20
+image_guidance_scale = 1.5
+guidance_scale = 10
+
+edited_image = pipe(prompt,
+ image=image,
+ num_inference_steps=num_inference_steps,
+ image_guidance_scale=image_guidance_scale,
+ guidance_scale=guidance_scale,
+ generator=generator,
+).images[0]
+edited_image.save("edited_image.png")
+```
+
+We encourage you to play with the following three parameters to control
+speed and quality during performance:
+
+* `num_inference_steps`
+* `image_guidance_scale`
+* `guidance_scale`
+
+Particularly, `image_guidance_scale` and `guidance_scale` can have a profound impact
+on the generated ("edited") image (see [here](https://twitter.com/RisingSayak/status/1628392199196151808?s=20) for an example).
+
+If you're looking for some interesting ways to use the InstructPix2Pix training methodology, we welcome you to check out this blog post: [Instruction-tuning Stable Diffusion with InstructPix2Pix](https://huggingface.co/blog/instruction-tuning-sd).
+
+## Compare between SD and SDXL
+
+We aim to understand the differences resulting from the use of SD-1.5 and SDXL-0.9 as pretrained models. To achieve this, we trained on the [small toy dataset](https://huggingface.co/datasets/fusing/instructpix2pix-1000-samples) using both of these pretrained models. The training script is as follows:
+
+```bash
+export MODEL_NAME="stable-diffusion-v1-5/stable-diffusion-v1-5" or "stabilityai/stable-diffusion-xl-base-0.9"
+export DATASET_ID="fusing/instructpix2pix-1000-samples"
+
+accelerate launch train_instruct_pix2pix.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --dataset_name=$DATASET_ID \
+ --use_ema \
+ --enable_xformers_memory_efficient_attention \
+ --resolution=512 --random_flip \
+ --train_batch_size=4 --gradient_accumulation_steps=4 --gradient_checkpointing \
+ --max_train_steps=15000 \
+ --checkpointing_steps=5000 --checkpoints_total_limit=1 \
+ --learning_rate=5e-05 --lr_warmup_steps=0 \
+ --conditioning_dropout_prob=0.05 \
+ --seed=42 \
+ --val_image_url="https://datasets-server.huggingface.co/assets/fusing/instructpix2pix-1000-samples/--/fusing--instructpix2pix-1000-samples/train/23/input_image/image.jpg" \
+ --validation_prompt="make it in Japan" \
+ --report_to=wandb \
+ --push_to_hub
+```
+
+We discovered that compared to training with SD-1.5 as the pretrained model, SDXL-0.9 results in a lower training loss value (SD-1.5 yields 0.0599, SDXL scores 0.0254). Moreover, from a visual perspective, the results obtained using SDXL demonstrated fewer artifacts and a richer detail. Notably, SDXL starts to preserve the structure of the original image earlier on.
+
+The following two GIFs provide intuitive visual results. We observed, for each step, what kind of results could be achieved using the image
+
+
+
+with "make it in Japan” as the prompt. It can be seen that SDXL starts preserving the details of the original image earlier, resulting in higher fidelity outcomes sooner.
+
+* SD-1.5: https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sd_ip2p_training_val_img_progress.gif
+
+
+
+
+
+* SDXL: https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl_ip2p_training_val_img_progress.gif
+
+
+
+
diff --git a/diffusers/examples/instruct_pix2pix/requirements.txt b/diffusers/examples/instruct_pix2pix/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..e18cc9e4215eaa760c8d29c946396dba9ff2c9ac
--- /dev/null
+++ b/diffusers/examples/instruct_pix2pix/requirements.txt
@@ -0,0 +1,6 @@
+accelerate>=0.16.0
+torchvision
+transformers>=4.25.1
+datasets
+ftfy
+tensorboard
\ No newline at end of file
diff --git a/diffusers/examples/instruct_pix2pix/test_instruct_pix2pix.py b/diffusers/examples/instruct_pix2pix/test_instruct_pix2pix.py
new file mode 100644
index 0000000000000000000000000000000000000000..f72f95db6afd46716d8a6e1c910879ca354e042c
--- /dev/null
+++ b/diffusers/examples/instruct_pix2pix/test_instruct_pix2pix.py
@@ -0,0 +1,101 @@
+# coding=utf-8
+# Copyright 2024 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import os
+import sys
+import tempfile
+
+
+sys.path.append("..")
+from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
+
+
+logging.basicConfig(level=logging.DEBUG)
+
+logger = logging.getLogger()
+stream_handler = logging.StreamHandler(sys.stdout)
+logger.addHandler(stream_handler)
+
+
+class InstructPix2Pix(ExamplesTestsAccelerate):
+ def test_instruct_pix2pix_checkpointing_checkpoints_total_limit(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/instruct_pix2pix/train_instruct_pix2pix.py
+ --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
+ --dataset_name=hf-internal-testing/instructpix2pix-10-samples
+ --resolution=64
+ --random_flip
+ --train_batch_size=1
+ --max_train_steps=6
+ --checkpointing_steps=2
+ --checkpoints_total_limit=2
+ --output_dir {tmpdir}
+ --seed=0
+ """.split()
+
+ run_command(self._launch_args + test_args)
+
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-4", "checkpoint-6"},
+ )
+
+ def test_instruct_pix2pix_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/instruct_pix2pix/train_instruct_pix2pix.py
+ --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
+ --dataset_name=hf-internal-testing/instructpix2pix-10-samples
+ --resolution=64
+ --random_flip
+ --train_batch_size=1
+ --max_train_steps=4
+ --checkpointing_steps=2
+ --output_dir {tmpdir}
+ --seed=0
+ """.split()
+
+ run_command(self._launch_args + test_args)
+
+ # check checkpoint directories exist
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-2", "checkpoint-4"},
+ )
+
+ resume_run_args = f"""
+ examples/instruct_pix2pix/train_instruct_pix2pix.py
+ --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
+ --dataset_name=hf-internal-testing/instructpix2pix-10-samples
+ --resolution=64
+ --random_flip
+ --train_batch_size=1
+ --max_train_steps=8
+ --checkpointing_steps=2
+ --output_dir {tmpdir}
+ --seed=0
+ --resume_from_checkpoint=checkpoint-4
+ --checkpoints_total_limit=2
+ """.split()
+
+ run_command(self._launch_args + resume_run_args)
+
+ # check checkpoint directories exist
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-6", "checkpoint-8"},
+ )
diff --git a/diffusers/examples/instruct_pix2pix/train_instruct_pix2pix.py b/diffusers/examples/instruct_pix2pix/train_instruct_pix2pix.py
new file mode 100644
index 0000000000000000000000000000000000000000..e5b7eaac4a1fadbec03d84822121f6fcfb6f4646
--- /dev/null
+++ b/diffusers/examples/instruct_pix2pix/train_instruct_pix2pix.py
@@ -0,0 +1,1031 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Script to fine-tune Stable Diffusion for InstructPix2Pix."""
+
+import argparse
+import logging
+import math
+import os
+import shutil
+from contextlib import nullcontext
+from pathlib import Path
+
+import accelerate
+import datasets
+import numpy as np
+import PIL
+import requests
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import ProjectConfiguration, set_seed
+from datasets import load_dataset
+from huggingface_hub import create_repo, upload_folder
+from packaging import version
+from torchvision import transforms
+from tqdm.auto import tqdm
+from transformers import CLIPTextModel, CLIPTokenizer
+
+import diffusers
+from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionInstructPix2PixPipeline, UNet2DConditionModel
+from diffusers.optimization import get_scheduler
+from diffusers.training_utils import EMAModel
+from diffusers.utils import check_min_version, deprecate, is_wandb_available
+from diffusers.utils.import_utils import is_xformers_available
+from diffusers.utils.torch_utils import is_compiled_module
+
+
+if is_wandb_available():
+ import wandb
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.31.0.dev0")
+
+logger = get_logger(__name__, log_level="INFO")
+
+DATASET_NAME_MAPPING = {
+ "fusing/instructpix2pix-1000-samples": ("input_image", "edit_prompt", "edited_image"),
+}
+WANDB_TABLE_COL_NAMES = ["original_image", "edited_image", "edit_prompt"]
+
+
+def log_validation(
+ pipeline,
+ args,
+ accelerator,
+ generator,
+):
+ logger.info(
+ f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
+ f" {args.validation_prompt}."
+ )
+ pipeline = pipeline.to(accelerator.device)
+ pipeline.set_progress_bar_config(disable=True)
+
+ # run inference
+ original_image = download_image(args.val_image_url)
+ edited_images = []
+ if torch.backends.mps.is_available():
+ autocast_ctx = nullcontext()
+ else:
+ autocast_ctx = torch.autocast(accelerator.device.type)
+
+ with autocast_ctx:
+ for _ in range(args.num_validation_images):
+ edited_images.append(
+ pipeline(
+ args.validation_prompt,
+ image=original_image,
+ num_inference_steps=20,
+ image_guidance_scale=1.5,
+ guidance_scale=7,
+ generator=generator,
+ ).images[0]
+ )
+
+ for tracker in accelerator.trackers:
+ if tracker.name == "wandb":
+ wandb_table = wandb.Table(columns=WANDB_TABLE_COL_NAMES)
+ for edited_image in edited_images:
+ wandb_table.add_data(wandb.Image(original_image), wandb.Image(edited_image), args.validation_prompt)
+ tracker.log({"validation": wandb_table})
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Simple example of a training script for InstructPix2Pix.")
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--variant",
+ type=str,
+ default=None,
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
+ )
+ parser.add_argument(
+ "--dataset_name",
+ type=str,
+ default=None,
+ help=(
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
+ " or to a folder containing files that 🤗 Datasets can understand."
+ ),
+ )
+ parser.add_argument(
+ "--dataset_config_name",
+ type=str,
+ default=None,
+ help="The config of the Dataset, leave as None if there's only one config.",
+ )
+ parser.add_argument(
+ "--train_data_dir",
+ type=str,
+ default=None,
+ help=(
+ "A folder containing the training data. Folder contents must follow the structure described in"
+ " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
+ " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
+ ),
+ )
+ parser.add_argument(
+ "--original_image_column",
+ type=str,
+ default="input_image",
+ help="The column of the dataset containing the original image on which edits where made.",
+ )
+ parser.add_argument(
+ "--edited_image_column",
+ type=str,
+ default="edited_image",
+ help="The column of the dataset containing the edited image.",
+ )
+ parser.add_argument(
+ "--edit_prompt_column",
+ type=str,
+ default="edit_prompt",
+ help="The column of the dataset containing the edit instruction.",
+ )
+ parser.add_argument(
+ "--val_image_url",
+ type=str,
+ default=None,
+ help="URL to the original image that you would like to edit (used during inference for debugging purposes).",
+ )
+ parser.add_argument(
+ "--validation_prompt", type=str, default=None, help="A prompt that is sampled during training for inference."
+ )
+ parser.add_argument(
+ "--num_validation_images",
+ type=int,
+ default=4,
+ help="Number of images that should be generated during validation with `validation_prompt`.",
+ )
+ parser.add_argument(
+ "--validation_epochs",
+ type=int,
+ default=1,
+ help=(
+ "Run fine-tuning validation every X epochs. The validation process consists of running the prompt"
+ " `args.validation_prompt` multiple times: `args.num_validation_images`."
+ ),
+ )
+ parser.add_argument(
+ "--max_train_samples",
+ type=int,
+ default=None,
+ help=(
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ ),
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="instruct-pix2pix-model",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=256,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--center_crop",
+ default=False,
+ action="store_true",
+ help=(
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
+ " cropped. The images will be resized to the resolution first before cropping."
+ ),
+ )
+ parser.add_argument(
+ "--random_flip",
+ action="store_true",
+ help="whether to randomly flip images horizontally",
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=100)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--conditioning_dropout_prob",
+ type=float,
+ default=None,
+ help="Conditioning dropout probability. Drops out the conditionings (image and edit prompt) used in training InstructPix2Pix. See section 3.2.1 in the paper: https://arxiv.org/abs/2211.09800.",
+ )
+ parser.add_argument(
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")
+ parser.add_argument(
+ "--non_ema_revision",
+ type=str,
+ default=None,
+ required=False,
+ help=(
+ "Revision of pretrained non-ema model identifier. Must be a branch, tag or git identifier of the local or"
+ " remote repository specified with --pretrained_model_name_or_path."
+ ),
+ )
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
+ )
+
+ args = parser.parse_args()
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ # Sanity checks
+ if args.dataset_name is None and args.train_data_dir is None:
+ raise ValueError("Need either a dataset name or a training folder.")
+
+ # default to using the same revision for the non-ema model if not specified
+ if args.non_ema_revision is None:
+ args.non_ema_revision = args.revision
+
+ return args
+
+
+def convert_to_np(image, resolution):
+ image = image.convert("RGB").resize((resolution, resolution))
+ return np.array(image).transpose(2, 0, 1)
+
+
+def download_image(url):
+ image = PIL.Image.open(requests.get(url, stream=True).raw)
+ image = PIL.ImageOps.exif_transpose(image)
+ image = image.convert("RGB")
+ return image
+
+
+def main():
+ args = parse_args()
+ if args.report_to == "wandb" and args.hub_token is not None:
+ raise ValueError(
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
+ " Please use `huggingface-cli login` to authenticate with the Hub."
+ )
+
+ if args.non_ema_revision is not None:
+ deprecate(
+ "non_ema_revision!=None",
+ "0.15.0",
+ message=(
+ "Downloading 'non_ema' weights from revision branches of the Hub is deprecated. Please make sure to"
+ " use `--variant=non_ema` instead."
+ ),
+ )
+ logging_dir = os.path.join(args.output_dir, args.logging_dir)
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ )
+
+ # Disable AMP for MPS.
+ if torch.backends.mps.is_available():
+ accelerator.native_amp = False
+
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ datasets.utils.logging.set_verbosity_warning()
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ datasets.utils.logging.set_verbosity_error()
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
+ ).repo_id
+
+ # Load scheduler, tokenizer and models.
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
+ tokenizer = CLIPTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
+ )
+ text_encoder = CLIPTextModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
+ )
+ vae = AutoencoderKL.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant
+ )
+ unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.non_ema_revision
+ )
+
+ # InstructPix2Pix uses an additional image for conditioning. To accommodate that,
+ # it uses 8 channels (instead of 4) in the first (conv) layer of the UNet. This UNet is
+ # then fine-tuned on the custom InstructPix2Pix dataset. This modified UNet is initialized
+ # from the pre-trained checkpoints. For the extra channels added to the first layer, they are
+ # initialized to zero.
+ logger.info("Initializing the InstructPix2Pix UNet from the pretrained UNet.")
+ in_channels = 8
+ out_channels = unet.conv_in.out_channels
+ unet.register_to_config(in_channels=in_channels)
+
+ with torch.no_grad():
+ new_conv_in = nn.Conv2d(
+ in_channels, out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding
+ )
+ new_conv_in.weight.zero_()
+ new_conv_in.weight[:, :4, :, :].copy_(unet.conv_in.weight)
+ unet.conv_in = new_conv_in
+
+ # Freeze vae and text_encoder
+ vae.requires_grad_(False)
+ text_encoder.requires_grad_(False)
+
+ # Create EMA for the unet.
+ if args.use_ema:
+ ema_unet = EMAModel(unet.parameters(), model_cls=UNet2DConditionModel, model_config=unet.config)
+
+ if args.enable_xformers_memory_efficient_attention:
+ if is_xformers_available():
+ import xformers
+
+ xformers_version = version.parse(xformers.__version__)
+ if xformers_version == version.parse("0.0.16"):
+ logger.warning(
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
+ )
+ unet.enable_xformers_memory_efficient_attention()
+ else:
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
+
+ def unwrap_model(model):
+ model = accelerator.unwrap_model(model)
+ model = model._orig_mod if is_compiled_module(model) else model
+ return model
+
+ # `accelerate` 0.16.0 will have better support for customized saving
+ if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
+ def save_model_hook(models, weights, output_dir):
+ if accelerator.is_main_process:
+ if args.use_ema:
+ ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema"))
+
+ for i, model in enumerate(models):
+ model.save_pretrained(os.path.join(output_dir, "unet"))
+
+ # make sure to pop weight so that corresponding model is not saved again
+ if weights:
+ weights.pop()
+
+ def load_model_hook(models, input_dir):
+ if args.use_ema:
+ load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), UNet2DConditionModel)
+ ema_unet.load_state_dict(load_model.state_dict())
+ ema_unet.to(accelerator.device)
+ del load_model
+
+ for i in range(len(models)):
+ # pop models so that they are not loaded again
+ model = models.pop()
+
+ # load diffusers style into model
+ load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet")
+ model.register_to_config(**load_model.config)
+
+ model.load_state_dict(load_model.state_dict())
+ del load_model
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
+
+ if args.gradient_checkpointing:
+ unet.enable_gradient_checkpointing()
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ # Initialize the optimizer
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
+ )
+
+ optimizer_cls = bnb.optim.AdamW8bit
+ else:
+ optimizer_cls = torch.optim.AdamW
+
+ optimizer = optimizer_cls(
+ unet.parameters(),
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ # Get the datasets: you can either provide your own training and evaluation files (see below)
+ # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
+
+ # In distributed training, the load_dataset function guarantees that only one local process can concurrently
+ # download the dataset.
+ if args.dataset_name is not None:
+ # Downloading and loading a dataset from the hub.
+ dataset = load_dataset(
+ args.dataset_name,
+ args.dataset_config_name,
+ cache_dir=args.cache_dir,
+ )
+ else:
+ data_files = {}
+ if args.train_data_dir is not None:
+ data_files["train"] = os.path.join(args.train_data_dir, "**")
+ dataset = load_dataset(
+ "imagefolder",
+ data_files=data_files,
+ cache_dir=args.cache_dir,
+ )
+ # See more about loading custom images at
+ # https://huggingface.co/docs/datasets/main/en/image_load#imagefolder
+
+ # Preprocessing the datasets.
+ # We need to tokenize inputs and targets.
+ column_names = dataset["train"].column_names
+
+ # 6. Get the column names for input/target.
+ dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None)
+ if args.original_image_column is None:
+ original_image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
+ else:
+ original_image_column = args.original_image_column
+ if original_image_column not in column_names:
+ raise ValueError(
+ f"--original_image_column' value '{args.original_image_column}' needs to be one of: {', '.join(column_names)}"
+ )
+ if args.edit_prompt_column is None:
+ edit_prompt_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
+ else:
+ edit_prompt_column = args.edit_prompt_column
+ if edit_prompt_column not in column_names:
+ raise ValueError(
+ f"--edit_prompt_column' value '{args.edit_prompt_column}' needs to be one of: {', '.join(column_names)}"
+ )
+ if args.edited_image_column is None:
+ edited_image_column = dataset_columns[2] if dataset_columns is not None else column_names[2]
+ else:
+ edited_image_column = args.edited_image_column
+ if edited_image_column not in column_names:
+ raise ValueError(
+ f"--edited_image_column' value '{args.edited_image_column}' needs to be one of: {', '.join(column_names)}"
+ )
+
+ # Preprocessing the datasets.
+ # We need to tokenize input captions and transform the images.
+ def tokenize_captions(captions):
+ inputs = tokenizer(
+ captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
+ )
+ return inputs.input_ids
+
+ # Preprocessing the datasets.
+ train_transforms = transforms.Compose(
+ [
+ transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
+ transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
+ ]
+ )
+
+ def preprocess_images(examples):
+ original_images = np.concatenate(
+ [convert_to_np(image, args.resolution) for image in examples[original_image_column]]
+ )
+ edited_images = np.concatenate(
+ [convert_to_np(image, args.resolution) for image in examples[edited_image_column]]
+ )
+ # We need to ensure that the original and the edited images undergo the same
+ # augmentation transforms.
+ images = np.concatenate([original_images, edited_images])
+ images = torch.tensor(images)
+ images = 2 * (images / 255) - 1
+ return train_transforms(images)
+
+ def preprocess_train(examples):
+ # Preprocess images.
+ preprocessed_images = preprocess_images(examples)
+ # Since the original and edited images were concatenated before
+ # applying the transformations, we need to separate them and reshape
+ # them accordingly.
+ original_images, edited_images = preprocessed_images.chunk(2)
+ original_images = original_images.reshape(-1, 3, args.resolution, args.resolution)
+ edited_images = edited_images.reshape(-1, 3, args.resolution, args.resolution)
+
+ # Collate the preprocessed images into the `examples`.
+ examples["original_pixel_values"] = original_images
+ examples["edited_pixel_values"] = edited_images
+
+ # Preprocess the captions.
+ captions = list(examples[edit_prompt_column])
+ examples["input_ids"] = tokenize_captions(captions)
+ return examples
+
+ with accelerator.main_process_first():
+ if args.max_train_samples is not None:
+ dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
+ # Set the training transforms
+ train_dataset = dataset["train"].with_transform(preprocess_train)
+
+ def collate_fn(examples):
+ original_pixel_values = torch.stack([example["original_pixel_values"] for example in examples])
+ original_pixel_values = original_pixel_values.to(memory_format=torch.contiguous_format).float()
+ edited_pixel_values = torch.stack([example["edited_pixel_values"] for example in examples])
+ edited_pixel_values = edited_pixel_values.to(memory_format=torch.contiguous_format).float()
+ input_ids = torch.stack([example["input_ids"] for example in examples])
+ return {
+ "original_pixel_values": original_pixel_values,
+ "edited_pixel_values": edited_pixel_values,
+ "input_ids": input_ids,
+ }
+
+ # DataLoaders creation:
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset,
+ shuffle=True,
+ collate_fn=collate_fn,
+ batch_size=args.train_batch_size,
+ num_workers=args.dataloader_num_workers,
+ )
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
+ )
+
+ # Prepare everything with our `accelerator`.
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ unet, optimizer, train_dataloader, lr_scheduler
+ )
+
+ if args.use_ema:
+ ema_unet.to(accelerator.device)
+
+ # For mixed precision training we cast the text_encoder and vae weights to half-precision
+ # as these models are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ # Move text_encode and vae to gpu and cast to weight_dtype
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
+ vae.to(accelerator.device, dtype=weight_dtype)
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ accelerator.init_trackers("instruct-pix2pix", config=vars(args))
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the most recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ resume_global_step = global_step * args.gradient_accumulation_steps
+ first_epoch = global_step // num_update_steps_per_epoch
+ resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
+
+ # Only show the progress bar once on each machine.
+ progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
+ progress_bar.set_description("Steps")
+
+ for epoch in range(first_epoch, args.num_train_epochs):
+ unet.train()
+ train_loss = 0.0
+ for step, batch in enumerate(train_dataloader):
+ # Skip steps until we reach the resumed step
+ if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
+ if step % args.gradient_accumulation_steps == 0:
+ progress_bar.update(1)
+ continue
+
+ with accelerator.accumulate(unet):
+ # We want to learn the denoising process w.r.t the edited images which
+ # are conditioned on the original image (which was edited) and the edit instruction.
+ # So, first, convert images to latent space.
+ latents = vae.encode(batch["edited_pixel_values"].to(weight_dtype)).latent_dist.sample()
+ latents = latents * vae.config.scaling_factor
+
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(latents)
+ bsz = latents.shape[0]
+ # Sample a random timestep for each image
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
+ timesteps = timesteps.long()
+
+ # Add noise to the latents according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
+
+ # Get the text embedding for conditioning.
+ encoder_hidden_states = text_encoder(batch["input_ids"])[0]
+
+ # Get the additional image embedding for conditioning.
+ # Instead of getting a diagonal Gaussian here, we simply take the mode.
+ original_image_embeds = vae.encode(batch["original_pixel_values"].to(weight_dtype)).latent_dist.mode()
+
+ # Conditioning dropout to support classifier-free guidance during inference. For more details
+ # check out the section 3.2.1 of the original paper https://arxiv.org/abs/2211.09800.
+ if args.conditioning_dropout_prob is not None:
+ random_p = torch.rand(bsz, device=latents.device, generator=generator)
+ # Sample masks for the edit prompts.
+ prompt_mask = random_p < 2 * args.conditioning_dropout_prob
+ prompt_mask = prompt_mask.reshape(bsz, 1, 1)
+ # Final text conditioning.
+ null_conditioning = text_encoder(tokenize_captions([""]).to(accelerator.device))[0]
+ encoder_hidden_states = torch.where(prompt_mask, null_conditioning, encoder_hidden_states)
+
+ # Sample masks for the original images.
+ image_mask_dtype = original_image_embeds.dtype
+ image_mask = 1 - (
+ (random_p >= args.conditioning_dropout_prob).to(image_mask_dtype)
+ * (random_p < 3 * args.conditioning_dropout_prob).to(image_mask_dtype)
+ )
+ image_mask = image_mask.reshape(bsz, 1, 1, 1)
+ # Final image conditioning.
+ original_image_embeds = image_mask * original_image_embeds
+
+ # Concatenate the `original_image_embeds` with the `noisy_latents`.
+ concatenated_noisy_latents = torch.cat([noisy_latents, original_image_embeds], dim=1)
+
+ # Get the target for loss depending on the prediction type
+ if noise_scheduler.config.prediction_type == "epsilon":
+ target = noise
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
+ else:
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
+
+ # Predict the noise residual and compute loss
+ model_pred = unet(concatenated_noisy_latents, timesteps, encoder_hidden_states, return_dict=False)[0]
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
+
+ # Gather the losses across all processes for logging (if we use distributed training).
+ avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
+ train_loss += avg_loss.item() / args.gradient_accumulation_steps
+
+ # Backpropagate
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ if args.use_ema:
+ ema_unet.step(unet.parameters())
+ progress_bar.update(1)
+ global_step += 1
+ accelerator.log({"train_loss": train_loss}, step=global_step)
+ train_loss = 0.0
+
+ if global_step % args.checkpointing_steps == 0:
+ if accelerator.is_main_process:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ if accelerator.is_main_process:
+ if (
+ (args.val_image_url is not None)
+ and (args.validation_prompt is not None)
+ and (epoch % args.validation_epochs == 0)
+ ):
+ if args.use_ema:
+ # Store the UNet parameters temporarily and load the EMA parameters to perform inference.
+ ema_unet.store(unet.parameters())
+ ema_unet.copy_to(unet.parameters())
+ # The models need unwrapping because for compatibility in distributed training mode.
+ pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ unet=unwrap_model(unet),
+ text_encoder=unwrap_model(text_encoder),
+ vae=unwrap_model(vae),
+ revision=args.revision,
+ variant=args.variant,
+ torch_dtype=weight_dtype,
+ )
+
+ log_validation(
+ pipeline,
+ args,
+ accelerator,
+ generator,
+ )
+
+ if args.use_ema:
+ # Switch back to the original UNet parameters.
+ ema_unet.restore(unet.parameters())
+
+ del pipeline
+ torch.cuda.empty_cache()
+
+ # Create the pipeline using the trained modules and save it.
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ if args.use_ema:
+ ema_unet.copy_to(unet.parameters())
+
+ pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ text_encoder=unwrap_model(text_encoder),
+ vae=unwrap_model(vae),
+ unet=unwrap_model(unet),
+ revision=args.revision,
+ variant=args.variant,
+ )
+ pipeline.save_pretrained(args.output_dir)
+
+ if args.push_to_hub:
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ if (args.val_image_url is not None) and (args.validation_prompt is not None):
+ log_validation(
+ pipeline,
+ args,
+ accelerator,
+ generator,
+ )
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/diffusers/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py b/diffusers/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py
new file mode 100644
index 0000000000000000000000000000000000000000..c88be6d16d88d57276f149522741afaf65e4c026
--- /dev/null
+++ b/diffusers/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py
@@ -0,0 +1,1261 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2024 Harutatsu Akiyama and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import logging
+import math
+import os
+import shutil
+import warnings
+from contextlib import nullcontext
+from pathlib import Path
+from urllib.parse import urlparse
+
+import accelerate
+import datasets
+import numpy as np
+import PIL
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import ProjectConfiguration, set_seed
+from datasets import load_dataset
+from huggingface_hub import create_repo, upload_folder
+from packaging import version
+from PIL import Image
+from torchvision import transforms
+from tqdm.auto import tqdm
+from transformers import AutoTokenizer, PretrainedConfig
+
+import diffusers
+from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel
+from diffusers.optimization import get_scheduler
+from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_instruct_pix2pix import (
+ StableDiffusionXLInstructPix2PixPipeline,
+)
+from diffusers.training_utils import EMAModel
+from diffusers.utils import check_min_version, deprecate, is_wandb_available, load_image
+from diffusers.utils.import_utils import is_xformers_available
+from diffusers.utils.torch_utils import is_compiled_module
+
+
+if is_wandb_available():
+ import wandb
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.31.0.dev0")
+
+logger = get_logger(__name__, log_level="INFO")
+
+DATASET_NAME_MAPPING = {
+ "fusing/instructpix2pix-1000-samples": ("file_name", "edited_image", "edit_prompt"),
+}
+WANDB_TABLE_COL_NAMES = ["file_name", "edited_image", "edit_prompt"]
+TORCH_DTYPE_MAPPING = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}
+
+
+def log_validation(pipeline, args, accelerator, generator, global_step, is_final_validation=False):
+ logger.info(
+ f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
+ f" {args.validation_prompt}."
+ )
+
+ pipeline = pipeline.to(accelerator.device)
+ pipeline.set_progress_bar_config(disable=True)
+
+ val_save_dir = os.path.join(args.output_dir, "validation_images")
+ if not os.path.exists(val_save_dir):
+ os.makedirs(val_save_dir)
+
+ original_image = (
+ lambda image_url_or_path: load_image(image_url_or_path)
+ if urlparse(image_url_or_path).scheme
+ else Image.open(image_url_or_path).convert("RGB")
+ )(args.val_image_url_or_path)
+
+ if torch.backends.mps.is_available():
+ autocast_ctx = nullcontext()
+ else:
+ autocast_ctx = torch.autocast(accelerator.device.type)
+
+ with autocast_ctx:
+ edited_images = []
+ # Run inference
+ for val_img_idx in range(args.num_validation_images):
+ a_val_img = pipeline(
+ args.validation_prompt,
+ image=original_image,
+ num_inference_steps=20,
+ image_guidance_scale=1.5,
+ guidance_scale=7,
+ generator=generator,
+ ).images[0]
+ edited_images.append(a_val_img)
+ # Save validation images
+ a_val_img.save(os.path.join(val_save_dir, f"step_{global_step}_val_img_{val_img_idx}.png"))
+
+ for tracker in accelerator.trackers:
+ if tracker.name == "wandb":
+ wandb_table = wandb.Table(columns=WANDB_TABLE_COL_NAMES)
+ for edited_image in edited_images:
+ wandb_table.add_data(wandb.Image(original_image), wandb.Image(edited_image), args.validation_prompt)
+ logger_name = "test" if is_final_validation else "validation"
+ tracker.log({logger_name: wandb_table})
+
+
+def import_model_class_from_model_name_or_path(
+ pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
+):
+ text_encoder_config = PretrainedConfig.from_pretrained(
+ pretrained_model_name_or_path, subfolder=subfolder, revision=revision
+ )
+ model_class = text_encoder_config.architectures[0]
+
+ if model_class == "CLIPTextModel":
+ from transformers import CLIPTextModel
+
+ return CLIPTextModel
+ elif model_class == "CLIPTextModelWithProjection":
+ from transformers import CLIPTextModelWithProjection
+
+ return CLIPTextModelWithProjection
+ else:
+ raise ValueError(f"{model_class} is not supported.")
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Script to train Stable Diffusion XL for InstructPix2Pix.")
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--pretrained_vae_model_name_or_path",
+ type=str,
+ default=None,
+ help="Path to an improved VAE to stabilize training. For more details check out: https://github.com/huggingface/diffusers/pull/4038.",
+ )
+ parser.add_argument(
+ "--vae_precision",
+ type=str,
+ choices=["fp32", "fp16", "bf16"],
+ default="fp32",
+ help=(
+ "The vanilla SDXL 1.0 VAE can cause NaNs due to large activation values. Some custom models might already have a solution"
+ " to this problem, and this flag allows you to use mixed precision to stabilize training."
+ ),
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--variant",
+ type=str,
+ default=None,
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
+ )
+ parser.add_argument(
+ "--dataset_name",
+ type=str,
+ default=None,
+ help=(
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
+ " or to a folder containing files that 🤗 Datasets can understand."
+ ),
+ )
+ parser.add_argument(
+ "--dataset_config_name",
+ type=str,
+ default=None,
+ help="The config of the Dataset, leave as None if there's only one config.",
+ )
+ parser.add_argument(
+ "--train_data_dir",
+ type=str,
+ default=None,
+ help=(
+ "A folder containing the training data. Folder contents must follow the structure described in"
+ " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
+ " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
+ ),
+ )
+ parser.add_argument(
+ "--original_image_column",
+ type=str,
+ default="input_image",
+ help="The column of the dataset containing the original image on which edits where made.",
+ )
+ parser.add_argument(
+ "--edited_image_column",
+ type=str,
+ default="edited_image",
+ help="The column of the dataset containing the edited image.",
+ )
+ parser.add_argument(
+ "--edit_prompt_column",
+ type=str,
+ default="edit_prompt",
+ help="The column of the dataset containing the edit instruction.",
+ )
+ parser.add_argument(
+ "--val_image_url_or_path",
+ type=str,
+ default=None,
+ help="URL to the original image that you would like to edit (used during inference for debugging purposes).",
+ )
+ parser.add_argument(
+ "--validation_prompt", type=str, default=None, help="A prompt that is sampled during training for inference."
+ )
+ parser.add_argument(
+ "--num_validation_images",
+ type=int,
+ default=4,
+ help="Number of images that should be generated during validation with `validation_prompt`.",
+ )
+ parser.add_argument(
+ "--validation_steps",
+ type=int,
+ default=100,
+ help=(
+ "Run fine-tuning validation every X steps. The validation process consists of running the prompt"
+ " `args.validation_prompt` multiple times: `args.num_validation_images`."
+ ),
+ )
+ parser.add_argument(
+ "--max_train_samples",
+ type=int,
+ default=None,
+ help=(
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ ),
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="instruct-pix2pix-model",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=256,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this resolution."
+ ),
+ )
+ parser.add_argument(
+ "--crops_coords_top_left_h",
+ type=int,
+ default=0,
+ help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."),
+ )
+ parser.add_argument(
+ "--crops_coords_top_left_w",
+ type=int,
+ default=0,
+ help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."),
+ )
+ parser.add_argument(
+ "--center_crop",
+ default=False,
+ action="store_true",
+ help=(
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
+ " cropped. The images will be resized to the resolution first before cropping."
+ ),
+ )
+ parser.add_argument(
+ "--random_flip",
+ action="store_true",
+ help="whether to randomly flip images horizontally",
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=100)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--conditioning_dropout_prob",
+ type=float,
+ default=None,
+ help="Conditioning dropout probability. Drops out the conditionings (image and edit prompt) used in training InstructPix2Pix. See section 3.2.1 in the paper: https://arxiv.org/abs/2211.09800.",
+ )
+ parser.add_argument(
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")
+ parser.add_argument(
+ "--non_ema_revision",
+ type=str,
+ default=None,
+ required=False,
+ help=(
+ "Revision of pretrained non-ema model identifier. Must be a branch, tag or git identifier of the local or"
+ " remote repository specified with --pretrained_model_name_or_path."
+ ),
+ )
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
+ )
+
+ args = parser.parse_args()
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ # Sanity checks
+ if args.dataset_name is None and args.train_data_dir is None:
+ raise ValueError("Need either a dataset name or a training folder.")
+
+ # default to using the same revision for the non-ema model if not specified
+ if args.non_ema_revision is None:
+ args.non_ema_revision = args.revision
+
+ return args
+
+
+def convert_to_np(image, resolution):
+ if isinstance(image, str):
+ image = PIL.Image.open(image)
+ image = image.convert("RGB").resize((resolution, resolution))
+ return np.array(image).transpose(2, 0, 1)
+
+
+def main():
+ args = parse_args()
+
+ if args.report_to == "wandb" and args.hub_token is not None:
+ raise ValueError(
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
+ " Please use `huggingface-cli login` to authenticate with the Hub."
+ )
+
+ if args.non_ema_revision is not None:
+ deprecate(
+ "non_ema_revision!=None",
+ "0.15.0",
+ message=(
+ "Downloading 'non_ema' weights from revision branches of the Hub is deprecated. Please make sure to"
+ " use `--variant=non_ema` instead."
+ ),
+ )
+ logging_dir = os.path.join(args.output_dir, args.logging_dir)
+
+ if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
+ # due to pytorch#99272, MPS does not yet support bfloat16.
+ raise ValueError(
+ "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
+ )
+
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ )
+
+ # Disable AMP for MPS.
+ if torch.backends.mps.is_available():
+ accelerator.native_amp = False
+
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ datasets.utils.logging.set_verbosity_warning()
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ datasets.utils.logging.set_verbosity_error()
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
+ ).repo_id
+
+ vae_path = (
+ args.pretrained_model_name_or_path
+ if args.pretrained_vae_model_name_or_path is None
+ else args.pretrained_vae_model_name_or_path
+ )
+ vae = AutoencoderKL.from_pretrained(
+ vae_path,
+ subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
+ revision=args.revision,
+ variant=args.variant,
+ )
+ unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
+ )
+
+ # InstructPix2Pix uses an additional image for conditioning. To accommodate that,
+ # it uses 8 channels (instead of 4) in the first (conv) layer of the UNet. This UNet is
+ # then fine-tuned on the custom InstructPix2Pix dataset. This modified UNet is initialized
+ # from the pre-trained checkpoints. For the extra channels added to the first layer, they are
+ # initialized to zero.
+ logger.info("Initializing the XL InstructPix2Pix UNet from the pretrained UNet.")
+ in_channels = 8
+ out_channels = unet.conv_in.out_channels
+ unet.register_to_config(in_channels=in_channels)
+
+ with torch.no_grad():
+ new_conv_in = nn.Conv2d(
+ in_channels, out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding
+ )
+ new_conv_in.weight.zero_()
+ new_conv_in.weight[:, :4, :, :].copy_(unet.conv_in.weight)
+ unet.conv_in = new_conv_in
+
+ # Create EMA for the unet.
+ if args.use_ema:
+ ema_unet = EMAModel(unet.parameters(), model_cls=UNet2DConditionModel, model_config=unet.config)
+
+ if args.enable_xformers_memory_efficient_attention:
+ if is_xformers_available():
+ import xformers
+
+ xformers_version = version.parse(xformers.__version__)
+ if xformers_version == version.parse("0.0.16"):
+ logger.warning(
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
+ )
+ unet.enable_xformers_memory_efficient_attention()
+ else:
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
+
+ def unwrap_model(model):
+ model = accelerator.unwrap_model(model)
+ model = model._orig_mod if is_compiled_module(model) else model
+ return model
+
+ # `accelerate` 0.16.0 will have better support for customized saving
+ if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
+ def save_model_hook(models, weights, output_dir):
+ if accelerator.is_main_process:
+ if args.use_ema:
+ ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema"))
+
+ for i, model in enumerate(models):
+ model.save_pretrained(os.path.join(output_dir, "unet"))
+
+ # make sure to pop weight so that corresponding model is not saved again
+ weights.pop()
+
+ def load_model_hook(models, input_dir):
+ if args.use_ema:
+ load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), UNet2DConditionModel)
+ ema_unet.load_state_dict(load_model.state_dict())
+ ema_unet.to(accelerator.device)
+ del load_model
+
+ for i in range(len(models)):
+ # pop models so that they are not loaded again
+ model = models.pop()
+
+ # load diffusers style into model
+ load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet")
+ model.register_to_config(**load_model.config)
+
+ model.load_state_dict(load_model.state_dict())
+ del load_model
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
+
+ if args.gradient_checkpointing:
+ unet.enable_gradient_checkpointing()
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ # Initialize the optimizer
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
+ )
+
+ optimizer_cls = bnb.optim.AdamW8bit
+ else:
+ optimizer_cls = torch.optim.AdamW
+
+ optimizer = optimizer_cls(
+ unet.parameters(),
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ # Get the datasets: you can either provide your own training and evaluation files (see below)
+ # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
+
+ # In distributed training, the load_dataset function guarantees that only one local process can concurrently
+ # download the dataset.
+ if args.dataset_name is not None:
+ # Downloading and loading a dataset from the hub.
+ dataset = load_dataset(
+ args.dataset_name,
+ args.dataset_config_name,
+ cache_dir=args.cache_dir,
+ )
+ else:
+ data_files = {}
+ if args.train_data_dir is not None:
+ data_files["train"] = os.path.join(args.train_data_dir, "**")
+ dataset = load_dataset(
+ "imagefolder",
+ data_files=data_files,
+ cache_dir=args.cache_dir,
+ )
+ # See more about loading custom images at
+ # https://huggingface.co/docs/datasets/main/en/image_load#imagefolder
+
+ # Preprocessing the datasets.
+ # We need to tokenize inputs and targets.
+ column_names = dataset["train"].column_names
+
+ # 6. Get the column names for input/target.
+ dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None)
+ if args.original_image_column is None:
+ original_image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
+ else:
+ original_image_column = args.original_image_column
+ if original_image_column not in column_names:
+ raise ValueError(
+ f"--original_image_column' value '{args.original_image_column}' needs to be one of: {', '.join(column_names)}"
+ )
+ if args.edit_prompt_column is None:
+ edit_prompt_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
+ else:
+ edit_prompt_column = args.edit_prompt_column
+ if edit_prompt_column not in column_names:
+ raise ValueError(
+ f"--edit_prompt_column' value '{args.edit_prompt_column}' needs to be one of: {', '.join(column_names)}"
+ )
+ if args.edited_image_column is None:
+ edited_image_column = dataset_columns[2] if dataset_columns is not None else column_names[2]
+ else:
+ edited_image_column = args.edited_image_column
+ if edited_image_column not in column_names:
+ raise ValueError(
+ f"--edited_image_column' value '{args.edited_image_column}' needs to be one of: {', '.join(column_names)}"
+ )
+
+ # For mixed precision training we cast the text_encoder and vae weights to half-precision
+ # as these models are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ warnings.warn(f"weight_dtype {weight_dtype} may cause nan during vae encoding", UserWarning)
+
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+ warnings.warn(f"weight_dtype {weight_dtype} may cause nan during vae encoding", UserWarning)
+
+ # Preprocessing the datasets.
+ # We need to tokenize input captions and transform the images.
+ def tokenize_captions(captions, tokenizer):
+ inputs = tokenizer(
+ captions,
+ max_length=tokenizer.model_max_length,
+ padding="max_length",
+ truncation=True,
+ return_tensors="pt",
+ )
+ return inputs.input_ids
+
+ # Preprocessing the datasets.
+ train_transforms = transforms.Compose(
+ [
+ transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
+ transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
+ ]
+ )
+
+ def preprocess_images(examples):
+ original_images = np.concatenate(
+ [convert_to_np(image, args.resolution) for image in examples[original_image_column]]
+ )
+ edited_images = np.concatenate(
+ [convert_to_np(image, args.resolution) for image in examples[edited_image_column]]
+ )
+ # We need to ensure that the original and the edited images undergo the same
+ # augmentation transforms.
+ images = np.concatenate([original_images, edited_images])
+ images = torch.tensor(images)
+ images = 2 * (images / 255) - 1
+ return train_transforms(images)
+
+ # Load scheduler, tokenizer and models.
+ tokenizer_1 = AutoTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="tokenizer",
+ revision=args.revision,
+ use_fast=False,
+ )
+ tokenizer_2 = AutoTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="tokenizer_2",
+ revision=args.revision,
+ use_fast=False,
+ )
+ text_encoder_cls_1 = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)
+ text_encoder_cls_2 = import_model_class_from_model_name_or_path(
+ args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2"
+ )
+
+ # Load scheduler and models
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
+ text_encoder_1 = text_encoder_cls_1.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
+ )
+ text_encoder_2 = text_encoder_cls_2.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant
+ )
+
+ # We ALWAYS pre-compute the additional condition embeddings needed for SDXL
+ # UNet as the model is already big and it uses two text encoders.
+ text_encoder_1.to(accelerator.device, dtype=weight_dtype)
+ text_encoder_2.to(accelerator.device, dtype=weight_dtype)
+ tokenizers = [tokenizer_1, tokenizer_2]
+ text_encoders = [text_encoder_1, text_encoder_2]
+
+ # Freeze vae and text_encoders
+ vae.requires_grad_(False)
+ text_encoder_1.requires_grad_(False)
+ text_encoder_2.requires_grad_(False)
+
+ # Set UNet to trainable.
+ unet.train()
+
+ # Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt
+ def encode_prompt(text_encoders, tokenizers, prompt):
+ prompt_embeds_list = []
+
+ for tokenizer, text_encoder in zip(tokenizers, text_encoders):
+ text_inputs = tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ prompt_embeds = text_encoder(
+ text_input_ids.to(text_encoder.device),
+ output_hidden_states=True,
+ )
+
+ # We are only ALWAYS interested in the pooled output of the final text encoder
+ pooled_prompt_embeds = prompt_embeds[0]
+ prompt_embeds = prompt_embeds.hidden_states[-2]
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
+ prompt_embeds_list.append(prompt_embeds)
+
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
+ pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
+ return prompt_embeds, pooled_prompt_embeds
+
+ # Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt
+ def encode_prompts(text_encoders, tokenizers, prompts):
+ prompt_embeds_all = []
+ pooled_prompt_embeds_all = []
+
+ for prompt in prompts:
+ prompt_embeds, pooled_prompt_embeds = encode_prompt(text_encoders, tokenizers, prompt)
+ prompt_embeds_all.append(prompt_embeds)
+ pooled_prompt_embeds_all.append(pooled_prompt_embeds)
+
+ return torch.stack(prompt_embeds_all), torch.stack(pooled_prompt_embeds_all)
+
+ # Adapted from examples.dreambooth.train_dreambooth_lora_sdxl
+ # Here, we compute not just the text embeddings but also the additional embeddings
+ # needed for the SD XL UNet to operate.
+ def compute_embeddings_for_prompts(prompts, text_encoders, tokenizers):
+ with torch.no_grad():
+ prompt_embeds_all, pooled_prompt_embeds_all = encode_prompts(text_encoders, tokenizers, prompts)
+ add_text_embeds_all = pooled_prompt_embeds_all
+
+ prompt_embeds_all = prompt_embeds_all.to(accelerator.device)
+ add_text_embeds_all = add_text_embeds_all.to(accelerator.device)
+ return prompt_embeds_all, add_text_embeds_all
+
+ # Get null conditioning
+ def compute_null_conditioning():
+ null_conditioning_list = []
+ for a_tokenizer, a_text_encoder in zip(tokenizers, text_encoders):
+ null_conditioning_list.append(
+ a_text_encoder(
+ tokenize_captions([""], tokenizer=a_tokenizer).to(accelerator.device),
+ output_hidden_states=True,
+ ).hidden_states[-2]
+ )
+ return torch.concat(null_conditioning_list, dim=-1)
+
+ null_conditioning = compute_null_conditioning()
+
+ def compute_time_ids():
+ crops_coords_top_left = (args.crops_coords_top_left_h, args.crops_coords_top_left_w)
+ original_size = target_size = (args.resolution, args.resolution)
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
+ add_time_ids = torch.tensor([add_time_ids], dtype=weight_dtype)
+ return add_time_ids.to(accelerator.device).repeat(args.train_batch_size, 1)
+
+ add_time_ids = compute_time_ids()
+
+ def preprocess_train(examples):
+ # Preprocess images.
+ preprocessed_images = preprocess_images(examples)
+ # Since the original and edited images were concatenated before
+ # applying the transformations, we need to separate them and reshape
+ # them accordingly.
+ original_images, edited_images = preprocessed_images.chunk(2)
+ original_images = original_images.reshape(-1, 3, args.resolution, args.resolution)
+ edited_images = edited_images.reshape(-1, 3, args.resolution, args.resolution)
+
+ # Collate the preprocessed images into the `examples`.
+ examples["original_pixel_values"] = original_images
+ examples["edited_pixel_values"] = edited_images
+
+ # Preprocess the captions.
+ captions = list(examples[edit_prompt_column])
+ prompt_embeds_all, add_text_embeds_all = compute_embeddings_for_prompts(captions, text_encoders, tokenizers)
+ examples["prompt_embeds"] = prompt_embeds_all
+ examples["add_text_embeds"] = add_text_embeds_all
+ return examples
+
+ with accelerator.main_process_first():
+ if args.max_train_samples is not None:
+ dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
+ # Set the training transforms
+ train_dataset = dataset["train"].with_transform(preprocess_train)
+
+ def collate_fn(examples):
+ original_pixel_values = torch.stack([example["original_pixel_values"] for example in examples])
+ original_pixel_values = original_pixel_values.to(memory_format=torch.contiguous_format).float()
+ edited_pixel_values = torch.stack([example["edited_pixel_values"] for example in examples])
+ edited_pixel_values = edited_pixel_values.to(memory_format=torch.contiguous_format).float()
+ prompt_embeds = torch.concat([example["prompt_embeds"] for example in examples], dim=0)
+ add_text_embeds = torch.concat([example["add_text_embeds"] for example in examples], dim=0)
+ return {
+ "original_pixel_values": original_pixel_values,
+ "edited_pixel_values": edited_pixel_values,
+ "prompt_embeds": prompt_embeds,
+ "add_text_embeds": add_text_embeds,
+ }
+
+ # DataLoaders creation:
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset,
+ shuffle=True,
+ collate_fn=collate_fn,
+ batch_size=args.train_batch_size,
+ num_workers=args.dataloader_num_workers,
+ )
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
+ num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
+ )
+
+ # Prepare everything with our `accelerator`.
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ unet, optimizer, train_dataloader, lr_scheduler
+ )
+
+ if args.use_ema:
+ ema_unet.to(accelerator.device)
+
+ # Move vae, unet and text_encoder to device and cast to weight_dtype
+ # The VAE is in float32 to avoid NaN losses.
+ if args.pretrained_vae_model_name_or_path is not None:
+ vae.to(accelerator.device, dtype=weight_dtype)
+ else:
+ vae.to(accelerator.device, dtype=TORCH_DTYPE_MAPPING[args.vae_precision])
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ accelerator.init_trackers("instruct-pix2pix-xl", config=vars(args))
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the most recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ initial_global_step = 0
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ initial_global_step = global_step
+ first_epoch = global_step // num_update_steps_per_epoch
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
+
+ for epoch in range(first_epoch, args.num_train_epochs):
+ train_loss = 0.0
+ for step, batch in enumerate(train_dataloader):
+ with accelerator.accumulate(unet):
+ # We want to learn the denoising process w.r.t the edited images which
+ # are conditioned on the original image (which was edited) and the edit instruction.
+ # So, first, convert images to latent space.
+ if args.pretrained_vae_model_name_or_path is not None:
+ edited_pixel_values = batch["edited_pixel_values"].to(dtype=weight_dtype)
+ else:
+ edited_pixel_values = batch["edited_pixel_values"]
+ latents = vae.encode(edited_pixel_values).latent_dist.sample()
+ latents = latents * vae.config.scaling_factor
+ if args.pretrained_vae_model_name_or_path is None:
+ latents = latents.to(weight_dtype)
+
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(latents)
+ bsz = latents.shape[0]
+ # Sample a random timestep for each image
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
+ timesteps = timesteps.long()
+
+ # Add noise to the latents according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
+
+ # SDXL additional inputs
+ encoder_hidden_states = batch["prompt_embeds"]
+ add_text_embeds = batch["add_text_embeds"]
+
+ # Get the additional image embedding for conditioning.
+ # Instead of getting a diagonal Gaussian here, we simply take the mode.
+ if args.pretrained_vae_model_name_or_path is not None:
+ original_pixel_values = batch["original_pixel_values"].to(dtype=weight_dtype)
+ else:
+ original_pixel_values = batch["original_pixel_values"]
+ original_image_embeds = vae.encode(original_pixel_values).latent_dist.sample()
+ if args.pretrained_vae_model_name_or_path is None:
+ original_image_embeds = original_image_embeds.to(weight_dtype)
+
+ # Conditioning dropout to support classifier-free guidance during inference. For more details
+ # check out the section 3.2.1 of the original paper https://arxiv.org/abs/2211.09800.
+ if args.conditioning_dropout_prob is not None:
+ random_p = torch.rand(bsz, device=latents.device, generator=generator)
+ # Sample masks for the edit prompts.
+ prompt_mask = random_p < 2 * args.conditioning_dropout_prob
+ prompt_mask = prompt_mask.reshape(bsz, 1, 1)
+ # Final text conditioning.
+ encoder_hidden_states = torch.where(prompt_mask, null_conditioning, encoder_hidden_states)
+
+ # Sample masks for the original images.
+ image_mask_dtype = original_image_embeds.dtype
+ image_mask = 1 - (
+ (random_p >= args.conditioning_dropout_prob).to(image_mask_dtype)
+ * (random_p < 3 * args.conditioning_dropout_prob).to(image_mask_dtype)
+ )
+ image_mask = image_mask.reshape(bsz, 1, 1, 1)
+ # Final image conditioning.
+ original_image_embeds = image_mask * original_image_embeds
+
+ # Concatenate the `original_image_embeds` with the `noisy_latents`.
+ concatenated_noisy_latents = torch.cat([noisy_latents, original_image_embeds], dim=1)
+
+ # Get the target for loss depending on the prediction type
+ if noise_scheduler.config.prediction_type == "epsilon":
+ target = noise
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
+ else:
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
+
+ # Predict the noise residual and compute loss
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
+
+ model_pred = unet(
+ concatenated_noisy_latents,
+ timesteps,
+ encoder_hidden_states,
+ added_cond_kwargs=added_cond_kwargs,
+ return_dict=False,
+ )[0]
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
+
+ # Gather the losses across all processes for logging (if we use distributed training).
+ avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
+ train_loss += avg_loss.item() / args.gradient_accumulation_steps
+
+ # Backpropagate
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ if args.use_ema:
+ ema_unet.step(unet.parameters())
+ progress_bar.update(1)
+ global_step += 1
+ accelerator.log({"train_loss": train_loss}, step=global_step)
+ train_loss = 0.0
+
+ if global_step % args.checkpointing_steps == 0:
+ if accelerator.is_main_process:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+
+ ### BEGIN: Perform validation every `validation_epochs` steps
+ if global_step % args.validation_steps == 0:
+ if (args.val_image_url_or_path is not None) and (args.validation_prompt is not None):
+ # create pipeline
+ if args.use_ema:
+ # Store the UNet parameters temporarily and load the EMA parameters to perform inference.
+ ema_unet.store(unet.parameters())
+ ema_unet.copy_to(unet.parameters())
+
+ # The models need unwrapping because for compatibility in distributed training mode.
+ pipeline = StableDiffusionXLInstructPix2PixPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ unet=unwrap_model(unet),
+ text_encoder=text_encoder_1,
+ text_encoder_2=text_encoder_2,
+ tokenizer=tokenizer_1,
+ tokenizer_2=tokenizer_2,
+ vae=vae,
+ revision=args.revision,
+ variant=args.variant,
+ torch_dtype=weight_dtype,
+ )
+
+ log_validation(
+ pipeline,
+ args,
+ accelerator,
+ generator,
+ global_step,
+ is_final_validation=False,
+ )
+
+ if args.use_ema:
+ # Switch back to the original UNet parameters.
+ ema_unet.restore(unet.parameters())
+
+ del pipeline
+ torch.cuda.empty_cache()
+ ### END: Perform validation every `validation_epochs` steps
+
+ if global_step >= args.max_train_steps:
+ break
+
+ # Create the pipeline using the trained modules and save it.
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ if args.use_ema:
+ ema_unet.copy_to(unet.parameters())
+
+ pipeline = StableDiffusionXLInstructPix2PixPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ text_encoder=text_encoder_1,
+ text_encoder_2=text_encoder_2,
+ tokenizer=tokenizer_1,
+ tokenizer_2=tokenizer_2,
+ vae=vae,
+ unet=unwrap_model(unet),
+ revision=args.revision,
+ variant=args.variant,
+ )
+
+ pipeline.save_pretrained(args.output_dir)
+
+ if args.push_to_hub:
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ if (args.val_image_url_or_path is not None) and (args.validation_prompt is not None):
+ log_validation(
+ pipeline,
+ args,
+ accelerator,
+ generator,
+ global_step,
+ is_final_validation=True,
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/diffusers/examples/kandinsky2_2/text_to_image/README.md b/diffusers/examples/kandinsky2_2/text_to_image/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..6daf06dcb8c076173d576a0ed7a27effe794a61b
--- /dev/null
+++ b/diffusers/examples/kandinsky2_2/text_to_image/README.md
@@ -0,0 +1,317 @@
+# Kandinsky2.2 text-to-image fine-tuning
+
+Kandinsky 2.2 includes a prior pipeline that generates image embeddings from text prompts, and a decoder pipeline that generates the output image based on the image embeddings. We provide `train_text_to_image_prior.py` and `train_text_to_image_decoder.py` scripts to show you how to fine-tune the Kandinsky prior and decoder models separately based on your own dataset. To achieve the best results, you should fine-tune **_both_** your prior and decoder models.
+
+___Note___:
+
+___This script is experimental. The script fine-tunes the whole model and often times the model overfits and runs into issues like catastrophic forgetting. It's recommended to try different hyperparameters to get the best result on your dataset.___
+
+
+## Running locally with PyTorch
+
+Before running the scripts, make sure to install the library's training dependencies:
+
+**Important**
+
+To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
+```bash
+git clone https://github.com/huggingface/diffusers
+cd diffusers
+pip install .
+```
+
+Then cd in the example folder and run
+```bash
+pip install -r requirements.txt
+```
+
+And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
+
+```bash
+accelerate config
+```
+For this example we want to directly store the trained LoRA embeddings on the Hub, so we need to be logged in and add the --push_to_hub flag.
+
+___
+
+### Naruto example
+
+For all our examples, we will directly store the trained weights on the Hub, so we need to be logged in and add the `--push_to_hub` flag. In order to do that, you have to be a registered user on the 🤗 Hugging Face Hub, and you'll also need to use an access token for the code to work. For more information on access tokens, please refer to the [User Access Tokens](https://huggingface.co/docs/hub/security-tokens) guide.
+
+Run the following command to authenticate your token
+
+```bash
+huggingface-cli login
+```
+
+We also use [Weights and Biases](https://docs.wandb.ai/quickstart) logging by default, because it is really useful to monitor the training progress by regularly generating sample images during training. To install wandb, run
+
+```bash
+pip install wandb
+```
+
+To disable wandb logging, remove the `--report_to=="wandb"` and `--validation_prompts="A robot naruto, 4k photo"` flags from below examples
+
+#### Fine-tune decoder
+
+
+
+```bash
+export DATASET_NAME="lambdalabs/naruto-blip-captions"
+
+accelerate launch --mixed_precision="fp16" train_text_to_image_decoder.py \
+ --dataset_name=$DATASET_NAME \
+ --resolution=768 \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=4 \
+ --gradient_checkpointing \
+ --max_train_steps=15000 \
+ --learning_rate=1e-05 \
+ --max_grad_norm=1 \
+ --checkpoints_total_limit=3 \
+ --lr_scheduler="constant" --lr_warmup_steps=0 \
+ --validation_prompts="A robot naruto, 4k photo" \
+ --report_to="wandb" \
+ --push_to_hub \
+ --output_dir="kandi2-decoder-naruto-model"
+```
+
+
+
+To train on your own training files, prepare the dataset according to the format required by `datasets`. You can find the instructions for how to do that in the [ImageFolder with metadata](https://huggingface.co/docs/datasets/en/image_load#imagefolder-with-metadata) guide.
+If you wish to use custom loading logic, you should modify the script and we have left pointers for that in the training script.
+
+```bash
+export TRAIN_DIR="path_to_your_dataset"
+
+accelerate launch --mixed_precision="fp16" train_text_to_image_decoder.py \
+ --train_data_dir=$TRAIN_DIR \
+ --resolution=768 \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=4 \
+ --gradient_checkpointing \
+ --max_train_steps=15000 \
+ --learning_rate=1e-05 \
+ --max_grad_norm=1 \
+ --checkpoints_total_limit=3 \
+ --lr_scheduler="constant" --lr_warmup_steps=0 \
+ --validation_prompts="A robot naruto, 4k photo" \
+ --report_to="wandb" \
+ --push_to_hub \
+ --output_dir="kandi22-decoder-naruto-model"
+```
+
+
+Once the training is finished the model will be saved in the `output_dir` specified in the command. In this example it's `kandi22-decoder-naruto-model`. To load the fine-tuned model for inference just pass that path to `AutoPipelineForText2Image`
+
+```python
+from diffusers import AutoPipelineForText2Image
+import torch
+
+pipe = AutoPipelineForText2Image.from_pretrained(output_dir, torch_dtype=torch.float16)
+pipe.enable_model_cpu_offload()
+
+prompt='A robot naruto, 4k photo'
+images = pipe(prompt=prompt).images
+images[0].save("robot-naruto.png")
+```
+
+Checkpoints only save the unet, so to run inference from a checkpoint, just load the unet
+```python
+from diffusers import AutoPipelineForText2Image, UNet2DConditionModel
+
+model_path = "path_to_saved_model"
+
+unet = UNet2DConditionModel.from_pretrained(model_path + "/checkpoint-/unet")
+
+pipe = AutoPipelineForText2Image.from_pretrained("kandinsky-community/kandinsky-2-2-decoder", unet=unet, torch_dtype=torch.float16)
+pipe.enable_model_cpu_offload()
+
+image = pipe(prompt="A robot naruto, 4k photo").images[0]
+image.save("robot-naruto.png")
+```
+
+#### Fine-tune prior
+
+You can fine-tune the Kandinsky prior model with `train_text_to_image_prior.py` script. Note that we currently do not support `--gradient_checkpointing` for prior model fine-tuning.
+
+
+
+
+```bash
+export DATASET_NAME="lambdalabs/naruto-blip-captions"
+
+accelerate launch --mixed_precision="fp16" train_text_to_image_prior.py \
+ --dataset_name=$DATASET_NAME \
+ --resolution=768 \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=4 \
+ --max_train_steps=15000 \
+ --learning_rate=1e-05 \
+ --max_grad_norm=1 \
+ --checkpoints_total_limit=3 \
+ --lr_scheduler="constant" --lr_warmup_steps=0 \
+ --validation_prompts="A robot naruto, 4k photo" \
+ --report_to="wandb" \
+ --push_to_hub \
+ --output_dir="kandi2-prior-naruto-model"
+```
+
+
+
+To perform inference with the fine-tuned prior model, you will need to first create a prior pipeline by passing the `output_dir` to `DiffusionPipeline`. Then create a `KandinskyV22CombinedPipeline` from a pretrained or fine-tuned decoder checkpoint along with all the modules of the prior pipeline you just created.
+
+```python
+from diffusers import AutoPipelineForText2Image, DiffusionPipeline
+import torch
+
+pipe_prior = DiffusionPipeline.from_pretrained(output_dir, torch_dtype=torch.float16)
+prior_components = {"prior_" + k: v for k,v in pipe_prior.components.items()}
+pipe = AutoPipelineForText2Image.from_pretrained("kandinsky-community/kandinsky-2-2-decoder", **prior_components, torch_dtype=torch.float16)
+
+pipe.enable_model_cpu_offload()
+prompt='A robot naruto, 4k photo'
+images = pipe(prompt=prompt, negative_prompt=negative_prompt).images
+images[0]
+```
+
+If you want to use a fine-tuned decoder checkpoint along with your fine-tuned prior checkpoint, you can simply replace the "kandinsky-community/kandinsky-2-2-decoder" in above code with your custom model repo name. Note that in order to be able to create a `KandinskyV22CombinedPipeline`, your model repository need to have a prior tag. If you have created your model repo using our training script, the prior tag is automatically included.
+
+#### Training with multiple GPUs
+
+`accelerate` allows for seamless multi-GPU training. Follow the instructions [here](https://huggingface.co/docs/accelerate/basic_tutorials/launch)
+for running distributed training with `accelerate`. Here is an example command:
+
+```bash
+export DATASET_NAME="lambdalabs/naruto-blip-captions"
+
+accelerate launch --mixed_precision="fp16" --multi_gpu train_text_to_image_decoder.py \
+ --dataset_name=$DATASET_NAME \
+ --resolution=768 \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=4 \
+ --gradient_checkpointing \
+ --max_train_steps=15000 \
+ --learning_rate=1e-05 \
+ --max_grad_norm=1 \
+ --checkpoints_total_limit=3 \
+ --lr_scheduler="constant" --lr_warmup_steps=0 \
+ --validation_prompts="A robot naruto, 4k photo" \
+ --report_to="wandb" \
+ --push_to_hub \
+ --output_dir="kandi2-decoder-naruto-model"
+```
+
+
+#### Training with Min-SNR weighting
+
+We support training with the Min-SNR weighting strategy proposed in [Efficient Diffusion Training via Min-SNR Weighting Strategy](https://arxiv.org/abs/2303.09556) which helps achieve faster convergence
+by rebalancing the loss. Enable the `--snr_gamma` argument and set it to the recommended
+value of 5.0.
+
+
+## Training with LoRA
+
+Low-Rank Adaption of Large Language Models was first introduced by Microsoft in [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen*.
+
+In a nutshell, LoRA allows adapting pretrained models by adding pairs of rank-decomposition matrices to existing weights and **only** training those newly added weights. This has a couple of advantages:
+
+- Previous pretrained weights are kept frozen so that model is not prone to [catastrophic forgetting](https://www.pnas.org/doi/10.1073/pnas.1611835114).
+- Rank-decomposition matrices have significantly fewer parameters than original model, which means that trained LoRA weights are easily portable.
+- LoRA attention layers allow to control to which extent the model is adapted toward new training images via a `scale` parameter.
+
+[cloneofsimo](https://github.com/cloneofsimo) was the first to try out LoRA training for Stable Diffusion in the popular [lora](https://github.com/cloneofsimo/lora) GitHub repository.
+
+With LoRA, it's possible to fine-tune Kandinsky 2.2 on a custom image-caption pair dataset
+on consumer GPUs like Tesla T4, Tesla V100.
+
+### Training
+
+First, you need to set up your development environment as explained in the [installation](#installing-the-dependencies). Make sure to set the `MODEL_NAME` and `DATASET_NAME` environment variables. Here, we will use [Kandinsky 2.2](https://huggingface.co/kandinsky-community/kandinsky-2-2-decoder) and the [Narutos dataset](https://huggingface.co/datasets/lambdalabs/naruto-blip-captions).
+
+
+#### Train decoder
+
+```bash
+export DATASET_NAME="lambdalabs/naruto-blip-captions"
+
+accelerate launch --mixed_precision="fp16" train_text_to_image_decoder_lora.py \
+ --dataset_name=$DATASET_NAME --caption_column="text" \
+ --resolution=768 \
+ --train_batch_size=1 \
+ --num_train_epochs=100 --checkpointing_steps=5000 \
+ --learning_rate=1e-04 --lr_scheduler="constant" --lr_warmup_steps=0 \
+ --seed=42 \
+ --rank=4 \
+ --gradient_checkpointing \
+ --output_dir="kandi22-decoder-naruto-lora" \
+ --validation_prompt="cute dragon creature" --report_to="wandb" \
+ --push_to_hub \
+```
+
+#### Train prior
+
+```bash
+export DATASET_NAME="lambdalabs/naruto-blip-captions"
+
+accelerate launch --mixed_precision="fp16" train_text_to_image_prior_lora.py \
+ --dataset_name=$DATASET_NAME --caption_column="text" \
+ --resolution=768 \
+ --train_batch_size=1 \
+ --num_train_epochs=100 --checkpointing_steps=5000 \
+ --learning_rate=1e-04 --lr_scheduler="constant" --lr_warmup_steps=0 \
+ --seed=42 \
+ --rank=4 \
+ --output_dir="kandi22-prior-naruto-lora" \
+ --validation_prompt="cute dragon creature" --report_to="wandb" \
+ --push_to_hub \
+```
+
+**___Note: When using LoRA we can use a much higher learning rate compared to non-LoRA fine-tuning. Here we use *1e-4* instead of the usual *1e-5*. Also, by using LoRA, it's possible to run above scripts in consumer GPUs like T4 or V100.___**
+
+
+### Inference
+
+#### Inference using fine-tuned LoRA checkpoint for decoder
+
+Once you have trained a Kandinsky decoder model using the above command, inference can be done with the `AutoPipelineForText2Image` after loading the trained LoRA weights. You need to pass the `output_dir` for loading the LoRA weights, which in this case is `kandi22-decoder-naruto-lora`.
+
+
+```python
+from diffusers import AutoPipelineForText2Image
+import torch
+
+pipe = AutoPipelineForText2Image.from_pretrained("kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16)
+pipe.unet.load_attn_procs(output_dir)
+pipe.enable_model_cpu_offload()
+
+prompt='A robot naruto, 4k photo'
+image = pipe(prompt=prompt).images[0]
+image.save("robot_naruto.png")
+```
+
+#### Inference using fine-tuned LoRA checkpoint for prior
+
+```python
+from diffusers import AutoPipelineForText2Image
+import torch
+
+pipe = AutoPipelineForText2Image.from_pretrained("kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16)
+pipe.prior_prior.load_attn_procs(output_dir)
+pipe.enable_model_cpu_offload()
+
+prompt='A robot naruto, 4k photo'
+image = pipe(prompt=prompt).images[0]
+image.save("robot_naruto.png")
+image
+```
+
+### Training with xFormers:
+
+You can enable memory efficient attention by [installing xFormers](https://huggingface.co/docs/diffusers/main/en/optimization/xformers) and passing the `--enable_xformers_memory_efficient_attention` argument to the script.
+
+xFormers training is not available for fine-tuning the prior model.
+
+**Note**:
+
+According to [this issue](https://github.com/huggingface/diffusers/issues/2234#issuecomment-1416931212), xFormers `v0.0.16` cannot be used for training in some GPUs. If you observe that problem, please install a development version as indicated in that comment.
\ No newline at end of file
diff --git a/diffusers/examples/kandinsky2_2/text_to_image/requirements.txt b/diffusers/examples/kandinsky2_2/text_to_image/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..31b9026efdc2799b1d02e2e3f4d8dfc463737fdc
--- /dev/null
+++ b/diffusers/examples/kandinsky2_2/text_to_image/requirements.txt
@@ -0,0 +1,7 @@
+accelerate>=0.16.0
+torchvision
+transformers>=4.25.1
+datasets
+ftfy
+tensorboard
+Jinja2
diff --git a/diffusers/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py b/diffusers/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..9caa3694d636733cc179e1b4372198e4fa415988
--- /dev/null
+++ b/diffusers/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py
@@ -0,0 +1,929 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+import argparse
+import logging
+import math
+import os
+import shutil
+from pathlib import Path
+
+import accelerate
+import datasets
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.state import AcceleratorState
+from accelerate.utils import ProjectConfiguration, set_seed
+from datasets import load_dataset
+from huggingface_hub import create_repo, upload_folder
+from packaging import version
+from PIL import Image
+from tqdm import tqdm
+from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
+from transformers.utils import ContextManagers
+
+import diffusers
+from diffusers import AutoPipelineForText2Image, DDPMScheduler, UNet2DConditionModel, VQModel
+from diffusers.optimization import get_scheduler
+from diffusers.training_utils import EMAModel, compute_snr
+from diffusers.utils import check_min_version, is_wandb_available, make_image_grid
+from diffusers.utils.import_utils import is_xformers_available
+
+
+if is_wandb_available():
+ import wandb
+
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.31.0.dev0")
+
+logger = get_logger(__name__, log_level="INFO")
+
+
+def save_model_card(
+ args,
+ repo_id: str,
+ images=None,
+ repo_folder=None,
+):
+ img_str = ""
+ if len(images) > 0:
+ image_grid = make_image_grid(images, 1, len(args.validation_prompts))
+ image_grid.save(os.path.join(repo_folder, "val_imgs_grid.png"))
+ img_str += "\n"
+
+ yaml = f"""
+---
+license: creativeml-openrail-m
+base_model: {args.pretrained_decoder_model_name_or_path}
+datasets:
+- {args.dataset_name}
+prior:
+- {args.pretrained_prior_model_name_or_path}
+tags:
+- kandinsky
+- text-to-image
+- diffusers
+- diffusers-training
+inference: true
+---
+ """
+ model_card = f"""
+# Finetuning - {repo_id}
+
+This pipeline was finetuned from **{args.pretrained_decoder_model_name_or_path}** on the **{args.dataset_name}** dataset. Below are some example images generated with the finetuned pipeline using the following prompts: {args.validation_prompts}: \n
+{img_str}
+
+## Pipeline usage
+
+You can use the pipeline like so:
+
+```python
+from diffusers import DiffusionPipeline
+import torch
+
+pipeline = AutoPipelineForText2Image.from_pretrained("{repo_id}", torch_dtype=torch.float16)
+prompt = "{args.validation_prompts[0]}"
+image = pipeline(prompt).images[0]
+image.save("my_image.png")
+```
+
+## Training info
+
+These are the key hyperparameters used during training:
+
+* Epochs: {args.num_train_epochs}
+* Learning rate: {args.learning_rate}
+* Batch size: {args.train_batch_size}
+* Gradient accumulation steps: {args.gradient_accumulation_steps}
+* Image resolution: {args.resolution}
+* Mixed-precision: {args.mixed_precision}
+
+"""
+ wandb_info = ""
+ if is_wandb_available():
+ wandb_run_url = None
+ if wandb.run is not None:
+ wandb_run_url = wandb.run.url
+
+ if wandb_run_url is not None:
+ wandb_info = f"""
+More information on all the CLI arguments and the environment are available on your [`wandb` run page]({wandb_run_url}).
+"""
+
+ model_card += wandb_info
+
+ with open(os.path.join(repo_folder, "README.md"), "w") as f:
+ f.write(yaml + model_card)
+
+
+def log_validation(vae, image_encoder, image_processor, unet, args, accelerator, weight_dtype, epoch):
+ logger.info("Running validation... ")
+
+ pipeline = AutoPipelineForText2Image.from_pretrained(
+ args.pretrained_decoder_model_name_or_path,
+ vae=accelerator.unwrap_model(vae),
+ prior_image_encoder=accelerator.unwrap_model(image_encoder),
+ prior_image_processor=image_processor,
+ unet=accelerator.unwrap_model(unet),
+ torch_dtype=weight_dtype,
+ )
+ pipeline = pipeline.to(accelerator.device)
+ pipeline.set_progress_bar_config(disable=True)
+
+ if args.enable_xformers_memory_efficient_attention:
+ pipeline.enable_xformers_memory_efficient_attention()
+
+ if args.seed is None:
+ generator = None
+ else:
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
+
+ images = []
+ for i in range(len(args.validation_prompts)):
+ with torch.autocast("cuda"):
+ image = pipeline(args.validation_prompts[i], num_inference_steps=20, generator=generator).images[0]
+
+ images.append(image)
+
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
+ elif tracker.name == "wandb":
+ tracker.log(
+ {
+ "validation": [
+ wandb.Image(image, caption=f"{i}: {args.validation_prompts[i]}")
+ for i, image in enumerate(images)
+ ]
+ }
+ )
+ else:
+ logger.warning(f"image logging not implemented for {tracker.name}")
+
+ del pipeline
+ torch.cuda.empty_cache()
+
+ return images
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Simple example of finetuning Kandinsky 2.2.")
+ parser.add_argument(
+ "--pretrained_decoder_model_name_or_path",
+ type=str,
+ default="kandinsky-community/kandinsky-2-2-decoder",
+ required=False,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--pretrained_prior_model_name_or_path",
+ type=str,
+ default="kandinsky-community/kandinsky-2-2-prior",
+ required=False,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--dataset_name",
+ type=str,
+ default=None,
+ help=(
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
+ " or to a folder containing files that 🤗 Datasets can understand."
+ ),
+ )
+ parser.add_argument(
+ "--dataset_config_name",
+ type=str,
+ default=None,
+ help="The config of the Dataset, leave as None if there's only one config.",
+ )
+ parser.add_argument(
+ "--train_data_dir",
+ type=str,
+ default=None,
+ help=(
+ "A folder containing the training data. Folder contents must follow the structure described in"
+ " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
+ " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
+ ),
+ )
+ parser.add_argument(
+ "--image_column", type=str, default="image", help="The column of the dataset containing an image."
+ )
+ parser.add_argument(
+ "--max_train_samples",
+ type=int,
+ default=None,
+ help=(
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ ),
+ )
+ parser.add_argument(
+ "--validation_prompts",
+ type=str,
+ default=None,
+ nargs="+",
+ help=("A set of prompts evaluated every `--validation_epochs` and logged to `--report_to`."),
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="kandi_2_2-model-finetuned",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=1, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=100)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="learning rate",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--snr_gamma",
+ type=float,
+ default=None,
+ help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
+ "More details here: https://arxiv.org/abs/2303.09556.",
+ )
+ parser.add_argument(
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument(
+ "--adam_weight_decay",
+ type=float,
+ default=0.0,
+ required=False,
+ help="weight decay_to_use",
+ )
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
+ )
+ parser.add_argument(
+ "--validation_epochs",
+ type=int,
+ default=5,
+ help="Run validation every X epochs.",
+ )
+ parser.add_argument(
+ "--tracker_project_name",
+ type=str,
+ default="text2image-fine-tune",
+ help=(
+ "The `project_name` argument passed to Accelerator.init_trackers for"
+ " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
+ ),
+ )
+
+ args = parser.parse_args()
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ # Sanity checks
+ if args.dataset_name is None and args.train_data_dir is None:
+ raise ValueError("Need either a dataset name or a training folder.")
+
+ return args
+
+
+def main():
+ args = parse_args()
+
+ if args.report_to == "wandb" and args.hub_token is not None:
+ raise ValueError(
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
+ " Please use `huggingface-cli login` to authenticate with the Hub."
+ )
+
+ logging_dir = os.path.join(args.output_dir, args.logging_dir)
+ accelerator_project_config = ProjectConfiguration(
+ total_limit=args.checkpoints_total_limit, project_dir=args.output_dir, logging_dir=logging_dir
+ )
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ )
+
+ # Disable AMP for MPS.
+ if torch.backends.mps.is_available():
+ accelerator.native_amp = False
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ datasets.utils.logging.set_verbosity_warning()
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ datasets.utils.logging.set_verbosity_error()
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
+ ).repo_id
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_decoder_model_name_or_path, subfolder="scheduler")
+ image_processor = CLIPImageProcessor.from_pretrained(
+ args.pretrained_prior_model_name_or_path, subfolder="image_processor"
+ )
+
+ def deepspeed_zero_init_disabled_context_manager():
+ """
+ returns either a context list that includes one that will disable zero.Init or an empty context list
+ """
+ deepspeed_plugin = AcceleratorState().deepspeed_plugin if accelerate.state.is_initialized() else None
+ if deepspeed_plugin is None:
+ return []
+
+ return [deepspeed_plugin.zero3_init_context_manager(enable=False)]
+
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+ with ContextManagers(deepspeed_zero_init_disabled_context_manager()):
+ vae = VQModel.from_pretrained(
+ args.pretrained_decoder_model_name_or_path, subfolder="movq", torch_dtype=weight_dtype
+ ).eval()
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(
+ args.pretrained_prior_model_name_or_path, subfolder="image_encoder", torch_dtype=weight_dtype
+ ).eval()
+ unet = UNet2DConditionModel.from_pretrained(args.pretrained_decoder_model_name_or_path, subfolder="unet")
+
+ # Freeze vae and image_encoder
+ vae.requires_grad_(False)
+ image_encoder.requires_grad_(False)
+
+ # Set unet to trainable.
+ unet.train()
+
+ # Create EMA for the unet.
+ if args.use_ema:
+ ema_unet = UNet2DConditionModel.from_pretrained(args.pretrained_decoder_model_name_or_path, subfolder="unet")
+ ema_unet = EMAModel(ema_unet.parameters(), model_cls=UNet2DConditionModel, model_config=ema_unet.config)
+ ema_unet.to(accelerator.device)
+ if args.enable_xformers_memory_efficient_attention:
+ if is_xformers_available():
+ import xformers
+
+ xformers_version = version.parse(xformers.__version__)
+ if xformers_version == version.parse("0.0.16"):
+ logger.warning(
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
+ )
+ unet.enable_xformers_memory_efficient_attention()
+ else:
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
+
+ # `accelerate` 0.16.0 will have better support for customized saving
+ if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
+ def save_model_hook(models, weights, output_dir):
+ if args.use_ema:
+ ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema"))
+
+ for i, model in enumerate(models):
+ model.save_pretrained(os.path.join(output_dir, "unet"))
+
+ # make sure to pop weight so that corresponding model is not saved again
+ weights.pop()
+
+ def load_model_hook(models, input_dir):
+ if args.use_ema:
+ load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), UNet2DConditionModel)
+ ema_unet.load_state_dict(load_model.state_dict())
+ ema_unet.to(accelerator.device)
+ del load_model
+
+ for i in range(len(models)):
+ # pop models so that they are not loaded again
+ model = models.pop()
+
+ # load diffusers style into model
+ load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet")
+ model.register_to_config(**load_model.config)
+
+ model.load_state_dict(load_model.state_dict())
+ del load_model
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
+
+ if args.gradient_checkpointing:
+ unet.enable_gradient_checkpointing()
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
+ )
+
+ optimizer_cls = bnb.optim.AdamW8bit
+ else:
+ optimizer_cls = torch.optim.AdamW
+
+ optimizer = optimizer_cls(
+ unet.parameters(),
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ # Get the datasets: you can either provide your own training and evaluation files (see below)
+ # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
+
+ # In distributed training, the load_dataset function guarantees that only one local process can concurrently
+ # download the dataset.
+ if args.dataset_name is not None:
+ # Downloading and loading a dataset from the hub.
+ dataset = load_dataset(
+ args.dataset_name,
+ args.dataset_config_name,
+ cache_dir=args.cache_dir,
+ )
+ else:
+ data_files = {}
+ if args.train_data_dir is not None:
+ data_files["train"] = os.path.join(args.train_data_dir, "**")
+ dataset = load_dataset(
+ "imagefolder",
+ data_files=data_files,
+ cache_dir=args.cache_dir,
+ )
+ # See more about loading custom images at
+ # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
+
+ # Preprocessing the datasets.
+ # We need to tokenize inputs and targets.
+ column_names = dataset["train"].column_names
+
+ image_column = args.image_column
+ if image_column not in column_names:
+ raise ValueError(f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}")
+
+ def center_crop(image):
+ width, height = image.size
+ new_size = min(width, height)
+ left = (width - new_size) / 2
+ top = (height - new_size) / 2
+ right = (width + new_size) / 2
+ bottom = (height + new_size) / 2
+ return image.crop((left, top, right, bottom))
+
+ def train_transforms(img):
+ img = center_crop(img)
+ img = img.resize((args.resolution, args.resolution), resample=Image.BICUBIC, reducing_gap=1)
+ img = np.array(img).astype(np.float32) / 127.5 - 1
+ img = torch.from_numpy(np.transpose(img, [2, 0, 1]))
+ return img
+
+ def preprocess_train(examples):
+ images = [image.convert("RGB") for image in examples[image_column]]
+ examples["pixel_values"] = [train_transforms(image) for image in images]
+ examples["clip_pixel_values"] = image_processor(images, return_tensors="pt").pixel_values
+ return examples
+
+ with accelerator.main_process_first():
+ if args.max_train_samples is not None:
+ dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
+ # Set the training transforms
+ train_dataset = dataset["train"].with_transform(preprocess_train)
+
+ def collate_fn(examples):
+ pixel_values = torch.stack([example["pixel_values"] for example in examples])
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
+ clip_pixel_values = torch.stack([example["clip_pixel_values"] for example in examples])
+ clip_pixel_values = clip_pixel_values.to(memory_format=torch.contiguous_format).float()
+ return {"pixel_values": pixel_values, "clip_pixel_values": clip_pixel_values}
+
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset,
+ shuffle=True,
+ collate_fn=collate_fn,
+ batch_size=args.train_batch_size,
+ num_workers=args.dataloader_num_workers,
+ )
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
+ num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
+ )
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ unet, optimizer, train_dataloader, lr_scheduler
+ )
+ # Move image_encode and vae to gpu and cast to weight_dtype
+ image_encoder.to(accelerator.device, dtype=weight_dtype)
+ vae.to(accelerator.device, dtype=weight_dtype)
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ tracker_config = dict(vars(args))
+ tracker_config.pop("validation_prompts")
+ accelerator.init_trackers(args.tracker_project_name, tracker_config)
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the most recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ initial_global_step = 0
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ initial_global_step = global_step
+ first_epoch = global_step // num_update_steps_per_epoch
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
+
+ for epoch in range(first_epoch, args.num_train_epochs):
+ train_loss = 0.0
+ for step, batch in enumerate(train_dataloader):
+ with accelerator.accumulate(unet):
+ # Convert images to latent space
+ images = batch["pixel_values"].to(weight_dtype)
+ clip_images = batch["clip_pixel_values"].to(weight_dtype)
+ latents = vae.encode(images).latents
+ image_embeds = image_encoder(clip_images).image_embeds
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(latents)
+ bsz = latents.shape[0]
+ # Sample a random timestep for each image
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
+ timesteps = timesteps.long()
+
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
+
+ target = noise
+
+ # Predict the noise residual and compute loss
+ added_cond_kwargs = {"image_embeds": image_embeds}
+
+ model_pred = unet(noisy_latents, timesteps, None, added_cond_kwargs=added_cond_kwargs).sample[:, :4]
+
+ if args.snr_gamma is None:
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
+ else:
+ # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
+ # Since we predict the noise instead of x_0, the original formulation is slightly changed.
+ # This is discussed in Section 4.2 of the same paper.
+ snr = compute_snr(noise_scheduler, timesteps)
+ mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(
+ dim=1
+ )[0]
+ if noise_scheduler.config.prediction_type == "epsilon":
+ mse_loss_weights = mse_loss_weights / snr
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ mse_loss_weights = mse_loss_weights / (snr + 1)
+
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
+ loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
+ loss = loss.mean()
+
+ # Gather the losses across all processes for logging (if we use distributed training).
+ avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
+ train_loss += avg_loss.item() / args.gradient_accumulation_steps
+
+ # Backpropagate
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ if args.use_ema:
+ ema_unet.step(unet.parameters())
+ progress_bar.update(1)
+ global_step += 1
+ accelerator.log({"train_loss": train_loss}, step=global_step)
+ train_loss = 0.0
+
+ if global_step % args.checkpointing_steps == 0:
+ if accelerator.is_main_process:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ if accelerator.is_main_process:
+ if args.validation_prompts is not None and epoch % args.validation_epochs == 0:
+ if args.use_ema:
+ # Store the UNet parameters temporarily and load the EMA parameters to perform inference.
+ ema_unet.store(unet.parameters())
+ ema_unet.copy_to(unet.parameters())
+ log_validation(
+ vae,
+ image_encoder,
+ image_processor,
+ unet,
+ args,
+ accelerator,
+ weight_dtype,
+ global_step,
+ )
+ if args.use_ema:
+ # Switch back to the original UNet parameters.
+ ema_unet.restore(unet.parameters())
+
+ # Create the pipeline using the trained modules and save it.
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ unet = accelerator.unwrap_model(unet)
+ if args.use_ema:
+ ema_unet.copy_to(unet.parameters())
+
+ pipeline = AutoPipelineForText2Image.from_pretrained(
+ args.pretrained_decoder_model_name_or_path,
+ vae=vae,
+ unet=unet,
+ )
+ pipeline.decoder_pipe.save_pretrained(args.output_dir)
+
+ # Run a final round of inference.
+ images = []
+ if args.validation_prompts is not None:
+ logger.info("Running inference for collecting generated images...")
+ pipeline.torch_dtype = weight_dtype
+ pipeline.set_progress_bar_config(disable=True)
+ pipeline.enable_model_cpu_offload()
+
+ if args.enable_xformers_memory_efficient_attention:
+ pipeline.enable_xformers_memory_efficient_attention()
+
+ if args.seed is None:
+ generator = None
+ else:
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
+
+ for i in range(len(args.validation_prompts)):
+ with torch.autocast("cuda"):
+ image = pipeline(args.validation_prompts[i], num_inference_steps=20, generator=generator).images[0]
+ images.append(image)
+
+ if args.push_to_hub:
+ save_model_card(args, repo_id, images, repo_folder=args.output_dir)
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/diffusers/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py b/diffusers/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..23f5d342b396b81146ee6a8f1ada3829047abf0c
--- /dev/null
+++ b/diffusers/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py
@@ -0,0 +1,812 @@
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Fine-tuning script for Kandinsky with support for LoRA."""
+
+import argparse
+import logging
+import math
+import os
+import shutil
+from pathlib import Path
+
+import datasets
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import ProjectConfiguration, set_seed
+from datasets import load_dataset
+from huggingface_hub import create_repo, upload_folder
+from PIL import Image
+from tqdm import tqdm
+from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
+
+import diffusers
+from diffusers import AutoPipelineForText2Image, DDPMScheduler, UNet2DConditionModel, VQModel
+from diffusers.loaders import AttnProcsLayers
+from diffusers.models.attention_processor import LoRAAttnAddedKVProcessor
+from diffusers.optimization import get_scheduler
+from diffusers.training_utils import compute_snr
+from diffusers.utils import check_min_version, is_wandb_available
+
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.31.0.dev0")
+
+logger = get_logger(__name__, log_level="INFO")
+
+
+def save_model_card(repo_id: str, images=None, base_model=str, dataset_name=str, repo_folder=None):
+ img_str = ""
+ for i, image in enumerate(images):
+ image.save(os.path.join(repo_folder, f"image_{i}.png"))
+ img_str += f"\n"
+
+ yaml = f"""
+---
+license: creativeml-openrail-m
+base_model: {base_model}
+tags:
+- kandinsky
+- text-to-image
+- diffusers
+- diffusers-training
+- lora
+inference: true
+---
+ """
+ model_card = f"""
+# LoRA text2image fine-tuning - {repo_id}
+These are LoRA adaption weights for {base_model}. The weights were fine-tuned on the {dataset_name} dataset. You can find some example images in the following. \n
+{img_str}
+"""
+ with open(os.path.join(repo_folder, "README.md"), "w") as f:
+ f.write(yaml + model_card)
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Simple example of finetuning Kandinsky 2.2 with LoRA.")
+ parser.add_argument(
+ "--pretrained_decoder_model_name_or_path",
+ type=str,
+ default="kandinsky-community/kandinsky-2-2-decoder",
+ required=False,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--pretrained_prior_model_name_or_path",
+ type=str,
+ default="kandinsky-community/kandinsky-2-2-prior",
+ required=False,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--dataset_name",
+ type=str,
+ default=None,
+ help=(
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
+ " or to a folder containing files that 🤗 Datasets can understand."
+ ),
+ )
+ parser.add_argument(
+ "--dataset_config_name",
+ type=str,
+ default=None,
+ help="The config of the Dataset, leave as None if there's only one config.",
+ )
+ parser.add_argument(
+ "--train_data_dir",
+ type=str,
+ default=None,
+ help=(
+ "A folder containing the training data. Folder contents must follow the structure described in"
+ " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
+ " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
+ ),
+ )
+ parser.add_argument(
+ "--image_column", type=str, default="image", help="The column of the dataset containing an image."
+ )
+ parser.add_argument(
+ "--validation_prompt", type=str, default=None, help="A prompt that is sampled during training for inference."
+ )
+ parser.add_argument(
+ "--num_validation_images",
+ type=int,
+ default=4,
+ help="Number of images that should be generated during validation with `validation_prompt`.",
+ )
+ parser.add_argument(
+ "--validation_epochs",
+ type=int,
+ default=1,
+ help=(
+ "Run fine-tuning validation every X epochs. The validation process consists of running the prompt"
+ " `args.validation_prompt` multiple times: `args.num_validation_images`."
+ ),
+ )
+ parser.add_argument(
+ "--max_train_samples",
+ type=int,
+ default=None,
+ help=(
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ ),
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="kandi_2_2-model-finetuned-lora",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=1, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=100)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--snr_gamma",
+ type=float,
+ default=None,
+ help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
+ "More details here: https://arxiv.org/abs/2303.09556.",
+ )
+ parser.add_argument(
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=0.0, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--rank",
+ type=int,
+ default=4,
+ help=("The dimension of the LoRA update matrices."),
+ )
+
+ args = parser.parse_args()
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ # Sanity checks
+ if args.dataset_name is None and args.train_data_dir is None:
+ raise ValueError("Need either a dataset name or a training folder.")
+
+ return args
+
+
+def main():
+ args = parse_args()
+
+ if args.report_to == "wandb" and args.hub_token is not None:
+ raise ValueError(
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
+ " Please use `huggingface-cli login` to authenticate with the Hub."
+ )
+
+ logging_dir = Path(args.output_dir, args.logging_dir)
+ accelerator_project_config = ProjectConfiguration(
+ total_limit=args.checkpoints_total_limit, project_dir=args.output_dir, logging_dir=logging_dir
+ )
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ )
+
+ # Disable AMP for MPS.
+ if torch.backends.mps.is_available():
+ accelerator.native_amp = False
+
+ if args.report_to == "wandb":
+ if not is_wandb_available():
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
+ import wandb
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ datasets.utils.logging.set_verbosity_warning()
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ datasets.utils.logging.set_verbosity_error()
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
+ ).repo_id
+ # Load scheduler, tokenizer and models.
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_decoder_model_name_or_path, subfolder="scheduler")
+ image_processor = CLIPImageProcessor.from_pretrained(
+ args.pretrained_prior_model_name_or_path, subfolder="image_processor"
+ )
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(
+ args.pretrained_prior_model_name_or_path, subfolder="image_encoder"
+ )
+
+ vae = VQModel.from_pretrained(args.pretrained_decoder_model_name_or_path, subfolder="movq")
+
+ unet = UNet2DConditionModel.from_pretrained(args.pretrained_decoder_model_name_or_path, subfolder="unet")
+ # freeze parameters of models to save more memory
+ unet.requires_grad_(False)
+ vae.requires_grad_(False)
+
+ image_encoder.requires_grad_(False)
+
+ # For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision
+ # as these weights are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ # Move unet, vae and text_encoder to device and cast to weight_dtype
+ unet.to(accelerator.device, dtype=weight_dtype)
+ vae.to(accelerator.device, dtype=weight_dtype)
+ image_encoder.to(accelerator.device, dtype=weight_dtype)
+
+ lora_attn_procs = {}
+ for name in unet.attn_processors.keys():
+ cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
+ if name.startswith("mid_block"):
+ hidden_size = unet.config.block_out_channels[-1]
+ elif name.startswith("up_blocks"):
+ block_id = int(name[len("up_blocks.")])
+ hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
+ elif name.startswith("down_blocks"):
+ block_id = int(name[len("down_blocks.")])
+ hidden_size = unet.config.block_out_channels[block_id]
+
+ lora_attn_procs[name] = LoRAAttnAddedKVProcessor(
+ hidden_size=hidden_size,
+ cross_attention_dim=cross_attention_dim,
+ rank=args.rank,
+ )
+
+ unet.set_attn_processor(lora_attn_procs)
+
+ lora_layers = AttnProcsLayers(unet.attn_processors)
+
+ if args.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
+ )
+
+ optimizer_cls = bnb.optim.AdamW8bit
+ else:
+ optimizer_cls = torch.optim.AdamW
+
+ optimizer = optimizer_cls(
+ lora_layers.parameters(),
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ # Get the datasets: you can either provide your own training and evaluation files (see below)
+ # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
+
+ # In distributed training, the load_dataset function guarantees that only one local process can concurrently
+ # download the dataset.
+ if args.dataset_name is not None:
+ # Downloading and loading a dataset from the hub.
+ dataset = load_dataset(
+ args.dataset_name,
+ args.dataset_config_name,
+ cache_dir=args.cache_dir,
+ )
+ else:
+ data_files = {}
+ if args.train_data_dir is not None:
+ data_files["train"] = os.path.join(args.train_data_dir, "**")
+ dataset = load_dataset(
+ "imagefolder",
+ data_files=data_files,
+ cache_dir=args.cache_dir,
+ )
+ # See more about loading custom images at
+ # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
+
+ # Preprocessing the datasets.
+ # We need to tokenize inputs and targets.
+ column_names = dataset["train"].column_names
+
+ image_column = args.image_column
+ if image_column not in column_names:
+ raise ValueError(f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}")
+
+ def center_crop(image):
+ width, height = image.size
+ new_size = min(width, height)
+ left = (width - new_size) / 2
+ top = (height - new_size) / 2
+ right = (width + new_size) / 2
+ bottom = (height + new_size) / 2
+ return image.crop((left, top, right, bottom))
+
+ def train_transforms(img):
+ img = center_crop(img)
+ img = img.resize((args.resolution, args.resolution), resample=Image.BICUBIC, reducing_gap=1)
+ img = np.array(img).astype(np.float32) / 127.5 - 1
+ img = torch.from_numpy(np.transpose(img, [2, 0, 1]))
+ return img
+
+ def preprocess_train(examples):
+ images = [image.convert("RGB") for image in examples[image_column]]
+ examples["pixel_values"] = [train_transforms(image) for image in images]
+ examples["clip_pixel_values"] = image_processor(images, return_tensors="pt").pixel_values
+ return examples
+
+ with accelerator.main_process_first():
+ if args.max_train_samples is not None:
+ dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
+ # Set the training transforms
+ train_dataset = dataset["train"].with_transform(preprocess_train)
+
+ def collate_fn(examples):
+ pixel_values = torch.stack([example["pixel_values"] for example in examples])
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
+ clip_pixel_values = torch.stack([example["clip_pixel_values"] for example in examples])
+ clip_pixel_values = clip_pixel_values.to(memory_format=torch.contiguous_format).float()
+ return {"pixel_values": pixel_values, "clip_pixel_values": clip_pixel_values}
+
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset,
+ shuffle=True,
+ collate_fn=collate_fn,
+ batch_size=args.train_batch_size,
+ num_workers=args.dataloader_num_workers,
+ )
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
+ num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
+ )
+ # Prepare everything with our `accelerator`.
+ lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ lora_layers, optimizer, train_dataloader, lr_scheduler
+ )
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ accelerator.init_trackers("text2image-fine-tune", config=vars(args))
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the most recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ initial_global_step = 0
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ initial_global_step = global_step
+ first_epoch = global_step // num_update_steps_per_epoch
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
+
+ for epoch in range(first_epoch, args.num_train_epochs):
+ unet.train()
+ train_loss = 0.0
+ for step, batch in enumerate(train_dataloader):
+ with accelerator.accumulate(unet):
+ # Convert images to latent space
+ images = batch["pixel_values"].to(weight_dtype)
+ clip_images = batch["clip_pixel_values"].to(weight_dtype)
+ latents = vae.encode(images).latents
+ image_embeds = image_encoder(clip_images).image_embeds
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(latents)
+ bsz = latents.shape[0]
+ # Sample a random timestep for each image
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
+ timesteps = timesteps.long()
+
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
+
+ target = noise
+
+ # Predict the noise residual and compute loss
+ added_cond_kwargs = {"image_embeds": image_embeds}
+
+ model_pred = unet(noisy_latents, timesteps, None, added_cond_kwargs=added_cond_kwargs).sample[:, :4]
+
+ if args.snr_gamma is None:
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
+ else:
+ # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
+ # Since we predict the noise instead of x_0, the original formulation is slightly changed.
+ # This is discussed in Section 4.2 of the same paper.
+ snr = compute_snr(noise_scheduler, timesteps)
+ mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(
+ dim=1
+ )[0]
+ if noise_scheduler.config.prediction_type == "epsilon":
+ mse_loss_weights = mse_loss_weights / snr
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ mse_loss_weights = mse_loss_weights / (snr + 1)
+
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
+ loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
+ loss = loss.mean()
+
+ # Gather the losses across all processes for logging (if we use distributed training).
+ avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
+ train_loss += avg_loss.item() / args.gradient_accumulation_steps
+
+ # Backpropagate
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ params_to_clip = lora_layers.parameters()
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ global_step += 1
+ accelerator.log({"train_loss": train_loss}, step=global_step)
+ train_loss = 0.0
+
+ if global_step % args.checkpointing_steps == 0:
+ if accelerator.is_main_process:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ if accelerator.is_main_process:
+ if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
+ logger.info(
+ f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
+ f" {args.validation_prompt}."
+ )
+ # create pipeline
+ pipeline = AutoPipelineForText2Image.from_pretrained(
+ args.pretrained_decoder_model_name_or_path,
+ unet=accelerator.unwrap_model(unet),
+ torch_dtype=weight_dtype,
+ )
+ pipeline = pipeline.to(accelerator.device)
+ pipeline.set_progress_bar_config(disable=True)
+
+ # run inference
+ generator = torch.Generator(device=accelerator.device)
+ if args.seed is not None:
+ generator = generator.manual_seed(args.seed)
+ images = []
+ for _ in range(args.num_validation_images):
+ images.append(
+ pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]
+ )
+
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
+ if tracker.name == "wandb":
+ tracker.log(
+ {
+ "validation": [
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
+ for i, image in enumerate(images)
+ ]
+ }
+ )
+
+ del pipeline
+ torch.cuda.empty_cache()
+
+ # Save the lora layers
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ unet = unet.to(torch.float32)
+ unet.save_attn_procs(args.output_dir)
+
+ if args.push_to_hub:
+ save_model_card(
+ repo_id,
+ images=images,
+ base_model=args.pretrained_decoder_model_name_or_path,
+ dataset_name=args.dataset_name,
+ repo_folder=args.output_dir,
+ )
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ # Final inference
+ # Load previous pipeline
+ pipeline = AutoPipelineForText2Image.from_pretrained(
+ args.pretrained_decoder_model_name_or_path, torch_dtype=weight_dtype
+ )
+ pipeline = pipeline.to(accelerator.device)
+
+ # load attention processors
+ pipeline.unet.load_attn_procs(args.output_dir)
+
+ # run inference
+ generator = torch.Generator(device=accelerator.device)
+ if args.seed is not None:
+ generator = generator.manual_seed(args.seed)
+ images = []
+ for _ in range(args.num_validation_images):
+ images.append(pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0])
+
+ if accelerator.is_main_process:
+ for tracker in accelerator.trackers:
+ if len(images) != 0:
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC")
+ if tracker.name == "wandb":
+ tracker.log(
+ {
+ "test": [
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
+ for i, image in enumerate(images)
+ ]
+ }
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/diffusers/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py b/diffusers/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py
new file mode 100644
index 0000000000000000000000000000000000000000..6ed3377db131a5c35219a644488a8a2900813e36
--- /dev/null
+++ b/diffusers/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py
@@ -0,0 +1,844 @@
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Fine-tuning script for Stable Diffusion for text2image with support for LoRA."""
+
+import argparse
+import logging
+import math
+import os
+import random
+import shutil
+from pathlib import Path
+
+import datasets
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import ProjectConfiguration, set_seed
+from datasets import load_dataset
+from huggingface_hub import create_repo, upload_folder
+from tqdm import tqdm
+from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection
+
+import diffusers
+from diffusers import AutoPipelineForText2Image, DDPMScheduler, PriorTransformer
+from diffusers.loaders import AttnProcsLayers
+from diffusers.models.attention_processor import LoRAAttnProcessor
+from diffusers.optimization import get_scheduler
+from diffusers.training_utils import compute_snr
+from diffusers.utils import check_min_version, is_wandb_available
+
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.31.0.dev0")
+
+logger = get_logger(__name__, log_level="INFO")
+
+
+def save_model_card(repo_id: str, images=None, base_model=str, dataset_name=str, repo_folder=None):
+ img_str = ""
+ for i, image in enumerate(images):
+ image.save(os.path.join(repo_folder, f"image_{i}.png"))
+ img_str += f"\n"
+
+ yaml = f"""
+---
+license: creativeml-openrail-m
+base_model: {base_model}
+tags:
+- kandinsky
+- text-to-image
+- diffusers
+- diffusers-training
+- lora
+inference: true
+---
+ """
+ model_card = f"""
+# LoRA text2image fine-tuning - {repo_id}
+These are LoRA adaption weights for {base_model}. The weights were fine-tuned on the {dataset_name} dataset. You can find some example images in the following. \n
+{img_str}
+"""
+ with open(os.path.join(repo_folder, "README.md"), "w") as f:
+ f.write(yaml + model_card)
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Simple example of finetuning Kandinsky 2.2.")
+ parser.add_argument(
+ "--pretrained_decoder_model_name_or_path",
+ type=str,
+ default="kandinsky-community/kandinsky-2-2-decoder",
+ required=False,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--pretrained_prior_model_name_or_path",
+ type=str,
+ default="kandinsky-community/kandinsky-2-2-prior",
+ required=False,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--dataset_name",
+ type=str,
+ default=None,
+ help=(
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
+ " or to a folder containing files that 🤗 Datasets can understand."
+ ),
+ )
+ parser.add_argument(
+ "--dataset_config_name",
+ type=str,
+ default=None,
+ help="The config of the Dataset, leave as None if there's only one config.",
+ )
+ parser.add_argument(
+ "--train_data_dir",
+ type=str,
+ default=None,
+ help=(
+ "A folder containing the training data. Folder contents must follow the structure described in"
+ " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
+ " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
+ ),
+ )
+ parser.add_argument(
+ "--image_column", type=str, default="image", help="The column of the dataset containing an image."
+ )
+ parser.add_argument(
+ "--caption_column",
+ type=str,
+ default="text",
+ help="The column of the dataset containing a caption or a list of captions.",
+ )
+ parser.add_argument(
+ "--validation_prompt", type=str, default=None, help="A prompt that is sampled during training for inference."
+ )
+ parser.add_argument(
+ "--num_validation_images",
+ type=int,
+ default=4,
+ help="Number of images that should be generated during validation with `validation_prompt`.",
+ )
+ parser.add_argument(
+ "--validation_epochs",
+ type=int,
+ default=1,
+ help=(
+ "Run fine-tuning validation every X epochs. The validation process consists of running the prompt"
+ " `args.validation_prompt` multiple times: `args.num_validation_images`."
+ ),
+ )
+ parser.add_argument(
+ "--max_train_samples",
+ type=int,
+ default=None,
+ help=(
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ ),
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="kandi_2_2-model-finetuned-lora",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=1, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=100)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="learning rate",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--snr_gamma",
+ type=float,
+ default=None,
+ help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
+ "More details here: https://arxiv.org/abs/2303.09556.",
+ )
+ parser.add_argument(
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument(
+ "--adam_weight_decay",
+ type=float,
+ default=0.0,
+ required=False,
+ help="weight decay_to_use",
+ )
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--rank",
+ type=int,
+ default=4,
+ help=("The dimension of the LoRA update matrices."),
+ )
+
+ args = parser.parse_args()
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ # Sanity checks
+ if args.dataset_name is None and args.train_data_dir is None:
+ raise ValueError("Need either a dataset name or a training folder.")
+
+ return args
+
+
+DATASET_NAME_MAPPING = {
+ "lambdalabs/naruto-blip-captions": ("image", "text"),
+}
+
+
+def main():
+ args = parse_args()
+
+ if args.report_to == "wandb" and args.hub_token is not None:
+ raise ValueError(
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
+ " Please use `huggingface-cli login` to authenticate with the Hub."
+ )
+
+ logging_dir = Path(args.output_dir, args.logging_dir)
+
+ accelerator_project_config = ProjectConfiguration(
+ total_limit=args.checkpoints_total_limit, project_dir=args.output_dir, logging_dir=logging_dir
+ )
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ )
+
+ # Disable AMP for MPS.
+ if torch.backends.mps.is_available():
+ accelerator.native_amp = False
+
+ if args.report_to == "wandb":
+ if not is_wandb_available():
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
+ import wandb
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ datasets.utils.logging.set_verbosity_warning()
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ datasets.utils.logging.set_verbosity_error()
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
+ ).repo_id
+ # Load scheduler, image_processor, tokenizer and models.
+ noise_scheduler = DDPMScheduler(beta_schedule="squaredcos_cap_v2", prediction_type="sample")
+ image_processor = CLIPImageProcessor.from_pretrained(
+ args.pretrained_prior_model_name_or_path, subfolder="image_processor"
+ )
+ tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_prior_model_name_or_path, subfolder="tokenizer")
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(
+ args.pretrained_prior_model_name_or_path, subfolder="image_encoder"
+ )
+ text_encoder = CLIPTextModelWithProjection.from_pretrained(
+ args.pretrained_prior_model_name_or_path, subfolder="text_encoder"
+ )
+ prior = PriorTransformer.from_pretrained(args.pretrained_prior_model_name_or_path, subfolder="prior")
+ # freeze parameters of models to save more memory
+ image_encoder.requires_grad_(False)
+ prior.requires_grad_(False)
+ text_encoder.requires_grad_(False)
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ # Move image_encoder, text_encoder and prior to device and cast to weight_dtype
+ prior.to(accelerator.device, dtype=weight_dtype)
+ image_encoder.to(accelerator.device, dtype=weight_dtype)
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
+ lora_attn_procs = {}
+ for name in prior.attn_processors.keys():
+ lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=2048, rank=args.rank)
+
+ prior.set_attn_processor(lora_attn_procs)
+ lora_layers = AttnProcsLayers(prior.attn_processors)
+
+ if args.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
+ )
+
+ optimizer_cls = bnb.optim.AdamW8bit
+ else:
+ optimizer_cls = torch.optim.AdamW
+
+ optimizer = optimizer_cls(
+ lora_layers.parameters(),
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ # Get the datasets: you can either provide your own training and evaluation files (see below)
+ # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
+
+ # In distributed training, the load_dataset function guarantees that only one local process can concurrently
+ # download the dataset.
+ if args.dataset_name is not None:
+ # Downloading and loading a dataset from the hub.
+ dataset = load_dataset(
+ args.dataset_name,
+ args.dataset_config_name,
+ cache_dir=args.cache_dir,
+ )
+ else:
+ data_files = {}
+ if args.train_data_dir is not None:
+ data_files["train"] = os.path.join(args.train_data_dir, "**")
+ dataset = load_dataset(
+ "imagefolder",
+ data_files=data_files,
+ cache_dir=args.cache_dir,
+ )
+ # See more about loading custom images at
+ # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
+
+ # Preprocessing the datasets.
+ # We need to tokenize inputs and targets.
+ column_names = dataset["train"].column_names
+
+ # 6. Get the column names for input/target.
+ dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None)
+ if args.image_column is None:
+ image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
+ else:
+ image_column = args.image_column
+ if image_column not in column_names:
+ raise ValueError(
+ f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}"
+ )
+ if args.caption_column is None:
+ caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
+ else:
+ caption_column = args.caption_column
+ if caption_column not in column_names:
+ raise ValueError(
+ f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}"
+ )
+
+ # Preprocessing the datasets.
+ # We need to tokenize input captions and transform the images.
+ def tokenize_captions(examples, is_train=True):
+ captions = []
+ for caption in examples[caption_column]:
+ if isinstance(caption, str):
+ captions.append(caption)
+ elif isinstance(caption, (list, np.ndarray)):
+ # take a random caption if there are multiple
+ captions.append(random.choice(caption) if is_train else caption[0])
+ else:
+ raise ValueError(
+ f"Caption column `{caption_column}` should contain either strings or lists of strings."
+ )
+ inputs = tokenizer(
+ captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
+ )
+ text_input_ids = inputs.input_ids
+ text_mask = inputs.attention_mask.bool()
+ return text_input_ids, text_mask
+
+ def preprocess_train(examples):
+ images = [image.convert("RGB") for image in examples[image_column]]
+ examples["clip_pixel_values"] = image_processor(images, return_tensors="pt").pixel_values
+ examples["text_input_ids"], examples["text_mask"] = tokenize_captions(examples)
+ return examples
+
+ with accelerator.main_process_first():
+ if args.max_train_samples is not None:
+ dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
+ # Set the training transforms
+ train_dataset = dataset["train"].with_transform(preprocess_train)
+
+ def collate_fn(examples):
+ clip_pixel_values = torch.stack([example["clip_pixel_values"] for example in examples])
+ clip_pixel_values = clip_pixel_values.to(memory_format=torch.contiguous_format).float()
+ text_input_ids = torch.stack([example["text_input_ids"] for example in examples])
+ text_mask = torch.stack([example["text_mask"] for example in examples])
+ return {"clip_pixel_values": clip_pixel_values, "text_input_ids": text_input_ids, "text_mask": text_mask}
+
+ # DataLoaders creation:
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset,
+ shuffle=True,
+ collate_fn=collate_fn,
+ batch_size=args.train_batch_size,
+ num_workers=args.dataloader_num_workers,
+ )
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
+ num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
+ )
+ clip_mean = prior.clip_mean.clone()
+ clip_std = prior.clip_std.clone()
+ lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ lora_layers, optimizer, train_dataloader, lr_scheduler
+ )
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ accelerator.init_trackers("text2image-fine-tune", config=vars(args))
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the most recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ initial_global_step = 0
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ initial_global_step = global_step
+ first_epoch = global_step // num_update_steps_per_epoch
+
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
+
+ clip_mean = clip_mean.to(weight_dtype).to(accelerator.device)
+ clip_std = clip_std.to(weight_dtype).to(accelerator.device)
+
+ for epoch in range(first_epoch, args.num_train_epochs):
+ prior.train()
+ train_loss = 0.0
+ for step, batch in enumerate(train_dataloader):
+ with accelerator.accumulate(prior):
+ # Convert images to latent space
+ text_input_ids, text_mask, clip_images = (
+ batch["text_input_ids"],
+ batch["text_mask"],
+ batch["clip_pixel_values"].to(weight_dtype),
+ )
+ with torch.no_grad():
+ text_encoder_output = text_encoder(text_input_ids)
+ prompt_embeds = text_encoder_output.text_embeds
+ text_encoder_hidden_states = text_encoder_output.last_hidden_state
+
+ image_embeds = image_encoder(clip_images).image_embeds
+ # Sample noise that we'll add to the image_embeds
+ noise = torch.randn_like(image_embeds)
+ bsz = image_embeds.shape[0]
+ # Sample a random timestep for each image
+ timesteps = torch.randint(
+ 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=image_embeds.device
+ )
+ timesteps = timesteps.long()
+ image_embeds = (image_embeds - clip_mean) / clip_std
+ noisy_latents = noise_scheduler.add_noise(image_embeds, noise, timesteps)
+
+ target = image_embeds
+
+ # Predict the noise residual and compute loss
+ model_pred = prior(
+ noisy_latents,
+ timestep=timesteps,
+ proj_embedding=prompt_embeds,
+ encoder_hidden_states=text_encoder_hidden_states,
+ attention_mask=text_mask,
+ ).predicted_image_embedding
+
+ if args.snr_gamma is None:
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
+ else:
+ # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
+ # Since we predict the noise instead of x_0, the original formulation is slightly changed.
+ # This is discussed in Section 4.2 of the same paper.
+ snr = compute_snr(noise_scheduler, timesteps)
+ mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(
+ dim=1
+ )[0]
+ if noise_scheduler.config.prediction_type == "epsilon":
+ mse_loss_weights = mse_loss_weights / snr
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ mse_loss_weights = mse_loss_weights / (snr + 1)
+
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
+ loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
+ loss = loss.mean()
+
+ # Gather the losses across all processes for logging (if we use distributed training).
+ avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
+ train_loss += avg_loss.item() / args.gradient_accumulation_steps
+
+ # Backpropagate
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ accelerator.clip_grad_norm_(lora_layers.parameters(), args.max_grad_norm)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ global_step += 1
+ accelerator.log({"train_loss": train_loss}, step=global_step)
+ train_loss = 0.0
+
+ if global_step % args.checkpointing_steps == 0:
+ if accelerator.is_main_process:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ if accelerator.is_main_process:
+ if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
+ logger.info(
+ f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
+ f" {args.validation_prompt}."
+ )
+ # create pipeline
+ pipeline = AutoPipelineForText2Image.from_pretrained(
+ args.pretrained_decoder_model_name_or_path,
+ prior_prior=accelerator.unwrap_model(prior),
+ torch_dtype=weight_dtype,
+ )
+ pipeline = pipeline.to(accelerator.device)
+ pipeline.set_progress_bar_config(disable=True)
+
+ # run inference
+ generator = torch.Generator(device=accelerator.device)
+ if args.seed is not None:
+ generator = generator.manual_seed(args.seed)
+ images = []
+ for _ in range(args.num_validation_images):
+ images.append(
+ pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]
+ )
+
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
+ if tracker.name == "wandb":
+ tracker.log(
+ {
+ "validation": [
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
+ for i, image in enumerate(images)
+ ]
+ }
+ )
+
+ del pipeline
+ torch.cuda.empty_cache()
+
+ # Save the lora layers
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ prior = prior.to(torch.float32)
+ prior.save_attn_procs(args.output_dir)
+
+ if args.push_to_hub:
+ save_model_card(
+ repo_id,
+ images=images,
+ base_model=args.pretrained_prior_model_name_or_path,
+ dataset_name=args.dataset_name,
+ repo_folder=args.output_dir,
+ )
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ # Final inference
+ # Load previous pipeline
+ pipeline = AutoPipelineForText2Image.from_pretrained(
+ args.pretrained_decoder_model_name_or_path, torch_dtype=weight_dtype
+ )
+ pipeline = pipeline.to(accelerator.device)
+
+ # load attention processors
+ pipeline.prior_prior.load_attn_procs(args.output_dir)
+
+ # run inference
+ generator = torch.Generator(device=accelerator.device)
+ if args.seed is not None:
+ generator = generator.manual_seed(args.seed)
+ images = []
+ for _ in range(args.num_validation_images):
+ images.append(pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0])
+
+ if accelerator.is_main_process:
+ for tracker in accelerator.trackers:
+ if len(images) != 0:
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC")
+ if tracker.name == "wandb":
+ tracker.log(
+ {
+ "test": [
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
+ for i, image in enumerate(images)
+ ]
+ }
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/diffusers/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py b/diffusers/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py
new file mode 100644
index 0000000000000000000000000000000000000000..448429444448820e18c646f08857bf1da7bdcb95
--- /dev/null
+++ b/diffusers/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py
@@ -0,0 +1,958 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+import argparse
+import logging
+import math
+import os
+import random
+import shutil
+from pathlib import Path
+
+import accelerate
+import datasets
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.state import AcceleratorState
+from accelerate.utils import ProjectConfiguration, set_seed
+from datasets import load_dataset
+from huggingface_hub import create_repo, upload_folder
+from packaging import version
+from tqdm import tqdm
+from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection
+from transformers.utils import ContextManagers
+
+import diffusers
+from diffusers import AutoPipelineForText2Image, DDPMScheduler, PriorTransformer
+from diffusers.optimization import get_scheduler
+from diffusers.training_utils import EMAModel, compute_snr
+from diffusers.utils import check_min_version, is_wandb_available, make_image_grid
+
+
+if is_wandb_available():
+ import wandb
+
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.31.0.dev0")
+
+logger = get_logger(__name__, log_level="INFO")
+
+DATASET_NAME_MAPPING = {
+ "lambdalabs/naruto-blip-captions": ("image", "text"),
+}
+
+
+def save_model_card(
+ args,
+ repo_id: str,
+ images=None,
+ repo_folder=None,
+):
+ img_str = ""
+ if len(images) > 0:
+ image_grid = make_image_grid(images, 1, len(args.validation_prompts))
+ image_grid.save(os.path.join(repo_folder, "val_imgs_grid.png"))
+ img_str += "\n"
+
+ yaml = f"""
+---
+license: creativeml-openrail-m
+base_model: {args.pretrained_prior_model_name_or_path}
+datasets:
+- {args.dataset_name}
+tags:
+- kandinsky
+- text-to-image
+- diffusers
+- diffusers-training
+inference: true
+---
+ """
+ model_card = f"""
+# Finetuning - {repo_id}
+
+This pipeline was finetuned from **{args.pretrained_prior_model_name_or_path}** on the **{args.dataset_name}** dataset. Below are some example images generated with the finetuned pipeline using the following prompts: {args.validation_prompts}: \n
+{img_str}
+
+## Pipeline usage
+
+You can use the pipeline like so:
+
+```python
+from diffusers import DiffusionPipeline
+import torch
+
+pipe_prior = DiffusionPipeline.from_pretrained("{repo_id}", torch_dtype=torch.float16)
+pipe_t2i = DiffusionPipeline.from_pretrained("{args.pretrained_decoder_model_name_or_path}", torch_dtype=torch.float16)
+prompt = "{args.validation_prompts[0]}"
+image_embeds, negative_image_embeds = pipe_prior(prompt, guidance_scale=1.0).to_tuple()
+image = pipe_t2i(image_embeds=image_embeds, negative_image_embeds=negative_image_embeds).images[0]
+image.save("my_image.png")
+```
+
+## Training info
+
+These are the key hyperparameters used during training:
+
+* Epochs: {args.num_train_epochs}
+* Learning rate: {args.learning_rate}
+* Batch size: {args.train_batch_size}
+* Gradient accumulation steps: {args.gradient_accumulation_steps}
+* Image resolution: {args.resolution}
+* Mixed-precision: {args.mixed_precision}
+
+"""
+ wandb_info = ""
+ if is_wandb_available():
+ wandb_run_url = None
+ if wandb.run is not None:
+ wandb_run_url = wandb.run.url
+
+ if wandb_run_url is not None:
+ wandb_info = f"""
+More information on all the CLI arguments and the environment are available on your [`wandb` run page]({wandb_run_url}).
+"""
+
+ model_card += wandb_info
+
+ with open(os.path.join(repo_folder, "README.md"), "w") as f:
+ f.write(yaml + model_card)
+
+
+def log_validation(
+ image_encoder, image_processor, text_encoder, tokenizer, prior, args, accelerator, weight_dtype, epoch
+):
+ logger.info("Running validation... ")
+
+ pipeline = AutoPipelineForText2Image.from_pretrained(
+ args.pretrained_decoder_model_name_or_path,
+ prior_image_encoder=accelerator.unwrap_model(image_encoder),
+ prior_image_processor=image_processor,
+ prior_text_encoder=accelerator.unwrap_model(text_encoder),
+ prior_tokenizer=tokenizer,
+ prior_prior=accelerator.unwrap_model(prior),
+ torch_dtype=weight_dtype,
+ )
+ pipeline = pipeline.to(accelerator.device)
+ pipeline.set_progress_bar_config(disable=True)
+
+ if args.seed is None:
+ generator = None
+ else:
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
+
+ images = []
+ for i in range(len(args.validation_prompts)):
+ with torch.autocast("cuda"):
+ image = pipeline(args.validation_prompts[i], num_inference_steps=20, generator=generator).images[0]
+
+ images.append(image)
+
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
+ elif tracker.name == "wandb":
+ tracker.log(
+ {
+ "validation": [
+ wandb.Image(image, caption=f"{i}: {args.validation_prompts[i]}")
+ for i, image in enumerate(images)
+ ]
+ }
+ )
+ else:
+ logger.warning(f"image logging not implemented for {tracker.name}")
+
+ del pipeline
+ torch.cuda.empty_cache()
+
+ return images
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Simple example of finetuning Kandinsky 2.2.")
+ parser.add_argument(
+ "--pretrained_decoder_model_name_or_path",
+ type=str,
+ default="kandinsky-community/kandinsky-2-2-decoder",
+ required=False,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--pretrained_prior_model_name_or_path",
+ type=str,
+ default="kandinsky-community/kandinsky-2-2-prior",
+ required=False,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--dataset_name",
+ type=str,
+ default=None,
+ help=(
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
+ " or to a folder containing files that 🤗 Datasets can understand."
+ ),
+ )
+ parser.add_argument(
+ "--dataset_config_name",
+ type=str,
+ default=None,
+ help="The config of the Dataset, leave as None if there's only one config.",
+ )
+ parser.add_argument(
+ "--train_data_dir",
+ type=str,
+ default=None,
+ help=(
+ "A folder containing the training data. Folder contents must follow the structure described in"
+ " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
+ " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
+ ),
+ )
+ parser.add_argument(
+ "--image_column", type=str, default="image", help="The column of the dataset containing an image."
+ )
+ parser.add_argument(
+ "--caption_column",
+ type=str,
+ default="text",
+ help="The column of the dataset containing a caption or a list of captions.",
+ )
+ parser.add_argument(
+ "--max_train_samples",
+ type=int,
+ default=None,
+ help=(
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ ),
+ )
+ parser.add_argument(
+ "--validation_prompts",
+ type=str,
+ default=None,
+ nargs="+",
+ help=("A set of prompts evaluated every `--validation_epochs` and logged to `--report_to`."),
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="kandi_2_2-model-finetuned",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=1, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=100)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="learning rate",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--snr_gamma",
+ type=float,
+ default=None,
+ help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
+ "More details here: https://arxiv.org/abs/2303.09556.",
+ )
+ parser.add_argument(
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument(
+ "--adam_weight_decay",
+ type=float,
+ default=0.0,
+ required=False,
+ help="weight decay_to_use",
+ )
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--validation_epochs",
+ type=int,
+ default=5,
+ help="Run validation every X epochs.",
+ )
+ parser.add_argument(
+ "--tracker_project_name",
+ type=str,
+ default="text2image-fine-tune",
+ help=(
+ "The `project_name` argument passed to Accelerator.init_trackers for"
+ " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
+ ),
+ )
+
+ args = parser.parse_args()
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ # Sanity checks
+ if args.dataset_name is None and args.train_data_dir is None:
+ raise ValueError("Need either a dataset name or a training folder.")
+
+ return args
+
+
+def main():
+ args = parse_args()
+
+ if args.report_to == "wandb" and args.hub_token is not None:
+ raise ValueError(
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
+ " Please use `huggingface-cli login` to authenticate with the Hub."
+ )
+
+ logging_dir = os.path.join(args.output_dir, args.logging_dir)
+ accelerator_project_config = ProjectConfiguration(
+ total_limit=args.checkpoints_total_limit, project_dir=args.output_dir, logging_dir=logging_dir
+ )
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ )
+
+ # Disable AMP for MPS.
+ if torch.backends.mps.is_available():
+ accelerator.native_amp = False
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ datasets.utils.logging.set_verbosity_warning()
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ datasets.utils.logging.set_verbosity_error()
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
+ ).repo_id
+
+ # Load scheduler, image_processor, tokenizer and models.
+ noise_scheduler = DDPMScheduler(beta_schedule="squaredcos_cap_v2", prediction_type="sample")
+ image_processor = CLIPImageProcessor.from_pretrained(
+ args.pretrained_prior_model_name_or_path, subfolder="image_processor"
+ )
+ tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_prior_model_name_or_path, subfolder="tokenizer")
+
+ def deepspeed_zero_init_disabled_context_manager():
+ """
+ returns either a context list that includes one that will disable zero.Init or an empty context list
+ """
+ deepspeed_plugin = AcceleratorState().deepspeed_plugin if accelerate.state.is_initialized() else None
+ if deepspeed_plugin is None:
+ return []
+
+ return [deepspeed_plugin.zero3_init_context_manager(enable=False)]
+
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+ with ContextManagers(deepspeed_zero_init_disabled_context_manager()):
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(
+ args.pretrained_prior_model_name_or_path, subfolder="image_encoder", torch_dtype=weight_dtype
+ ).eval()
+ text_encoder = CLIPTextModelWithProjection.from_pretrained(
+ args.pretrained_prior_model_name_or_path, subfolder="text_encoder", torch_dtype=weight_dtype
+ ).eval()
+
+ prior = PriorTransformer.from_pretrained(args.pretrained_prior_model_name_or_path, subfolder="prior")
+
+ # Freeze text_encoder and image_encoder
+ text_encoder.requires_grad_(False)
+ image_encoder.requires_grad_(False)
+
+ # Set prior to trainable.
+ prior.train()
+
+ # Create EMA for the prior.
+ if args.use_ema:
+ ema_prior = PriorTransformer.from_pretrained(args.pretrained_prior_model_name_or_path, subfolder="prior")
+ ema_prior = EMAModel(ema_prior.parameters(), model_cls=PriorTransformer, model_config=ema_prior.config)
+ ema_prior.to(accelerator.device)
+
+ # `accelerate` 0.16.0 will have better support for customized saving
+ if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
+ def save_model_hook(models, weights, output_dir):
+ if args.use_ema:
+ ema_prior.save_pretrained(os.path.join(output_dir, "prior_ema"))
+
+ for i, model in enumerate(models):
+ model.save_pretrained(os.path.join(output_dir, "prior"))
+
+ # make sure to pop weight so that corresponding model is not saved again
+ weights.pop()
+
+ def load_model_hook(models, input_dir):
+ if args.use_ema:
+ load_model = EMAModel.from_pretrained(os.path.join(input_dir, "prior_ema"), PriorTransformer)
+ ema_prior.load_state_dict(load_model.state_dict())
+ ema_prior.to(accelerator.device)
+ del load_model
+
+ for i in range(len(models)):
+ # pop models so that they are not loaded again
+ model = models.pop()
+
+ # load diffusers style into model
+ load_model = PriorTransformer.from_pretrained(input_dir, subfolder="prior")
+ model.register_to_config(**load_model.config)
+
+ model.load_state_dict(load_model.state_dict())
+ del load_model
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
+
+ if args.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
+ )
+
+ optimizer_cls = bnb.optim.AdamW8bit
+ else:
+ optimizer_cls = torch.optim.AdamW
+ optimizer = optimizer_cls(
+ prior.parameters(),
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ # Get the datasets: you can either provide your own training and evaluation files (see below)
+ # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
+
+ # In distributed training, the load_dataset function guarantees that only one local process can concurrently
+ # download the dataset.
+ if args.dataset_name is not None:
+ # Downloading and loading a dataset from the hub.
+ dataset = load_dataset(
+ args.dataset_name,
+ args.dataset_config_name,
+ cache_dir=args.cache_dir,
+ )
+ else:
+ data_files = {}
+ if args.train_data_dir is not None:
+ data_files["train"] = os.path.join(args.train_data_dir, "**")
+ dataset = load_dataset(
+ "imagefolder",
+ data_files=data_files,
+ cache_dir=args.cache_dir,
+ )
+ # See more about loading custom images at
+ # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
+
+ # Preprocessing the datasets.
+ # We need to tokenize inputs and targets.
+ column_names = dataset["train"].column_names
+
+ # 6. Get the column names for input/target.
+ dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None)
+ if args.image_column is None:
+ image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
+ else:
+ image_column = args.image_column
+ if image_column not in column_names:
+ raise ValueError(
+ f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}"
+ )
+ if args.caption_column is None:
+ caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
+ else:
+ caption_column = args.caption_column
+ if caption_column not in column_names:
+ raise ValueError(
+ f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}"
+ )
+
+ # Preprocessing the datasets.
+ # We need to tokenize input captions and transform the images.
+ def tokenize_captions(examples, is_train=True):
+ captions = []
+ for caption in examples[caption_column]:
+ if isinstance(caption, str):
+ captions.append(caption)
+ elif isinstance(caption, (list, np.ndarray)):
+ # take a random caption if there are multiple
+ captions.append(random.choice(caption) if is_train else caption[0])
+ else:
+ raise ValueError(
+ f"Caption column `{caption_column}` should contain either strings or lists of strings."
+ )
+ inputs = tokenizer(
+ captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
+ )
+ text_input_ids = inputs.input_ids
+ text_mask = inputs.attention_mask.bool()
+ return text_input_ids, text_mask
+
+ def preprocess_train(examples):
+ images = [image.convert("RGB") for image in examples[image_column]]
+ examples["clip_pixel_values"] = image_processor(images, return_tensors="pt").pixel_values
+ examples["text_input_ids"], examples["text_mask"] = tokenize_captions(examples)
+ return examples
+
+ with accelerator.main_process_first():
+ if args.max_train_samples is not None:
+ dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
+ # Set the training transforms
+ train_dataset = dataset["train"].with_transform(preprocess_train)
+
+ def collate_fn(examples):
+ clip_pixel_values = torch.stack([example["clip_pixel_values"] for example in examples])
+ clip_pixel_values = clip_pixel_values.to(memory_format=torch.contiguous_format).float()
+ text_input_ids = torch.stack([example["text_input_ids"] for example in examples])
+ text_mask = torch.stack([example["text_mask"] for example in examples])
+ return {"clip_pixel_values": clip_pixel_values, "text_input_ids": text_input_ids, "text_mask": text_mask}
+
+ # DataLoaders creation:
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset,
+ shuffle=True,
+ collate_fn=collate_fn,
+ batch_size=args.train_batch_size,
+ num_workers=args.dataloader_num_workers,
+ )
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
+ num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
+ )
+
+ clip_mean = prior.clip_mean.clone()
+ clip_std = prior.clip_std.clone()
+
+ prior, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ prior, optimizer, train_dataloader, lr_scheduler
+ )
+
+ image_encoder.to(accelerator.device, dtype=weight_dtype)
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ tracker_config = dict(vars(args))
+ tracker_config.pop("validation_prompts")
+ accelerator.init_trackers(args.tracker_project_name, tracker_config)
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the most recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ initial_global_step = 0
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ initial_global_step = global_step
+ first_epoch = global_step // num_update_steps_per_epoch
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
+
+ clip_mean = clip_mean.to(weight_dtype).to(accelerator.device)
+ clip_std = clip_std.to(weight_dtype).to(accelerator.device)
+
+ for epoch in range(first_epoch, args.num_train_epochs):
+ train_loss = 0.0
+ for step, batch in enumerate(train_dataloader):
+ with accelerator.accumulate(prior):
+ # Convert images to latent space
+ text_input_ids, text_mask, clip_images = (
+ batch["text_input_ids"],
+ batch["text_mask"],
+ batch["clip_pixel_values"].to(weight_dtype),
+ )
+ with torch.no_grad():
+ text_encoder_output = text_encoder(text_input_ids)
+ prompt_embeds = text_encoder_output.text_embeds
+ text_encoder_hidden_states = text_encoder_output.last_hidden_state
+
+ image_embeds = image_encoder(clip_images).image_embeds
+ # Sample noise that we'll add to the image_embeds
+ noise = torch.randn_like(image_embeds)
+ bsz = image_embeds.shape[0]
+ # Sample a random timestep for each image
+ timesteps = torch.randint(
+ 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=image_embeds.device
+ )
+ timesteps = timesteps.long()
+ image_embeds = (image_embeds - clip_mean) / clip_std
+ noisy_latents = noise_scheduler.add_noise(image_embeds, noise, timesteps)
+
+ target = image_embeds
+
+ # Predict the noise residual and compute loss
+ model_pred = prior(
+ noisy_latents,
+ timestep=timesteps,
+ proj_embedding=prompt_embeds,
+ encoder_hidden_states=text_encoder_hidden_states,
+ attention_mask=text_mask,
+ ).predicted_image_embedding
+
+ if args.snr_gamma is None:
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
+ else:
+ # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
+ # Since we predict the noise instead of x_0, the original formulation is slightly changed.
+ # This is discussed in Section 4.2 of the same paper.
+ snr = compute_snr(noise_scheduler, timesteps)
+ mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(
+ dim=1
+ )[0]
+ if noise_scheduler.config.prediction_type == "epsilon":
+ mse_loss_weights = mse_loss_weights / snr
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ mse_loss_weights = mse_loss_weights / (snr + 1)
+
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
+ loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
+ loss = loss.mean()
+
+ # Gather the losses across all processes for logging (if we use distributed training).
+ avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
+ train_loss += avg_loss.item() / args.gradient_accumulation_steps
+
+ # Backpropagate
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ accelerator.clip_grad_norm_(prior.parameters(), args.max_grad_norm)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ if args.use_ema:
+ ema_prior.step(prior.parameters())
+ progress_bar.update(1)
+ global_step += 1
+ accelerator.log({"train_loss": train_loss}, step=global_step)
+ train_loss = 0.0
+
+ if global_step % args.checkpointing_steps == 0:
+ if accelerator.is_main_process:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ if accelerator.is_main_process:
+ if args.validation_prompts is not None and epoch % args.validation_epochs == 0:
+ if args.use_ema:
+ # Store the UNet parameters temporarily and load the EMA parameters to perform inference.
+ ema_prior.store(prior.parameters())
+ ema_prior.copy_to(prior.parameters())
+ log_validation(
+ image_encoder,
+ image_processor,
+ text_encoder,
+ tokenizer,
+ prior,
+ args,
+ accelerator,
+ weight_dtype,
+ global_step,
+ )
+ if args.use_ema:
+ # Switch back to the original UNet parameters.
+ ema_prior.restore(prior.parameters())
+
+ # Create the pipeline using the trained modules and save it.
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ prior = accelerator.unwrap_model(prior)
+ if args.use_ema:
+ ema_prior.copy_to(prior.parameters())
+
+ pipeline = AutoPipelineForText2Image.from_pretrained(
+ args.pretrained_decoder_model_name_or_path,
+ prior_image_encoder=image_encoder,
+ prior_text_encoder=text_encoder,
+ prior_prior=prior,
+ )
+ pipeline.prior_pipe.save_pretrained(args.output_dir)
+
+ # Run a final round of inference.
+ images = []
+ if args.validation_prompts is not None:
+ logger.info("Running inference for collecting generated images...")
+ pipeline = pipeline.to(accelerator.device)
+ pipeline.torch_dtype = weight_dtype
+ pipeline.set_progress_bar_config(disable=True)
+
+ if args.seed is None:
+ generator = None
+ else:
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
+
+ for i in range(len(args.validation_prompts)):
+ with torch.autocast("cuda"):
+ image = pipeline(args.validation_prompts[i], num_inference_steps=20, generator=generator).images[0]
+ images.append(image)
+
+ if args.push_to_hub:
+ save_model_card(args, repo_id, images, repo_folder=args.output_dir)
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/diffusers/examples/reinforcement_learning/README.md b/diffusers/examples/reinforcement_learning/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..3c3ada2031cf76678f7cdea0a53db27532951025
--- /dev/null
+++ b/diffusers/examples/reinforcement_learning/README.md
@@ -0,0 +1,22 @@
+# Overview
+
+These examples show how to run [Diffuser](https://arxiv.org/abs/2205.09991) in Diffusers.
+There are two ways to use the script, `run_diffuser_locomotion.py`.
+
+The key option is a change of the variable `n_guide_steps`.
+When `n_guide_steps=0`, the trajectories are sampled from the diffusion model, but not fine-tuned to maximize reward in the environment.
+By default, `n_guide_steps=2` to match the original implementation.
+
+
+You will need some RL specific requirements to run the examples:
+
+```sh
+pip install -f https://download.pytorch.org/whl/torch_stable.html \
+ free-mujoco-py \
+ einops \
+ gym==0.24.1 \
+ protobuf==3.20.1 \
+ git+https://github.com/rail-berkeley/d4rl.git \
+ mediapy \
+ Pillow==9.0.0
+```
diff --git a/diffusers/examples/reinforcement_learning/run_diffuser_locomotion.py b/diffusers/examples/reinforcement_learning/run_diffuser_locomotion.py
new file mode 100644
index 0000000000000000000000000000000000000000..adf6d1443d1c2e7caca7bdc1a26da1f2f186b8f9
--- /dev/null
+++ b/diffusers/examples/reinforcement_learning/run_diffuser_locomotion.py
@@ -0,0 +1,59 @@
+import d4rl # noqa
+import gym
+import tqdm
+from diffusers.experimental import ValueGuidedRLPipeline
+
+
+config = {
+ "n_samples": 64,
+ "horizon": 32,
+ "num_inference_steps": 20,
+ "n_guide_steps": 2, # can set to 0 for faster sampling, does not use value network
+ "scale_grad_by_std": True,
+ "scale": 0.1,
+ "eta": 0.0,
+ "t_grad_cutoff": 2,
+ "device": "cpu",
+}
+
+
+if __name__ == "__main__":
+ env_name = "hopper-medium-v2"
+ env = gym.make(env_name)
+
+ pipeline = ValueGuidedRLPipeline.from_pretrained(
+ "bglick13/hopper-medium-v2-value-function-hor32",
+ env=env,
+ )
+
+ env.seed(0)
+ obs = env.reset()
+ total_reward = 0
+ total_score = 0
+ T = 1000
+ rollout = [obs.copy()]
+ try:
+ for t in tqdm.tqdm(range(T)):
+ # call the policy
+ denorm_actions = pipeline(obs, planning_horizon=32)
+
+ # execute action in environment
+ next_observation, reward, terminal, _ = env.step(denorm_actions)
+ score = env.get_normalized_score(total_reward)
+
+ # update return
+ total_reward += reward
+ total_score += score
+ print(
+ f"Step: {t}, Reward: {reward}, Total Reward: {total_reward}, Score: {score}, Total Score:"
+ f" {total_score}"
+ )
+
+ # save observations for rendering
+ rollout.append(next_observation.copy())
+
+ obs = next_observation
+ except KeyboardInterrupt:
+ pass
+
+ print(f"Total reward: {total_reward}")
diff --git a/diffusers/examples/research_projects/README.md b/diffusers/examples/research_projects/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..bd0fec9690e6be01b989861510ddc4d7f7d4a099
--- /dev/null
+++ b/diffusers/examples/research_projects/README.md
@@ -0,0 +1,14 @@
+# Research projects
+
+This folder contains various research projects using 🧨 Diffusers.
+They are not really maintained by the core maintainers of this library and often require a specific version of Diffusers that is indicated in the requirements file of each folder.
+Updating them to the most recent version of the library will require some work.
+
+To use any of them, just run the command
+
+```sh
+pip install -r requirements.txt
+```
+inside the folder of your choice.
+
+If you need help with any of those, please open an issue where you directly ping the author(s), as indicated at the top of the README of each folder.
diff --git a/diffusers/examples/research_projects/colossalai/README.md b/diffusers/examples/research_projects/colossalai/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..be94950b772eaaac994104f0787d9ffbfc769f63
--- /dev/null
+++ b/diffusers/examples/research_projects/colossalai/README.md
@@ -0,0 +1,111 @@
+# [DreamBooth](https://github.com/huggingface/diffusers/tree/main/examples/dreambooth) by [colossalai](https://github.com/hpcaitech/ColossalAI.git)
+
+[DreamBooth](https://arxiv.org/abs/2208.12242) is a method to personalize text2image models like stable diffusion given just a few(3~5) images of a subject.
+The `train_dreambooth_colossalai.py` script shows how to implement the training procedure and adapt it for stable diffusion.
+
+By accommodating model data in CPU and GPU and moving the data to the computing device when necessary, [Gemini](https://www.colossalai.org/docs/advanced_tutorials/meet_gemini), the Heterogeneous Memory Manager of [Colossal-AI](https://github.com/hpcaitech/ColossalAI) can breakthrough the GPU memory wall by using GPU and CPU memory (composed of CPU DRAM or nvme SSD memory) together at the same time. Moreover, the model scale can be further improved by combining heterogeneous training with the other parallel approaches, such as data parallel, tensor parallel and pipeline parallel.
+
+## Installing the dependencies
+
+Before running the scripts, make sure to install the library's training dependencies:
+
+```bash
+pip install -r requirements.txt
+```
+
+## Install [ColossalAI](https://github.com/hpcaitech/ColossalAI.git)
+
+**From PyPI**
+```bash
+pip install colossalai
+```
+
+**From source**
+
+```bash
+git clone https://github.com/hpcaitech/ColossalAI.git
+cd ColossalAI
+
+# install colossalai
+pip install .
+```
+
+## Dataset for Teyvat BLIP captions
+Dataset used to train [Teyvat characters text to image model](https://github.com/hpcaitech/ColossalAI/tree/main/examples/images/diffusion).
+
+BLIP generated captions for characters images from [genshin-impact fandom wiki](https://genshin-impact.fandom.com/wiki/Character#Playable_Characters)and [biligame wiki for genshin impact](https://wiki.biligame.com/ys/%E8%A7%92%E8%89%B2).
+
+For each row the dataset contains `image` and `text` keys. `image` is a varying size PIL png, and `text` is the accompanying text caption. Only a train split is provided.
+
+The `text` include the tag `Teyvat`, `Name`,`Element`, `Weapon`, `Region`, `Model type`, and `Description`, the `Description` is captioned with the [pre-trained BLIP model](https://github.com/salesforce/BLIP).
+
+## Training
+
+The argument `placement` can be `cpu`, `auto`, `cuda`, with `cpu` the GPU RAM required can be minimized to 4GB but will deceleration, with `cuda` you can also reduce GPU memory by half but accelerated training, with `auto` a more balanced solution for speed and memory can be obtained。
+
+**___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.___**
+
+```bash
+export MODEL_NAME="CompVis/stable-diffusion-v1-4"
+export INSTANCE_DIR="path-to-instance-images"
+export OUTPUT_DIR="path-to-save-model"
+
+torchrun --nproc_per_node 2 train_dreambooth_colossalai.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --instance_data_dir=$INSTANCE_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --instance_prompt="a photo of sks dog" \
+ --resolution=512 \
+ --train_batch_size=1 \
+ --learning_rate=5e-6 \
+ --lr_scheduler="constant" \
+ --lr_warmup_steps=0 \
+ --max_train_steps=400 \
+ --placement="cuda"
+```
+
+
+### Training with prior-preservation loss
+
+Prior-preservation is used to avoid overfitting and language-drift. Refer to the paper to learn more about it. For prior-preservation we first generate images using the model with a class prompt and then use those during training along with our data.
+According to the paper, it's recommended to generate `num_epochs * num_samples` images for prior-preservation. 200-300 works well for most cases. The `num_class_images` flag sets the number of images to generate with the class prompt. You can place existing images in `class_data_dir`, and the training script will generate any additional images so that `num_class_images` are present in `class_data_dir` during training time.
+
+```bash
+export MODEL_NAME="CompVis/stable-diffusion-v1-4"
+export INSTANCE_DIR="path-to-instance-images"
+export CLASS_DIR="path-to-class-images"
+export OUTPUT_DIR="path-to-save-model"
+
+torchrun --nproc_per_node 2 train_dreambooth_colossalai.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --instance_data_dir=$INSTANCE_DIR \
+ --class_data_dir=$CLASS_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --with_prior_preservation --prior_loss_weight=1.0 \
+ --instance_prompt="a photo of sks dog" \
+ --class_prompt="a photo of dog" \
+ --resolution=512 \
+ --train_batch_size=1 \
+ --learning_rate=5e-6 \
+ --lr_scheduler="constant" \
+ --lr_warmup_steps=0 \
+ --max_train_steps=800 \
+ --placement="cuda"
+```
+
+## Inference
+
+Once you have trained a model using above command, the inference can be done simply using the `StableDiffusionPipeline`. Make sure to include the `identifier`(e.g. sks in above example) in your prompt.
+
+```python
+from diffusers import StableDiffusionPipeline
+import torch
+
+model_id = "path-to-save-model"
+pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
+
+prompt = "A photo of sks dog in a bucket"
+image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5).images[0]
+
+image.save("dog-bucket.png")
+```
diff --git a/diffusers/examples/research_projects/colossalai/inference.py b/diffusers/examples/research_projects/colossalai/inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..3b115c2d2b8f5bcdb3a0c053a6c71b91a965c573
--- /dev/null
+++ b/diffusers/examples/research_projects/colossalai/inference.py
@@ -0,0 +1,12 @@
+import torch
+
+from diffusers import StableDiffusionPipeline
+
+
+model_id = "path-to-your-trained-model"
+pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
+
+prompt = "A photo of sks dog in a bucket"
+image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5).images[0]
+
+image.save("dog-bucket.png")
diff --git a/diffusers/examples/research_projects/colossalai/requirement.txt b/diffusers/examples/research_projects/colossalai/requirement.txt
new file mode 100644
index 0000000000000000000000000000000000000000..f80467dcff521bfed1fa72109e1e01e92ab05646
--- /dev/null
+++ b/diffusers/examples/research_projects/colossalai/requirement.txt
@@ -0,0 +1,7 @@
+diffusers
+torch
+torchvision
+ftfy
+tensorboard
+Jinja2
+transformers
\ No newline at end of file
diff --git a/diffusers/examples/research_projects/colossalai/train_dreambooth_colossalai.py b/diffusers/examples/research_projects/colossalai/train_dreambooth_colossalai.py
new file mode 100644
index 0000000000000000000000000000000000000000..10c8e095a69652eb47e01ed7d1368dbdf9ed735c
--- /dev/null
+++ b/diffusers/examples/research_projects/colossalai/train_dreambooth_colossalai.py
@@ -0,0 +1,671 @@
+import argparse
+import math
+import os
+from pathlib import Path
+
+import colossalai
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+from colossalai.context.parallel_mode import ParallelMode
+from colossalai.core import global_context as gpc
+from colossalai.logging import disable_existing_loggers, get_dist_logger
+from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer
+from colossalai.nn.parallel.utils import get_static_torch_model
+from colossalai.utils import get_current_device
+from colossalai.utils.model.colo_init_context import ColoInitContext
+from huggingface_hub import create_repo, upload_folder
+from huggingface_hub.utils import insecure_hashlib
+from PIL import Image
+from torch.utils.data import Dataset
+from torchvision import transforms
+from tqdm.auto import tqdm
+from transformers import AutoTokenizer, PretrainedConfig
+
+from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel
+from diffusers.optimization import get_scheduler
+
+
+disable_existing_loggers()
+logger = get_dist_logger()
+
+
+def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str):
+ text_encoder_config = PretrainedConfig.from_pretrained(
+ pretrained_model_name_or_path,
+ subfolder="text_encoder",
+ revision=args.revision,
+ )
+ model_class = text_encoder_config.architectures[0]
+
+ if model_class == "CLIPTextModel":
+ from transformers import CLIPTextModel
+
+ return CLIPTextModel
+ elif model_class == "RobertaSeriesModelWithTransformation":
+ from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation
+
+ return RobertaSeriesModelWithTransformation
+ else:
+ raise ValueError(f"{model_class} is not supported.")
+
+
+def parse_args(input_args=None):
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--tokenizer_name",
+ type=str,
+ default=None,
+ help="Pretrained tokenizer name or path if not the same as model_name",
+ )
+ parser.add_argument(
+ "--instance_data_dir",
+ type=str,
+ default=None,
+ required=True,
+ help="A folder containing the training data of instance images.",
+ )
+ parser.add_argument(
+ "--class_data_dir",
+ type=str,
+ default=None,
+ required=False,
+ help="A folder containing the training data of class images.",
+ )
+ parser.add_argument(
+ "--instance_prompt",
+ type=str,
+ default="a photo of sks dog",
+ required=False,
+ help="The prompt with identifier specifying the instance",
+ )
+ parser.add_argument(
+ "--class_prompt",
+ type=str,
+ default=None,
+ help="The prompt to specify images in the same class as provided instance images.",
+ )
+ parser.add_argument(
+ "--with_prior_preservation",
+ default=False,
+ action="store_true",
+ help="Flag to add prior preservation loss.",
+ )
+ parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.")
+ parser.add_argument(
+ "--num_class_images",
+ type=int,
+ default=100,
+ help=(
+ "Minimal class images for prior preservation loss. If there are not enough images already present in"
+ " class_data_dir, additional images will be sampled with class_prompt."
+ ),
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="text-inversion-model",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--placement",
+ type=str,
+ default="cpu",
+ help="Placement Policy for Gemini. Valid when using colossalai as dist plan.",
+ )
+ parser.add_argument(
+ "--center_crop",
+ default=False,
+ action="store_true",
+ help=(
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
+ " cropped. The images will be resized to the resolution first before cropping."
+ ),
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument(
+ "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=1)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument("--save_steps", type=int, default=500, help="Save checkpoint every X updates steps.")
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=5e-6,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
+ )
+
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+
+ if input_args is not None:
+ args = parser.parse_args(input_args)
+ else:
+ args = parser.parse_args()
+
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ if args.with_prior_preservation:
+ if args.class_data_dir is None:
+ raise ValueError("You must specify a data directory for class images.")
+ if args.class_prompt is None:
+ raise ValueError("You must specify prompt for class images.")
+ else:
+ if args.class_data_dir is not None:
+ logger.warning("You need not use --class_data_dir without --with_prior_preservation.")
+ if args.class_prompt is not None:
+ logger.warning("You need not use --class_prompt without --with_prior_preservation.")
+
+ return args
+
+
+class DreamBoothDataset(Dataset):
+ """
+ A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
+ It pre-processes the images and the tokenizes prompts.
+ """
+
+ def __init__(
+ self,
+ instance_data_root,
+ instance_prompt,
+ tokenizer,
+ class_data_root=None,
+ class_prompt=None,
+ size=512,
+ center_crop=False,
+ ):
+ self.size = size
+ self.center_crop = center_crop
+ self.tokenizer = tokenizer
+
+ self.instance_data_root = Path(instance_data_root)
+ if not self.instance_data_root.exists():
+ raise ValueError("Instance images root doesn't exists.")
+
+ self.instance_images_path = list(Path(instance_data_root).iterdir())
+ self.num_instance_images = len(self.instance_images_path)
+ self.instance_prompt = instance_prompt
+ self._length = self.num_instance_images
+
+ if class_data_root is not None:
+ self.class_data_root = Path(class_data_root)
+ self.class_data_root.mkdir(parents=True, exist_ok=True)
+ self.class_images_path = list(self.class_data_root.iterdir())
+ self.num_class_images = len(self.class_images_path)
+ self._length = max(self.num_class_images, self.num_instance_images)
+ self.class_prompt = class_prompt
+ else:
+ self.class_data_root = None
+
+ self.image_transforms = transforms.Compose(
+ [
+ transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
+ transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+ )
+
+ def __len__(self):
+ return self._length
+
+ def __getitem__(self, index):
+ example = {}
+ instance_image = Image.open(self.instance_images_path[index % self.num_instance_images])
+ if not instance_image.mode == "RGB":
+ instance_image = instance_image.convert("RGB")
+ example["instance_images"] = self.image_transforms(instance_image)
+ example["instance_prompt_ids"] = self.tokenizer(
+ self.instance_prompt,
+ padding="do_not_pad",
+ truncation=True,
+ max_length=self.tokenizer.model_max_length,
+ ).input_ids
+
+ if self.class_data_root:
+ class_image = Image.open(self.class_images_path[index % self.num_class_images])
+ if not class_image.mode == "RGB":
+ class_image = class_image.convert("RGB")
+ example["class_images"] = self.image_transforms(class_image)
+ example["class_prompt_ids"] = self.tokenizer(
+ self.class_prompt,
+ padding="do_not_pad",
+ truncation=True,
+ max_length=self.tokenizer.model_max_length,
+ ).input_ids
+
+ return example
+
+
+class PromptDataset(Dataset):
+ """A simple dataset to prepare the prompts to generate class images on multiple GPUs."""
+
+ def __init__(self, prompt, num_samples):
+ self.prompt = prompt
+ self.num_samples = num_samples
+
+ def __len__(self):
+ return self.num_samples
+
+ def __getitem__(self, index):
+ example = {}
+ example["prompt"] = self.prompt
+ example["index"] = index
+ return example
+
+
+# Gemini + ZeRO DDP
+def gemini_zero_dpp(model: torch.nn.Module, placememt_policy: str = "auto"):
+ from colossalai.nn.parallel import GeminiDDP
+
+ model = GeminiDDP(
+ model, device=get_current_device(), placement_policy=placememt_policy, pin_memory=True, search_range_mb=64
+ )
+ return model
+
+
+def main(args):
+ if args.seed is None:
+ colossalai.launch_from_torch(config={})
+ else:
+ colossalai.launch_from_torch(config={}, seed=args.seed)
+
+ local_rank = gpc.get_local_rank(ParallelMode.DATA)
+ world_size = gpc.get_world_size(ParallelMode.DATA)
+
+ if args.with_prior_preservation:
+ class_images_dir = Path(args.class_data_dir)
+ if not class_images_dir.exists():
+ class_images_dir.mkdir(parents=True)
+ cur_class_images = len(list(class_images_dir.iterdir()))
+
+ if cur_class_images < args.num_class_images:
+ torch_dtype = torch.float16 if get_current_device() == "cuda" else torch.float32
+ pipeline = DiffusionPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ torch_dtype=torch_dtype,
+ safety_checker=None,
+ revision=args.revision,
+ )
+ pipeline.set_progress_bar_config(disable=True)
+
+ num_new_images = args.num_class_images - cur_class_images
+ logger.info(f"Number of class images to sample: {num_new_images}.")
+
+ sample_dataset = PromptDataset(args.class_prompt, num_new_images)
+ sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
+
+ pipeline.to(get_current_device())
+
+ for example in tqdm(
+ sample_dataloader,
+ desc="Generating class images",
+ disable=not local_rank == 0,
+ ):
+ images = pipeline(example["prompt"]).images
+
+ for i, image in enumerate(images):
+ hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest()
+ image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
+ image.save(image_filename)
+
+ del pipeline
+
+ # Handle the repository creation
+ if local_rank == 0:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
+ ).repo_id
+
+ # Load the tokenizer
+ if args.tokenizer_name:
+ logger.info(f"Loading tokenizer from {args.tokenizer_name}", ranks=[0])
+ tokenizer = AutoTokenizer.from_pretrained(
+ args.tokenizer_name,
+ revision=args.revision,
+ use_fast=False,
+ )
+ elif args.pretrained_model_name_or_path:
+ logger.info("Loading tokenizer from pretrained model", ranks=[0])
+ tokenizer = AutoTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="tokenizer",
+ revision=args.revision,
+ use_fast=False,
+ )
+ # import correct text encoder class
+ text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path)
+
+ # Load models and create wrapper for stable diffusion
+
+ logger.info(f"Loading text_encoder from {args.pretrained_model_name_or_path}", ranks=[0])
+
+ text_encoder = text_encoder_cls.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="text_encoder",
+ revision=args.revision,
+ )
+
+ logger.info(f"Loading AutoencoderKL from {args.pretrained_model_name_or_path}", ranks=[0])
+ vae = AutoencoderKL.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="vae",
+ revision=args.revision,
+ )
+
+ logger.info(f"Loading UNet2DConditionModel from {args.pretrained_model_name_or_path}", ranks=[0])
+ with ColoInitContext(device=get_current_device()):
+ unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, low_cpu_mem_usage=False
+ )
+
+ vae.requires_grad_(False)
+ text_encoder.requires_grad_(False)
+
+ if args.gradient_checkpointing:
+ unet.enable_gradient_checkpointing()
+
+ if args.scale_lr:
+ args.learning_rate = args.learning_rate * args.train_batch_size * world_size
+
+ unet = gemini_zero_dpp(unet, args.placement)
+
+ # config optimizer for colossalai zero
+ optimizer = GeminiAdamOptimizer(unet, lr=args.learning_rate, initial_scale=2**5, clipping_norm=args.max_grad_norm)
+
+ # load noise_scheduler
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
+
+ # prepare dataset
+ logger.info(f"Prepare dataset from {args.instance_data_dir}", ranks=[0])
+ train_dataset = DreamBoothDataset(
+ instance_data_root=args.instance_data_dir,
+ instance_prompt=args.instance_prompt,
+ class_data_root=args.class_data_dir if args.with_prior_preservation else None,
+ class_prompt=args.class_prompt,
+ tokenizer=tokenizer,
+ size=args.resolution,
+ center_crop=args.center_crop,
+ )
+
+ def collate_fn(examples):
+ input_ids = [example["instance_prompt_ids"] for example in examples]
+ pixel_values = [example["instance_images"] for example in examples]
+
+ # Concat class and instance examples for prior preservation.
+ # We do this to avoid doing two forward passes.
+ if args.with_prior_preservation:
+ input_ids += [example["class_prompt_ids"] for example in examples]
+ pixel_values += [example["class_images"] for example in examples]
+
+ pixel_values = torch.stack(pixel_values)
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
+
+ input_ids = tokenizer.pad(
+ {"input_ids": input_ids},
+ padding="max_length",
+ max_length=tokenizer.model_max_length,
+ return_tensors="pt",
+ ).input_ids
+
+ batch = {
+ "input_ids": input_ids,
+ "pixel_values": pixel_values,
+ }
+ return batch
+
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn, num_workers=1
+ )
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader))
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps,
+ num_training_steps=args.max_train_steps,
+ )
+ weight_dtype = torch.float32
+ if args.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif args.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ # Move text_encode and vae to gpu.
+ # For mixed precision training we cast the text_encoder and vae weights to half-precision
+ # as these models are only used for inference, keeping weights in full precision is not required.
+ vae.to(get_current_device(), dtype=weight_dtype)
+ text_encoder.to(get_current_device(), dtype=weight_dtype)
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader))
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # Train!
+ total_batch_size = args.train_batch_size * world_size
+
+ logger.info("***** Running training *****", ranks=[0])
+ logger.info(f" Num examples = {len(train_dataset)}", ranks=[0])
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}", ranks=[0])
+ logger.info(f" Num Epochs = {args.num_train_epochs}", ranks=[0])
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}", ranks=[0])
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}", ranks=[0])
+ logger.info(f" Total optimization steps = {args.max_train_steps}", ranks=[0])
+
+ # Only show the progress bar once on each machine.
+ progress_bar = tqdm(range(args.max_train_steps), disable=not local_rank == 0)
+ progress_bar.set_description("Steps")
+ global_step = 0
+
+ torch.cuda.synchronize()
+ for epoch in range(args.num_train_epochs):
+ unet.train()
+ for step, batch in enumerate(train_dataloader):
+ torch.cuda.reset_peak_memory_stats()
+ # Move batch to gpu
+ for key, value in batch.items():
+ batch[key] = value.to(get_current_device(), non_blocking=True)
+
+ # Convert images to latent space
+ optimizer.zero_grad()
+
+ latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
+ latents = latents * 0.18215
+
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(latents)
+ bsz = latents.shape[0]
+ # Sample a random timestep for each image
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
+ timesteps = timesteps.long()
+
+ # Add noise to the latents according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
+
+ # Get the text embedding for conditioning
+ encoder_hidden_states = text_encoder(batch["input_ids"])[0]
+
+ # Predict the noise residual
+ model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
+
+ # Get the target for loss depending on the prediction type
+ if noise_scheduler.config.prediction_type == "epsilon":
+ target = noise
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
+ else:
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
+
+ if args.with_prior_preservation:
+ # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
+ model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
+ target, target_prior = torch.chunk(target, 2, dim=0)
+
+ # Compute instance loss
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean()
+
+ # Compute prior loss
+ prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
+
+ # Add the prior loss to the instance loss.
+ loss = loss + args.prior_loss_weight * prior_loss
+ else:
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
+
+ optimizer.backward(loss)
+
+ optimizer.step()
+ lr_scheduler.step()
+ logger.info(f"max GPU_mem cost is {torch.cuda.max_memory_allocated()/2**20} MB", ranks=[0])
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ progress_bar.update(1)
+ global_step += 1
+ logs = {
+ "loss": loss.detach().item(),
+ "lr": optimizer.param_groups[0]["lr"],
+ } # lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+
+ if global_step % args.save_steps == 0:
+ torch.cuda.synchronize()
+ torch_unet = get_static_torch_model(unet)
+ if local_rank == 0:
+ pipeline = DiffusionPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ unet=torch_unet,
+ revision=args.revision,
+ )
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ pipeline.save_pretrained(save_path)
+ logger.info(f"Saving model checkpoint to {save_path}", ranks=[0])
+ if global_step >= args.max_train_steps:
+ break
+
+ torch.cuda.synchronize()
+ unet = get_static_torch_model(unet)
+
+ if local_rank == 0:
+ pipeline = DiffusionPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ unet=unet,
+ revision=args.revision,
+ )
+
+ pipeline.save_pretrained(args.output_dir)
+ logger.info(f"Saving model checkpoint to {args.output_dir}", ranks=[0])
+
+ if args.push_to_hub:
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ main(args)
diff --git a/diffusers/examples/research_projects/consistency_training/README.md b/diffusers/examples/research_projects/consistency_training/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..67eb83514192cf33fc33bebd01559efcf6744387
--- /dev/null
+++ b/diffusers/examples/research_projects/consistency_training/README.md
@@ -0,0 +1,24 @@
+# Consistency Training
+
+`train_cm_ct_unconditional.py` trains a consistency model (CM) from scratch following the consistency training (CT) algorithm introduced in [Consistency Models](https://arxiv.org/abs/2303.01469) and refined in [Improved Techniques for Training Consistency Models](https://arxiv.org/abs/2310.14189). Both unconditional and class-conditional training are supported.
+
+A usage example is as follows:
+
+```bash
+accelerate launch examples/research_projects/consistency_training/train_cm_ct_unconditional.py \
+ --dataset_name="cifar10" \
+ --dataset_image_column_name="img" \
+ --output_dir="/path/to/output/dir" \
+ --mixed_precision=fp16 \
+ --resolution=32 \
+ --max_train_steps=1000 --max_train_samples=10000 \
+ --dataloader_num_workers=8 \
+ --noise_precond_type="cm" --input_precond_type="cm" \
+ --train_batch_size=4 \
+ --learning_rate=1e-04 --lr_scheduler="constant" --lr_warmup_steps=0 \
+ --use_8bit_adam \
+ --use_ema \
+ --validation_steps=100 --eval_batch_size=4 \
+ --checkpointing_steps=100 --checkpoints_total_limit=10 \
+ --class_conditional --num_classes=10 \
+```
\ No newline at end of file
diff --git a/diffusers/examples/research_projects/consistency_training/requirements.txt b/diffusers/examples/research_projects/consistency_training/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..e19b0ce60bf407bec21a9b85f9232cad957bfa6f
--- /dev/null
+++ b/diffusers/examples/research_projects/consistency_training/requirements.txt
@@ -0,0 +1,6 @@
+accelerate>=0.16.0
+torchvision
+transformers>=4.25.1
+ftfy
+tensorboard
+Jinja2
\ No newline at end of file
diff --git a/diffusers/examples/research_projects/consistency_training/train_cm_ct_unconditional.py b/diffusers/examples/research_projects/consistency_training/train_cm_ct_unconditional.py
new file mode 100644
index 0000000000000000000000000000000000000000..eccc539f230c672ce309751cb277ecb315c22639
--- /dev/null
+++ b/diffusers/examples/research_projects/consistency_training/train_cm_ct_unconditional.py
@@ -0,0 +1,1438 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+"""Script to train a consistency model from scratch via (improved) consistency training."""
+
+import argparse
+import gc
+import logging
+import math
+import os
+import shutil
+from datetime import timedelta
+from pathlib import Path
+
+import accelerate
+import datasets
+import numpy as np
+import torch
+from accelerate import Accelerator, InitProcessGroupKwargs
+from accelerate.logging import get_logger
+from accelerate.utils import ProjectConfiguration, set_seed
+from datasets import load_dataset
+from huggingface_hub import create_repo, upload_folder
+from packaging import version
+from torchvision import transforms
+from tqdm.auto import tqdm
+
+import diffusers
+from diffusers import (
+ CMStochasticIterativeScheduler,
+ ConsistencyModelPipeline,
+ UNet2DModel,
+)
+from diffusers.optimization import get_scheduler
+from diffusers.training_utils import EMAModel, resolve_interpolation_mode
+from diffusers.utils import is_tensorboard_available, is_wandb_available
+from diffusers.utils.import_utils import is_xformers_available
+from diffusers.utils.torch_utils import is_compiled_module
+
+
+if is_wandb_available():
+ import wandb
+
+
+logger = get_logger(__name__, log_level="INFO")
+
+
+def _extract_into_tensor(arr, timesteps, broadcast_shape):
+ """
+ Extract values from a 1-D numpy array for a batch of indices.
+
+ :param arr: the 1-D numpy array.
+ :param timesteps: a tensor of indices into the array to extract.
+ :param broadcast_shape: a larger shape of K dimensions with the batch
+ dimension equal to the length of timesteps.
+ :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
+ """
+ if not isinstance(arr, torch.Tensor):
+ arr = torch.from_numpy(arr)
+ res = arr[timesteps].float().to(timesteps.device)
+ while len(res.shape) < len(broadcast_shape):
+ res = res[..., None]
+ return res.expand(broadcast_shape)
+
+
+def append_dims(x, target_dims):
+ """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
+ dims_to_append = target_dims - x.ndim
+ if dims_to_append < 0:
+ raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less")
+ return x[(...,) + (None,) * dims_to_append]
+
+
+def extract_into_tensor(a, t, x_shape):
+ b, *_ = t.shape
+ out = a.gather(-1, t)
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
+
+
+def get_discretization_steps(global_step: int, max_train_steps: int, s_0: int = 10, s_1: int = 1280, constant=False):
+ """
+ Calculates the current discretization steps at global step k using the discretization curriculum N(k).
+ """
+ if constant:
+ return s_0 + 1
+
+ k_prime = math.floor(max_train_steps / (math.log2(math.floor(s_1 / s_0)) + 1))
+ num_discretization_steps = min(s_0 * 2 ** math.floor(global_step / k_prime), s_1) + 1
+
+ return num_discretization_steps
+
+
+def get_skip_steps(global_step, initial_skip: int = 1):
+ # Currently only support constant skip curriculum.
+ return initial_skip
+
+
+def get_karras_sigmas(
+ num_discretization_steps: int,
+ sigma_min: float = 0.002,
+ sigma_max: float = 80.0,
+ rho: float = 7.0,
+ dtype=torch.float32,
+):
+ """
+ Calculates the Karras sigmas timestep discretization of [sigma_min, sigma_max].
+ """
+ ramp = np.linspace(0, 1, num_discretization_steps)
+ min_inv_rho = sigma_min ** (1 / rho)
+ max_inv_rho = sigma_max ** (1 / rho)
+ sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
+ # Make sure sigmas are in increasing rather than decreasing order (see section 2 of the iCT paper)
+ sigmas = sigmas[::-1].copy()
+ sigmas = torch.from_numpy(sigmas).to(dtype=dtype)
+ return sigmas
+
+
+def get_discretized_lognormal_weights(noise_levels: torch.Tensor, p_mean: float = -1.1, p_std: float = 2.0):
+ """
+ Calculates the unnormalized weights for a 1D array of noise level sigma_i based on the discretized lognormal"
+ " distribution used in the iCT paper (given in Equation 10).
+ """
+ upper_prob = torch.special.erf((torch.log(noise_levels[1:]) - p_mean) / (math.sqrt(2) * p_std))
+ lower_prob = torch.special.erf((torch.log(noise_levels[:-1]) - p_mean) / (math.sqrt(2) * p_std))
+ weights = upper_prob - lower_prob
+ return weights
+
+
+def get_loss_weighting_schedule(noise_levels: torch.Tensor):
+ """
+ Calculates the loss weighting schedule lambda given a set of noise levels.
+ """
+ return 1.0 / (noise_levels[1:] - noise_levels[:-1])
+
+
+def add_noise(original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor):
+ # Make sure timesteps (Karras sigmas) have the same device and dtype as original_samples
+ sigmas = timesteps.to(device=original_samples.device, dtype=original_samples.dtype)
+ while len(sigmas.shape) < len(original_samples.shape):
+ sigmas = sigmas.unsqueeze(-1)
+
+ noisy_samples = original_samples + noise * sigmas
+
+ return noisy_samples
+
+
+def get_noise_preconditioning(sigmas, noise_precond_type: str = "cm"):
+ """
+ Calculates the noise preconditioning function c_noise, which is used to transform the raw Karras sigmas into the
+ timestep input for the U-Net.
+ """
+ if noise_precond_type == "none":
+ return sigmas
+ elif noise_precond_type == "edm":
+ return 0.25 * torch.log(sigmas)
+ elif noise_precond_type == "cm":
+ return 1000 * 0.25 * torch.log(sigmas + 1e-44)
+ else:
+ raise ValueError(
+ f"Noise preconditioning type {noise_precond_type} is not current supported. Currently supported noise"
+ f" preconditioning types are `none` (which uses the sigmas as is), `edm`, and `cm`."
+ )
+
+
+def get_input_preconditioning(sigmas, sigma_data=0.5, input_precond_type: str = "cm"):
+ """
+ Calculates the input preconditioning factor c_in, which is used to scale the U-Net image input.
+ """
+ if input_precond_type == "none":
+ return 1
+ elif input_precond_type == "cm":
+ return 1.0 / (sigmas**2 + sigma_data**2)
+ else:
+ raise ValueError(
+ f"Input preconditioning type {input_precond_type} is not current supported. Currently supported input"
+ f" preconditioning types are `none` (which uses a scaling factor of 1.0) and `cm`."
+ )
+
+
+def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=1.0):
+ scaled_timestep = timestep_scaling * timestep
+ c_skip = sigma_data**2 / (scaled_timestep**2 + sigma_data**2)
+ c_out = scaled_timestep / (scaled_timestep**2 + sigma_data**2) ** 0.5
+ return c_skip, c_out
+
+
+def log_validation(unet, scheduler, args, accelerator, weight_dtype, step, name="teacher"):
+ logger.info("Running validation... ")
+
+ unet = accelerator.unwrap_model(unet)
+ pipeline = ConsistencyModelPipeline(
+ unet=unet,
+ scheduler=scheduler,
+ )
+ pipeline = pipeline.to(device=accelerator.device)
+ pipeline.set_progress_bar_config(disable=True)
+
+ if args.enable_xformers_memory_efficient_attention:
+ pipeline.enable_xformers_memory_efficient_attention()
+
+ if args.seed is None:
+ generator = None
+ else:
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
+
+ class_labels = [None]
+ if args.class_conditional:
+ if args.num_classes is not None:
+ class_labels = list(range(args.num_classes))
+ else:
+ logger.warning(
+ "The model is class-conditional but the number of classes is not set. The generated images will be"
+ " unconditional rather than class-conditional."
+ )
+
+ image_logs = []
+
+ for class_label in class_labels:
+ images = []
+ with torch.autocast("cuda"):
+ images = pipeline(
+ num_inference_steps=1,
+ batch_size=args.eval_batch_size,
+ class_labels=[class_label] * args.eval_batch_size,
+ generator=generator,
+ ).images
+ log = {"images": images}
+ if args.class_conditional and class_label is not None:
+ log["class_label"] = str(class_label)
+ else:
+ log["class_label"] = "images"
+ image_logs.append(log)
+
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ for log in image_logs:
+ images = log["images"]
+ class_label = log["class_label"]
+ formatted_images = []
+ for image in images:
+ formatted_images.append(np.asarray(image))
+
+ formatted_images = np.stack(formatted_images)
+
+ tracker.writer.add_images(class_label, formatted_images, step, dataformats="NHWC")
+ elif tracker.name == "wandb":
+ formatted_images = []
+
+ for log in image_logs:
+ images = log["images"]
+ class_label = log["class_label"]
+ for image in images:
+ image = wandb.Image(image, caption=class_label)
+ formatted_images.append(image)
+
+ tracker.log({f"validation/{name}": formatted_images})
+ else:
+ logger.warning(f"image logging not implemented for {tracker.name}")
+
+ del pipeline
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ return image_logs
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ # ------------Model Arguments-----------
+ parser.add_argument(
+ "--model_config_name_or_path",
+ type=str,
+ default=None,
+ help="The config of the UNet model to train, leave as None to use standard DDPM configuration.",
+ )
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ help=(
+ "If initializing the weights from a pretrained model, the path to the pretrained model or model identifier"
+ " from huggingface.co/models."
+ ),
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--variant",
+ type=str,
+ default=None,
+ help=(
+ "Variant of the model files of the pretrained model identifier from huggingface.co/models, e.g. `fp16`,"
+ " `non_ema`, etc.",
+ ),
+ )
+ # ------------Dataset Arguments-----------
+ parser.add_argument(
+ "--train_data_dir",
+ type=str,
+ default=None,
+ help=(
+ "A folder containing the training data. Folder contents must follow the structure described in"
+ " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
+ " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
+ ),
+ )
+ parser.add_argument(
+ "--dataset_name",
+ type=str,
+ default=None,
+ help=(
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
+ " or to a folder containing files that HF Datasets can understand."
+ ),
+ )
+ parser.add_argument(
+ "--dataset_config_name",
+ type=str,
+ default=None,
+ help="The config of the Dataset, leave as None if there's only one config.",
+ )
+ parser.add_argument(
+ "--dataset_image_column_name",
+ type=str,
+ default="image",
+ help="The name of the image column in the dataset to use for training.",
+ )
+ parser.add_argument(
+ "--dataset_class_label_column_name",
+ type=str,
+ default="label",
+ help="If doing class-conditional training, the name of the class label column in the dataset to use.",
+ )
+ # ------------Image Processing Arguments-----------
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=64,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--interpolation_type",
+ type=str,
+ default="bilinear",
+ help=(
+ "The interpolation function used when resizing images to the desired resolution. Choose between `bilinear`,"
+ " `bicubic`, `box`, `nearest`, `nearest_exact`, `hamming`, and `lanczos`."
+ ),
+ )
+ parser.add_argument(
+ "--center_crop",
+ default=False,
+ action="store_true",
+ help=(
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
+ " cropped. The images will be resized to the resolution first before cropping."
+ ),
+ )
+ parser.add_argument(
+ "--random_flip",
+ default=False,
+ action="store_true",
+ help="whether to randomly flip images horizontally",
+ )
+ parser.add_argument(
+ "--class_conditional",
+ action="store_true",
+ help=(
+ "Whether to train a class-conditional model. If set, the class labels will be taken from the `label`"
+ " column of the provided dataset."
+ ),
+ )
+ parser.add_argument(
+ "--num_classes",
+ type=int,
+ default=None,
+ help="The number of classes in the training data, if training a class-conditional model.",
+ )
+ parser.add_argument(
+ "--class_embed_type",
+ type=str,
+ default=None,
+ help=(
+ "The class embedding type to use. Choose from `None`, `identity`, and `timestep`. If `class_conditional`"
+ " and `num_classes` and set, but `class_embed_type` is `None`, a embedding matrix will be used."
+ ),
+ )
+ # ------------Dataloader Arguments-----------
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "The number of subprocesses to use for data loading. 0 means that the data will be loaded in the main"
+ " process."
+ ),
+ )
+ # ------------Training Arguments-----------
+ # ----General Training Arguments----
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="ddpm-model-64",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument("--overwrite_output_dir", action="store_true")
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ # ----Batch Size and Training Length----
+ parser.add_argument(
+ "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=100)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--max_train_samples",
+ type=int,
+ default=None,
+ help=(
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ ),
+ )
+ # ----Learning Rate----
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="cosine",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ # ----Optimizer (Adam) Arguments----
+ parser.add_argument(
+ "--optimizer_type",
+ type=str,
+ default="adamw",
+ help=(
+ "The optimizer algorithm to use for training. Choose between `radam` and `adamw`. The iCT paper uses"
+ " RAdam."
+ ),
+ )
+ parser.add_argument(
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.95, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument(
+ "--adam_weight_decay", type=float, default=1e-6, help="Weight decay magnitude for the Adam optimizer."
+ )
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer.")
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ # ----Consistency Training (CT) Specific Arguments----
+ parser.add_argument(
+ "--prediction_type",
+ type=str,
+ default="sample",
+ choices=["sample"],
+ help="Whether the model should predict the 'epsilon'/noise error or directly the reconstructed image 'x0'.",
+ )
+ parser.add_argument("--ddpm_num_steps", type=int, default=1000)
+ parser.add_argument("--ddpm_num_inference_steps", type=int, default=1000)
+ parser.add_argument("--ddpm_beta_schedule", type=str, default="linear")
+ parser.add_argument(
+ "--sigma_min",
+ type=float,
+ default=0.002,
+ help=(
+ "The lower boundary for the timestep discretization, which should be set to a small positive value close"
+ " to zero to avoid numerical issues when solving the PF-ODE backwards in time."
+ ),
+ )
+ parser.add_argument(
+ "--sigma_max",
+ type=float,
+ default=80.0,
+ help=(
+ "The upper boundary for the timestep discretization, which also determines the variance of the Gaussian"
+ " prior."
+ ),
+ )
+ parser.add_argument(
+ "--rho",
+ type=float,
+ default=7.0,
+ help="The rho parameter for the Karras sigmas timestep dicretization.",
+ )
+ parser.add_argument(
+ "--huber_c",
+ type=float,
+ default=None,
+ help=(
+ "The Pseudo-Huber loss parameter c. If not set, this will default to the value recommended in the Improved"
+ " Consistency Training (iCT) paper of 0.00054 * sqrt(d), where d is the data dimensionality."
+ ),
+ )
+ parser.add_argument(
+ "--discretization_s_0",
+ type=int,
+ default=10,
+ help=(
+ "The s_0 parameter in the discretization curriculum N(k). This controls the number of training steps after"
+ " which the number of discretization steps N will be doubled."
+ ),
+ )
+ parser.add_argument(
+ "--discretization_s_1",
+ type=int,
+ default=1280,
+ help=(
+ "The s_1 parameter in the discretization curriculum N(k). This controls the upper limit to the number of"
+ " discretization steps used. Increasing this value will reduce the bias at the cost of higher variance."
+ ),
+ )
+ parser.add_argument(
+ "--constant_discretization_steps",
+ action="store_true",
+ help=(
+ "Whether to set the discretization curriculum N(k) to be the constant value `discretization_s_0 + 1`. This"
+ " is useful for testing when `max_number_steps` is small, when `k_prime` would otherwise be 0, causing"
+ " a divide-by-zero error."
+ ),
+ )
+ parser.add_argument(
+ "--p_mean",
+ type=float,
+ default=-1.1,
+ help=(
+ "The mean parameter P_mean for the (discretized) lognormal noise schedule, which controls the probability"
+ " of sampling a (discrete) noise level sigma_i."
+ ),
+ )
+ parser.add_argument(
+ "--p_std",
+ type=float,
+ default=2.0,
+ help=(
+ "The standard deviation parameter P_std for the (discretized) noise schedule, which controls the"
+ " probability of sampling a (discrete) noise level sigma_i."
+ ),
+ )
+ parser.add_argument(
+ "--noise_precond_type",
+ type=str,
+ default="cm",
+ help=(
+ "The noise preconditioning function to use for transforming the raw Karras sigmas into the timestep"
+ " argument of the U-Net. Choose between `none` (the identity function), `edm`, and `cm`."
+ ),
+ )
+ parser.add_argument(
+ "--input_precond_type",
+ type=str,
+ default="cm",
+ help=(
+ "The input preconditioning function to use for scaling the image input of the U-Net. Choose between `none`"
+ " (a scaling factor of 1) and `cm`."
+ ),
+ )
+ parser.add_argument(
+ "--skip_steps",
+ type=int,
+ default=1,
+ help=(
+ "The gap in indices between the student and teacher noise levels. In the iCT paper this is always set to"
+ " 1, but theoretically this could be greater than 1 and/or altered according to a curriculum throughout"
+ " training, much like the number of discretization steps is."
+ ),
+ )
+ parser.add_argument(
+ "--cast_teacher",
+ action="store_true",
+ help="Whether to cast the teacher U-Net model to `weight_dtype` or leave it in full precision.",
+ )
+ # ----Exponential Moving Average (EMA) Arguments----
+ parser.add_argument(
+ "--use_ema",
+ action="store_true",
+ help="Whether to use Exponential Moving Average for the final model weights.",
+ )
+ parser.add_argument(
+ "--ema_min_decay",
+ type=float,
+ default=None,
+ help=(
+ "The minimum decay magnitude for EMA. If not set, this will default to the value of `ema_max_decay`,"
+ " resulting in a constant EMA decay rate."
+ ),
+ )
+ parser.add_argument(
+ "--ema_max_decay",
+ type=float,
+ default=0.99993,
+ help=(
+ "The maximum decay magnitude for EMA. Setting `ema_min_decay` equal to this value will result in a"
+ " constant decay rate."
+ ),
+ )
+ parser.add_argument(
+ "--use_ema_warmup",
+ action="store_true",
+ help="Whether to use EMA warmup.",
+ )
+ parser.add_argument("--ema_inv_gamma", type=float, default=1.0, help="The inverse gamma value for the EMA decay.")
+ parser.add_argument("--ema_power", type=float, default=3 / 4, help="The power value for the EMA decay.")
+ # ----Training Optimization Arguments----
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default="no",
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose"
+ "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
+ "and an Nvidia Ampere GPU."
+ ),
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
+ )
+ # ----Distributed Training Arguments----
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+ # ------------Validation Arguments-----------
+ parser.add_argument(
+ "--validation_steps",
+ type=int,
+ default=200,
+ help="Run validation every X steps.",
+ )
+ parser.add_argument(
+ "--eval_batch_size",
+ type=int,
+ default=16,
+ help=(
+ "The number of images to generate for evaluation. Note that if `class_conditional` and `num_classes` is"
+ " set the effective number of images generated per evaluation step is `eval_batch_size * num_classes`."
+ ),
+ )
+ parser.add_argument("--save_images_epochs", type=int, default=10, help="How often to save images during training.")
+ # ------------Validation Arguments-----------
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--save_model_epochs", type=int, default=10, help="How often to save the model during training."
+ )
+ # ------------Logging Arguments-----------
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ # ------------HuggingFace Hub Arguments-----------
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--hub_private_repo", action="store_true", help="Whether or not to create a private repository."
+ )
+ # ------------Accelerate Arguments-----------
+ parser.add_argument(
+ "--tracker_project_name",
+ type=str,
+ default="consistency-training",
+ help=(
+ "The `project_name` argument passed to Accelerator.init_trackers for"
+ " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
+ ),
+ )
+
+ args = parser.parse_args()
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ if args.dataset_name is None and args.train_data_dir is None:
+ raise ValueError("You must specify either a dataset name from the hub or a train data directory.")
+
+ return args
+
+
+def main(args):
+ logging_dir = os.path.join(args.output_dir, args.logging_dir)
+
+ if args.report_to == "wandb" and args.hub_token is not None:
+ raise ValueError(
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
+ " Please use `huggingface-cli login` to authenticate with the Hub."
+ )
+
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+
+ kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=7200)) # a big number for high resolution or big dataset
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ kwargs_handlers=[kwargs],
+ )
+
+ if args.report_to == "tensorboard":
+ if not is_tensorboard_available():
+ raise ImportError("Make sure to install tensorboard if you want to use it for logging during training.")
+
+ elif args.report_to == "wandb":
+ if not is_wandb_available():
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ datasets.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ datasets.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
+ ).repo_id
+
+ # 1. Initialize the noise scheduler.
+ initial_discretization_steps = get_discretization_steps(
+ 0,
+ args.max_train_steps,
+ s_0=args.discretization_s_0,
+ s_1=args.discretization_s_1,
+ constant=args.constant_discretization_steps,
+ )
+ noise_scheduler = CMStochasticIterativeScheduler(
+ num_train_timesteps=initial_discretization_steps,
+ sigma_min=args.sigma_min,
+ sigma_max=args.sigma_max,
+ rho=args.rho,
+ )
+
+ # 2. Initialize the student U-Net model.
+ if args.pretrained_model_name_or_path is not None:
+ logger.info(f"Loading pretrained U-Net weights from {args.pretrained_model_name_or_path}... ")
+ unet = UNet2DModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
+ )
+ elif args.model_config_name_or_path is None:
+ # TODO: use default architectures from iCT paper
+ if not args.class_conditional and (args.num_classes is not None or args.class_embed_type is not None):
+ logger.warning(
+ f"`--class_conditional` is set to `False` but `--num_classes` is set to {args.num_classes} and"
+ f" `--class_embed_type` is set to {args.class_embed_type}. These values will be overridden to `None`."
+ )
+ args.num_classes = None
+ args.class_embed_type = None
+ elif args.class_conditional and args.num_classes is None and args.class_embed_type is None:
+ logger.warning(
+ "`--class_conditional` is set to `True` but neither `--num_classes` nor `--class_embed_type` is set."
+ "`class_conditional` will be overridden to `False`."
+ )
+ args.class_conditional = False
+ unet = UNet2DModel(
+ sample_size=args.resolution,
+ in_channels=3,
+ out_channels=3,
+ layers_per_block=2,
+ block_out_channels=(128, 128, 256, 256, 512, 512),
+ down_block_types=(
+ "DownBlock2D",
+ "DownBlock2D",
+ "DownBlock2D",
+ "DownBlock2D",
+ "AttnDownBlock2D",
+ "DownBlock2D",
+ ),
+ up_block_types=(
+ "UpBlock2D",
+ "AttnUpBlock2D",
+ "UpBlock2D",
+ "UpBlock2D",
+ "UpBlock2D",
+ "UpBlock2D",
+ ),
+ class_embed_type=args.class_embed_type,
+ num_class_embeds=args.num_classes,
+ )
+ else:
+ config = UNet2DModel.load_config(args.model_config_name_or_path)
+ unet = UNet2DModel.from_config(config)
+ unet.train()
+
+ # Create EMA for the student U-Net model.
+ if args.use_ema:
+ if args.ema_min_decay is None:
+ args.ema_min_decay = args.ema_max_decay
+ ema_unet = EMAModel(
+ unet.parameters(),
+ decay=args.ema_max_decay,
+ min_decay=args.ema_min_decay,
+ use_ema_warmup=args.use_ema_warmup,
+ inv_gamma=args.ema_inv_gamma,
+ power=args.ema_power,
+ model_cls=UNet2DModel,
+ model_config=unet.config,
+ )
+
+ # 3. Initialize the teacher U-Net model from the student U-Net model.
+ # Note that following the improved Consistency Training paper, the teacher U-Net is not updated via EMA (e.g. the
+ # EMA decay rate is 0.)
+ teacher_unet = UNet2DModel.from_config(unet.config)
+ teacher_unet.load_state_dict(unet.state_dict())
+ teacher_unet.train()
+ teacher_unet.requires_grad_(False)
+
+ # 4. Handle mixed precision and device placement
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ args.mixed_precision = accelerator.mixed_precision
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+ args.mixed_precision = accelerator.mixed_precision
+
+ # Cast teacher_unet to weight_dtype if cast_teacher is set.
+ if args.cast_teacher:
+ teacher_dtype = weight_dtype
+ else:
+ teacher_dtype = torch.float32
+
+ teacher_unet.to(accelerator.device)
+ if args.use_ema:
+ ema_unet.to(accelerator.device)
+
+ # 5. Handle saving and loading of checkpoints.
+ # `accelerate` 0.16.0 will have better support for customized saving
+ if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
+ def save_model_hook(models, weights, output_dir):
+ if accelerator.is_main_process:
+ teacher_unet.save_pretrained(os.path.join(output_dir, "unet_teacher"))
+ if args.use_ema:
+ ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema"))
+
+ for i, model in enumerate(models):
+ model.save_pretrained(os.path.join(output_dir, "unet"))
+
+ # make sure to pop weight so that corresponding model is not saved again
+ weights.pop()
+
+ def load_model_hook(models, input_dir):
+ load_model = UNet2DModel.from_pretrained(os.path.join(input_dir, "unet_teacher"))
+ teacher_unet.load_state_dict(load_model.state_dict())
+ teacher_unet.to(accelerator.device)
+ del load_model
+
+ if args.use_ema:
+ load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), UNet2DModel)
+ ema_unet.load_state_dict(load_model.state_dict())
+ ema_unet.to(accelerator.device)
+ del load_model
+
+ for i in range(len(models)):
+ # pop models so that they are not loaded again
+ model = models.pop()
+
+ # load diffusers style into model
+ load_model = UNet2DModel.from_pretrained(input_dir, subfolder="unet")
+ model.register_to_config(**load_model.config)
+
+ model.load_state_dict(load_model.state_dict())
+ del load_model
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
+
+ # 6. Enable optimizations
+ if args.enable_xformers_memory_efficient_attention:
+ if is_xformers_available():
+ import xformers
+
+ xformers_version = version.parse(xformers.__version__)
+ if xformers_version == version.parse("0.0.16"):
+ logger.warning(
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
+ )
+ unet.enable_xformers_memory_efficient_attention()
+ teacher_unet.enable_xformers_memory_efficient_attention()
+ if args.use_ema:
+ ema_unet.enable_xformers_memory_efficient_attention()
+ else:
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.gradient_checkpointing:
+ unet.enable_gradient_checkpointing()
+
+ if args.optimizer_type == "radam":
+ optimizer_class = torch.optim.RAdam
+ elif args.optimizer_type == "adamw":
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model for 16GB GPUs
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
+ )
+
+ optimizer_class = bnb.optim.AdamW8bit
+ else:
+ optimizer_class = torch.optim.AdamW
+ else:
+ raise ValueError(
+ f"Optimizer type {args.optimizer_type} is not supported. Currently supported optimizer types are `radam`"
+ f" and `adamw`."
+ )
+
+ # 7. Initialize the optimizer
+ optimizer = optimizer_class(
+ unet.parameters(),
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ # 8. Dataset creation and data preprocessing
+ # Get the datasets: you can either provide your own training and evaluation files (see below)
+ # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
+
+ # In distributed training, the load_dataset function guarantees that only one local process can concurrently
+ # download the dataset.
+ if args.dataset_name is not None:
+ dataset = load_dataset(
+ args.dataset_name,
+ args.dataset_config_name,
+ cache_dir=args.cache_dir,
+ split="train",
+ )
+ else:
+ dataset = load_dataset("imagefolder", data_dir=args.train_data_dir, cache_dir=args.cache_dir, split="train")
+ # See more about loading custom images at
+ # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
+
+ # Preprocessing the datasets and DataLoaders creation.
+ interpolation_mode = resolve_interpolation_mode(args.interpolation_type)
+ augmentations = transforms.Compose(
+ [
+ transforms.Resize(args.resolution, interpolation=interpolation_mode),
+ transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
+ transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+ )
+
+ def transform_images(examples):
+ images = [augmentations(image.convert("RGB")) for image in examples[args.dataset_image_column_name]]
+ batch_dict = {"images": images}
+ if args.class_conditional:
+ batch_dict["class_labels"] = examples[args.dataset_class_label_column_name]
+ return batch_dict
+
+ logger.info(f"Dataset size: {len(dataset)}")
+
+ dataset.set_transform(transform_images)
+ train_dataloader = torch.utils.data.DataLoader(
+ dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers
+ )
+
+ # 9. Initialize the learning rate scheduler
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps,
+ num_training_steps=args.max_train_steps,
+ )
+
+ # 10. Prepare for training
+ # Prepare everything with our `accelerator`.
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ unet, optimizer, train_dataloader, lr_scheduler
+ )
+
+ def recalculate_num_discretization_step_values(discretization_steps, skip_steps):
+ """
+ Recalculates all quantities depending on the number of discretization steps N.
+ """
+ noise_scheduler = CMStochasticIterativeScheduler(
+ num_train_timesteps=discretization_steps,
+ sigma_min=args.sigma_min,
+ sigma_max=args.sigma_max,
+ rho=args.rho,
+ )
+ current_timesteps = get_karras_sigmas(discretization_steps, args.sigma_min, args.sigma_max, args.rho)
+ valid_teacher_timesteps_plus_one = current_timesteps[: len(current_timesteps) - skip_steps + 1]
+ # timestep_weights are the unnormalized probabilities of sampling the timestep/noise level at each index
+ timestep_weights = get_discretized_lognormal_weights(
+ valid_teacher_timesteps_plus_one, p_mean=args.p_mean, p_std=args.p_std
+ )
+ # timestep_loss_weights is the timestep-dependent loss weighting schedule lambda(sigma_i)
+ timestep_loss_weights = get_loss_weighting_schedule(valid_teacher_timesteps_plus_one)
+
+ current_timesteps = current_timesteps.to(accelerator.device)
+ timestep_weights = timestep_weights.to(accelerator.device)
+ timestep_loss_weights = timestep_loss_weights.to(accelerator.device)
+
+ return noise_scheduler, current_timesteps, timestep_weights, timestep_loss_weights
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ tracker_config = dict(vars(args))
+ accelerator.init_trackers(args.tracker_project_name, config=tracker_config)
+
+ # Function for unwraping if torch.compile() was used in accelerate.
+ def unwrap_model(model):
+ model = accelerator.unwrap_model(model)
+ model = model._orig_mod if is_compiled_module(model) else model
+ return model
+
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(dataset)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the most recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ initial_global_step = 0
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ initial_global_step = global_step
+ first_epoch = global_step // num_update_steps_per_epoch
+ else:
+ initial_global_step = 0
+
+ # Resolve the c parameter for the Pseudo-Huber loss
+ if args.huber_c is None:
+ args.huber_c = 0.00054 * args.resolution * math.sqrt(unwrap_model(unet).config.in_channels)
+
+ # Get current number of discretization steps N according to our discretization curriculum
+ current_discretization_steps = get_discretization_steps(
+ initial_global_step,
+ args.max_train_steps,
+ s_0=args.discretization_s_0,
+ s_1=args.discretization_s_1,
+ constant=args.constant_discretization_steps,
+ )
+ current_skip_steps = get_skip_steps(initial_global_step, initial_skip=args.skip_steps)
+ if current_skip_steps >= current_discretization_steps:
+ raise ValueError(
+ f"The current skip steps is {current_skip_steps}, but should be smaller than the current number of"
+ f" discretization steps {current_discretization_steps}"
+ )
+ # Recalculate all quantities depending on the number of discretization steps N
+ (
+ noise_scheduler,
+ current_timesteps,
+ timestep_weights,
+ timestep_loss_weights,
+ ) = recalculate_num_discretization_step_values(current_discretization_steps, current_skip_steps)
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
+
+ # 11. Train!
+ for epoch in range(first_epoch, args.num_train_epochs):
+ unet.train()
+ for step, batch in enumerate(train_dataloader):
+ # 1. Get batch of images from dataloader (sample x ~ p_data(x))
+ clean_images = batch["images"].to(weight_dtype)
+ if args.class_conditional:
+ class_labels = batch["class_labels"]
+ else:
+ class_labels = None
+ bsz = clean_images.shape[0]
+
+ # 2. Sample a random timestep for each image according to the noise schedule.
+ # Sample random indices i ~ p(i), where p(i) is the dicretized lognormal distribution in the iCT paper
+ # NOTE: timestep_indices should be in the range [0, len(current_timesteps) - k - 1] inclusive
+ timestep_indices = torch.multinomial(timestep_weights, bsz, replacement=True).long()
+ teacher_timesteps = current_timesteps[timestep_indices]
+ student_timesteps = current_timesteps[timestep_indices + current_skip_steps]
+
+ # 3. Sample noise and add it to the clean images for both teacher and student unets
+ # Sample noise z ~ N(0, I) that we'll add to the images
+ noise = torch.randn(clean_images.shape, dtype=weight_dtype, device=clean_images.device)
+ # Add noise to the clean images according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ teacher_noisy_images = add_noise(clean_images, noise, teacher_timesteps)
+ student_noisy_images = add_noise(clean_images, noise, student_timesteps)
+
+ # 4. Calculate preconditioning and scalings for boundary conditions for the consistency model.
+ teacher_rescaled_timesteps = get_noise_preconditioning(teacher_timesteps, args.noise_precond_type)
+ student_rescaled_timesteps = get_noise_preconditioning(student_timesteps, args.noise_precond_type)
+
+ c_in_teacher = get_input_preconditioning(teacher_timesteps, input_precond_type=args.input_precond_type)
+ c_in_student = get_input_preconditioning(student_timesteps, input_precond_type=args.input_precond_type)
+
+ c_skip_teacher, c_out_teacher = scalings_for_boundary_conditions(teacher_timesteps)
+ c_skip_student, c_out_student = scalings_for_boundary_conditions(student_timesteps)
+
+ c_skip_teacher, c_out_teacher, c_in_teacher = [
+ append_dims(x, clean_images.ndim) for x in [c_skip_teacher, c_out_teacher, c_in_teacher]
+ ]
+ c_skip_student, c_out_student, c_in_student = [
+ append_dims(x, clean_images.ndim) for x in [c_skip_student, c_out_student, c_in_student]
+ ]
+
+ with accelerator.accumulate(unet):
+ # 5. Get the student unet denoising prediction on the student timesteps
+ # Get rng state now to ensure that dropout is synced between the student and teacher models.
+ dropout_state = torch.get_rng_state()
+ student_model_output = unet(
+ c_in_student * student_noisy_images, student_rescaled_timesteps, class_labels=class_labels
+ ).sample
+ # NOTE: currently only support prediction_type == sample, so no need to convert model_output
+ student_denoise_output = c_skip_student * student_noisy_images + c_out_student * student_model_output
+
+ # 6. Get the teacher unet denoising prediction on the teacher timesteps
+ with torch.no_grad(), torch.autocast("cuda", dtype=teacher_dtype):
+ torch.set_rng_state(dropout_state)
+ teacher_model_output = teacher_unet(
+ c_in_teacher * teacher_noisy_images, teacher_rescaled_timesteps, class_labels=class_labels
+ ).sample
+ # NOTE: currently only support prediction_type == sample, so no need to convert model_output
+ teacher_denoise_output = (
+ c_skip_teacher * teacher_noisy_images + c_out_teacher * teacher_model_output
+ )
+
+ # 7. Calculate the weighted Pseudo-Huber loss
+ if args.prediction_type == "sample":
+ # Note that the loss weights should be those at the (teacher) timestep indices.
+ lambda_t = _extract_into_tensor(
+ timestep_loss_weights, timestep_indices, (bsz,) + (1,) * (clean_images.ndim - 1)
+ )
+ loss = lambda_t * (
+ torch.sqrt(
+ (student_denoise_output.float() - teacher_denoise_output.float()) ** 2 + args.huber_c**2
+ )
+ - args.huber_c
+ )
+ loss = loss.mean()
+ else:
+ raise ValueError(
+ f"Unsupported prediction type: {args.prediction_type}. Currently, only `sample` is supported."
+ )
+
+ # 8. Backpropagate on the consistency training loss
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ # 9. Update teacher_unet and ema_unet parameters using unet's parameters.
+ teacher_unet.load_state_dict(unet.state_dict())
+ if args.use_ema:
+ ema_unet.step(unet.parameters())
+ progress_bar.update(1)
+ global_step += 1
+
+ if accelerator.is_main_process:
+ # 10. Recalculate quantities depending on the global step, if necessary.
+ new_discretization_steps = get_discretization_steps(
+ global_step,
+ args.max_train_steps,
+ s_0=args.discretization_s_0,
+ s_1=args.discretization_s_1,
+ constant=args.constant_discretization_steps,
+ )
+ current_skip_steps = get_skip_steps(global_step, initial_skip=args.skip_steps)
+ if current_skip_steps >= new_discretization_steps:
+ raise ValueError(
+ f"The current skip steps is {current_skip_steps}, but should be smaller than the current"
+ f" number of discretization steps {new_discretization_steps}."
+ )
+ if new_discretization_steps != current_discretization_steps:
+ (
+ noise_scheduler,
+ current_timesteps,
+ timestep_weights,
+ timestep_loss_weights,
+ ) = recalculate_num_discretization_step_values(new_discretization_steps, current_skip_steps)
+ current_discretization_steps = new_discretization_steps
+
+ if global_step % args.checkpointing_steps == 0:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ if global_step % args.validation_steps == 0:
+ # NOTE: since we do not use EMA for the teacher model, the teacher parameters and student
+ # parameters are the same at this point in time
+ log_validation(unet, noise_scheduler, args, accelerator, weight_dtype, global_step, "teacher")
+ # teacher_unet.to(dtype=teacher_dtype)
+
+ if args.use_ema:
+ # Store the student unet weights and load the EMA weights.
+ ema_unet.store(unet.parameters())
+ ema_unet.copy_to(unet.parameters())
+
+ log_validation(
+ unet,
+ noise_scheduler,
+ args,
+ accelerator,
+ weight_dtype,
+ global_step,
+ "ema_student",
+ )
+
+ # Restore student unet weights
+ ema_unet.restore(unet.parameters())
+
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step}
+ if args.use_ema:
+ logs["ema_decay"] = ema_unet.cur_decay_value
+ progress_bar.set_postfix(**logs)
+ accelerator.log(logs, step=global_step)
+
+ if global_step >= args.max_train_steps:
+ break
+ # progress_bar.close()
+
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ unet = unwrap_model(unet)
+ pipeline = ConsistencyModelPipeline(unet=unet, scheduler=noise_scheduler)
+ pipeline.save_pretrained(args.output_dir)
+
+ # If using EMA, save EMA weights as well.
+ if args.use_ema:
+ ema_unet.copy_to(unet.parameters())
+
+ unet.save_pretrained(os.path.join(args.output_dir, "ema_unet"))
+
+ if args.push_to_hub:
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ main(args)
diff --git a/diffusers/examples/research_projects/controlnet/train_controlnet_webdataset.py b/diffusers/examples/research_projects/controlnet/train_controlnet_webdataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..88a5d93d8edfc3044d3f325d52b6352d8d7b2dbf
--- /dev/null
+++ b/diffusers/examples/research_projects/controlnet/train_controlnet_webdataset.py
@@ -0,0 +1,1472 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+import argparse
+import functools
+import gc
+import itertools
+import json
+import logging
+import math
+import os
+import random
+import shutil
+from pathlib import Path
+from typing import List, Optional, Union
+
+import accelerate
+import cv2
+import numpy as np
+import torch
+import torch.utils.checkpoint
+import transformers
+import webdataset as wds
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import ProjectConfiguration, set_seed
+from braceexpand import braceexpand
+from huggingface_hub import create_repo, upload_folder
+from packaging import version
+from PIL import Image
+from torch.utils.data import default_collate
+from torchvision import transforms
+from tqdm.auto import tqdm
+from transformers import AutoTokenizer, DPTForDepthEstimation, DPTImageProcessor, PretrainedConfig
+from webdataset.tariterators import (
+ base_plus_ext,
+ tar_file_expander,
+ url_opener,
+ valid_sample,
+)
+
+import diffusers
+from diffusers import (
+ AutoencoderKL,
+ ControlNetModel,
+ EulerDiscreteScheduler,
+ StableDiffusionXLControlNetPipeline,
+ UNet2DConditionModel,
+)
+from diffusers.optimization import get_scheduler
+from diffusers.utils import check_min_version, is_wandb_available
+from diffusers.utils.import_utils import is_xformers_available
+
+
+MAX_SEQ_LENGTH = 77
+
+if is_wandb_available():
+ import wandb
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.18.0.dev0")
+
+logger = get_logger(__name__)
+
+
+def filter_keys(key_set):
+ def _f(dictionary):
+ return {k: v for k, v in dictionary.items() if k in key_set}
+
+ return _f
+
+
+def group_by_keys_nothrow(data, keys=base_plus_ext, lcase=True, suffixes=None, handler=None):
+ """Return function over iterator that groups key, value pairs into samples.
+
+ :param keys: function that splits the key into key and extension (base_plus_ext)
+ :param lcase: convert suffixes to lower case (Default value = True)
+ """
+ current_sample = None
+ for filesample in data:
+ assert isinstance(filesample, dict)
+ fname, value = filesample["fname"], filesample["data"]
+ prefix, suffix = keys(fname)
+ if prefix is None:
+ continue
+ if lcase:
+ suffix = suffix.lower()
+ # FIXME webdataset version throws if suffix in current_sample, but we have a potential for
+ # this happening in the current LAION400m dataset if a tar ends with same prefix as the next
+ # begins, rare, but can happen since prefix aren't unique across tar files in that dataset
+ if current_sample is None or prefix != current_sample["__key__"] or suffix in current_sample:
+ if valid_sample(current_sample):
+ yield current_sample
+ current_sample = {"__key__": prefix, "__url__": filesample["__url__"]}
+ if suffixes is None or suffix in suffixes:
+ current_sample[suffix] = value
+ if valid_sample(current_sample):
+ yield current_sample
+
+
+def tarfile_to_samples_nothrow(src, handler=wds.warn_and_continue):
+ # NOTE this is a re-impl of the webdataset impl with group_by_keys that doesn't throw
+ streams = url_opener(src, handler=handler)
+ files = tar_file_expander(streams, handler=handler)
+ samples = group_by_keys_nothrow(files, handler=handler)
+ return samples
+
+
+def control_transform(image):
+ image = np.array(image)
+
+ low_threshold = 100
+ high_threshold = 200
+
+ image = cv2.Canny(image, low_threshold, high_threshold)
+ image = image[:, :, None]
+ image = np.concatenate([image, image, image], axis=2)
+ control_image = Image.fromarray(image)
+ return control_image
+
+
+def canny_image_transform(example, resolution=1024):
+ image = example["image"]
+ image = transforms.Resize(resolution, interpolation=transforms.InterpolationMode.BILINEAR)(image)
+ # get crop coordinates
+ c_top, c_left, _, _ = transforms.RandomCrop.get_params(image, output_size=(resolution, resolution))
+ image = transforms.functional.crop(image, c_top, c_left, resolution, resolution)
+ control_image = control_transform(image)
+
+ image = transforms.ToTensor()(image)
+ image = transforms.Normalize([0.5], [0.5])(image)
+ control_image = transforms.ToTensor()(control_image)
+
+ example["image"] = image
+ example["control_image"] = control_image
+ example["crop_coords"] = (c_top, c_left)
+
+ return example
+
+
+def depth_image_transform(example, feature_extractor, resolution=1024):
+ image = example["image"]
+ image = transforms.Resize(resolution, interpolation=transforms.InterpolationMode.BILINEAR)(image)
+ # get crop coordinates
+ c_top, c_left, _, _ = transforms.RandomCrop.get_params(image, output_size=(resolution, resolution))
+ image = transforms.functional.crop(image, c_top, c_left, resolution, resolution)
+
+ control_image = feature_extractor(images=image, return_tensors="pt").pixel_values.squeeze(0)
+
+ image = transforms.ToTensor()(image)
+ image = transforms.Normalize([0.5], [0.5])(image)
+
+ example["image"] = image
+ example["control_image"] = control_image
+ example["crop_coords"] = (c_top, c_left)
+
+ return example
+
+
+class WebdatasetFilter:
+ def __init__(self, min_size=1024, max_pwatermark=0.5):
+ self.min_size = min_size
+ self.max_pwatermark = max_pwatermark
+
+ def __call__(self, x):
+ try:
+ if "json" in x:
+ x_json = json.loads(x["json"])
+ filter_size = (x_json.get("original_width", 0.0) or 0.0) >= self.min_size and x_json.get(
+ "original_height", 0
+ ) >= self.min_size
+ filter_watermark = (x_json.get("pwatermark", 1.0) or 1.0) <= self.max_pwatermark
+ return filter_size and filter_watermark
+ else:
+ return False
+ except Exception:
+ return False
+
+
+class Text2ImageDataset:
+ def __init__(
+ self,
+ train_shards_path_or_url: Union[str, List[str]],
+ eval_shards_path_or_url: Union[str, List[str]],
+ num_train_examples: int,
+ per_gpu_batch_size: int,
+ global_batch_size: int,
+ num_workers: int,
+ resolution: int = 256,
+ center_crop: bool = True,
+ random_flip: bool = False,
+ shuffle_buffer_size: int = 1000,
+ pin_memory: bool = False,
+ persistent_workers: bool = False,
+ control_type: str = "canny",
+ feature_extractor: Optional[DPTImageProcessor] = None,
+ ):
+ if not isinstance(train_shards_path_or_url, str):
+ train_shards_path_or_url = [list(braceexpand(urls)) for urls in train_shards_path_or_url]
+ # flatten list using itertools
+ train_shards_path_or_url = list(itertools.chain.from_iterable(train_shards_path_or_url))
+
+ if not isinstance(eval_shards_path_or_url, str):
+ eval_shards_path_or_url = [list(braceexpand(urls)) for urls in eval_shards_path_or_url]
+ # flatten list using itertools
+ eval_shards_path_or_url = list(itertools.chain.from_iterable(eval_shards_path_or_url))
+
+ def get_orig_size(json):
+ return (int(json.get("original_width", 0.0)), int(json.get("original_height", 0.0)))
+
+ if control_type == "canny":
+ image_transform = functools.partial(canny_image_transform, resolution=resolution)
+ elif control_type == "depth":
+ image_transform = functools.partial(
+ depth_image_transform, feature_extractor=feature_extractor, resolution=resolution
+ )
+
+ processing_pipeline = [
+ wds.decode("pil", handler=wds.ignore_and_continue),
+ wds.rename(
+ image="jpg;png;jpeg;webp",
+ control_image="jpg;png;jpeg;webp",
+ text="text;txt;caption",
+ orig_size="json",
+ handler=wds.warn_and_continue,
+ ),
+ wds.map(filter_keys({"image", "control_image", "text", "orig_size"})),
+ wds.map_dict(orig_size=get_orig_size),
+ wds.map(image_transform),
+ wds.to_tuple("image", "control_image", "text", "orig_size", "crop_coords"),
+ ]
+
+ # Create train dataset and loader
+ pipeline = [
+ wds.ResampledShards(train_shards_path_or_url),
+ tarfile_to_samples_nothrow,
+ wds.select(WebdatasetFilter(min_size=512)),
+ wds.shuffle(shuffle_buffer_size),
+ *processing_pipeline,
+ wds.batched(per_gpu_batch_size, partial=False, collation_fn=default_collate),
+ ]
+
+ num_worker_batches = math.ceil(num_train_examples / (global_batch_size * num_workers)) # per dataloader worker
+ num_batches = num_worker_batches * num_workers
+ num_samples = num_batches * global_batch_size
+
+ # each worker is iterating over this
+ self._train_dataset = wds.DataPipeline(*pipeline).with_epoch(num_worker_batches)
+ self._train_dataloader = wds.WebLoader(
+ self._train_dataset,
+ batch_size=None,
+ shuffle=False,
+ num_workers=num_workers,
+ pin_memory=pin_memory,
+ persistent_workers=persistent_workers,
+ )
+ # add meta-data to dataloader instance for convenience
+ self._train_dataloader.num_batches = num_batches
+ self._train_dataloader.num_samples = num_samples
+
+ # Create eval dataset and loader
+ pipeline = [
+ wds.SimpleShardList(eval_shards_path_or_url),
+ wds.split_by_worker,
+ wds.tarfile_to_samples(handler=wds.ignore_and_continue),
+ *processing_pipeline,
+ wds.batched(per_gpu_batch_size, partial=False, collation_fn=default_collate),
+ ]
+ self._eval_dataset = wds.DataPipeline(*pipeline)
+ self._eval_dataloader = wds.WebLoader(
+ self._eval_dataset,
+ batch_size=None,
+ shuffle=False,
+ num_workers=num_workers,
+ pin_memory=pin_memory,
+ persistent_workers=persistent_workers,
+ )
+
+ @property
+ def train_dataset(self):
+ return self._train_dataset
+
+ @property
+ def train_dataloader(self):
+ return self._train_dataloader
+
+ @property
+ def eval_dataset(self):
+ return self._eval_dataset
+
+ @property
+ def eval_dataloader(self):
+ return self._eval_dataloader
+
+
+def image_grid(imgs, rows, cols):
+ assert len(imgs) == rows * cols
+
+ w, h = imgs[0].size
+ grid = Image.new("RGB", size=(cols * w, rows * h))
+
+ for i, img in enumerate(imgs):
+ grid.paste(img, box=(i % cols * w, i // cols * h))
+ return grid
+
+
+def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step):
+ logger.info("Running validation... ")
+
+ controlnet = accelerator.unwrap_model(controlnet)
+
+ pipeline = StableDiffusionXLControlNetPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ vae=vae,
+ unet=unet,
+ controlnet=controlnet,
+ revision=args.revision,
+ torch_dtype=weight_dtype,
+ )
+ # pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config)
+ pipeline = pipeline.to(accelerator.device)
+ pipeline.set_progress_bar_config(disable=True)
+
+ if args.enable_xformers_memory_efficient_attention:
+ pipeline.enable_xformers_memory_efficient_attention()
+
+ if args.seed is None:
+ generator = None
+ else:
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
+
+ if len(args.validation_image) == len(args.validation_prompt):
+ validation_images = args.validation_image
+ validation_prompts = args.validation_prompt
+ elif len(args.validation_image) == 1:
+ validation_images = args.validation_image * len(args.validation_prompt)
+ validation_prompts = args.validation_prompt
+ elif len(args.validation_prompt) == 1:
+ validation_images = args.validation_image
+ validation_prompts = args.validation_prompt * len(args.validation_image)
+ else:
+ raise ValueError(
+ "number of `args.validation_image` and `args.validation_prompt` should be checked in `parse_args`"
+ )
+
+ image_logs = []
+
+ for validation_prompt, validation_image in zip(validation_prompts, validation_images):
+ validation_image = Image.open(validation_image).convert("RGB")
+ validation_image = validation_image.resize((args.resolution, args.resolution))
+
+ images = []
+
+ for _ in range(args.num_validation_images):
+ with torch.autocast("cuda"):
+ image = pipeline(
+ validation_prompt, image=validation_image, num_inference_steps=20, generator=generator
+ ).images[0]
+ images.append(image)
+
+ image_logs.append(
+ {"validation_image": validation_image, "images": images, "validation_prompt": validation_prompt}
+ )
+
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ for log in image_logs:
+ images = log["images"]
+ validation_prompt = log["validation_prompt"]
+ validation_image = log["validation_image"]
+
+ formatted_images = []
+
+ formatted_images.append(np.asarray(validation_image))
+
+ for image in images:
+ formatted_images.append(np.asarray(image))
+
+ formatted_images = np.stack(formatted_images)
+
+ tracker.writer.add_images(validation_prompt, formatted_images, step, dataformats="NHWC")
+ elif tracker.name == "wandb":
+ formatted_images = []
+
+ for log in image_logs:
+ images = log["images"]
+ validation_prompt = log["validation_prompt"]
+ validation_image = log["validation_image"]
+
+ formatted_images.append(wandb.Image(validation_image, caption="Controlnet conditioning"))
+
+ for image in images:
+ image = wandb.Image(image, caption=validation_prompt)
+ formatted_images.append(image)
+
+ tracker.log({"validation": formatted_images})
+ else:
+ logger.warning(f"image logging not implemented for {tracker.name}")
+
+ del pipeline
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ return image_logs
+
+
+def import_model_class_from_model_name_or_path(
+ pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
+):
+ text_encoder_config = PretrainedConfig.from_pretrained(
+ pretrained_model_name_or_path, subfolder=subfolder, revision=revision
+ )
+ model_class = text_encoder_config.architectures[0]
+
+ if model_class == "CLIPTextModel":
+ from transformers import CLIPTextModel
+
+ return CLIPTextModel
+ elif model_class == "CLIPTextModelWithProjection":
+ from transformers import CLIPTextModelWithProjection
+
+ return CLIPTextModelWithProjection
+ else:
+ raise ValueError(f"{model_class} is not supported.")
+
+
+def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=None):
+ img_str = ""
+ if image_logs is not None:
+ img_str = "You can find some example images below.\n"
+ for i, log in enumerate(image_logs):
+ images = log["images"]
+ validation_prompt = log["validation_prompt"]
+ validation_image = log["validation_image"]
+ validation_image.save(os.path.join(repo_folder, "image_control.png"))
+ img_str += f"prompt: {validation_prompt}\n"
+ images = [validation_image] + images
+ image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images_{i}.png"))
+ img_str += f"\n"
+
+ yaml = f"""
+---
+license: creativeml-openrail-m
+base_model: {base_model}
+tags:
+- stable-diffusion-xl
+- stable-diffusion-xl-diffusers
+- text-to-image
+- diffusers
+- controlnet
+- diffusers-training
+- webdataset
+inference: true
+---
+ """
+ model_card = f"""
+# controlnet-{repo_id}
+
+These are controlnet weights trained on {base_model} with new type of conditioning.
+{img_str}
+"""
+ with open(os.path.join(repo_folder, "README.md"), "w") as f:
+ f.write(yaml + model_card)
+
+
+def parse_args(input_args=None):
+ parser = argparse.ArgumentParser(description="Simple example of a ControlNet training script.")
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--pretrained_vae_model_name_or_path",
+ type=str,
+ default=None,
+ help="Path to an improved VAE to stabilize training. For more details check out: https://github.com/huggingface/diffusers/pull/4038.",
+ )
+ parser.add_argument(
+ "--controlnet_model_name_or_path",
+ type=str,
+ default=None,
+ help="Path to pretrained controlnet model or model identifier from huggingface.co/models."
+ " If not specified controlnet weights are initialized from unet.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help=(
+ "Revision of pretrained model identifier from huggingface.co/models. Trainable model components should be"
+ " float32 precision."
+ ),
+ )
+ parser.add_argument(
+ "--tokenizer_name",
+ type=str,
+ default=None,
+ help="Pretrained tokenizer name or path if not the same as model_name",
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="controlnet-model",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--crops_coords_top_left_h",
+ type=int,
+ default=0,
+ help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."),
+ )
+ parser.add_argument(
+ "--crops_coords_top_left_w",
+ type=int,
+ default=0,
+ help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."),
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=1)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. "
+ "In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference."
+ "Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components."
+ "See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step"
+ "instructions."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=3,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=5e-6,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--lr_num_cycles",
+ type=int,
+ default=1,
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
+ )
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
+ parser.add_argument(
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
+ )
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=1,
+ help=("Number of subprocesses to use for data loading."),
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument(
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
+ )
+ parser.add_argument(
+ "--set_grads_to_none",
+ action="store_true",
+ help=(
+ "Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain"
+ " behaviors, so disable this argument if it causes any problems. More info:"
+ " https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html"
+ ),
+ )
+ parser.add_argument(
+ "--train_shards_path_or_url",
+ type=str,
+ default=None,
+ help=(
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
+ " or to a folder containing files that 🤗 Datasets can understand."
+ ),
+ )
+ parser.add_argument(
+ "--eval_shards_path_or_url",
+ type=str,
+ default=None,
+ help="The config of the Dataset, leave as None if there's only one config.",
+ )
+ parser.add_argument(
+ "--train_data_dir",
+ type=str,
+ default=None,
+ help=(
+ "A folder containing the training data. Folder contents must follow the structure described in"
+ " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
+ " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
+ ),
+ )
+ parser.add_argument(
+ "--image_column", type=str, default="image", help="The column of the dataset containing the target image."
+ )
+ parser.add_argument(
+ "--conditioning_image_column",
+ type=str,
+ default="conditioning_image",
+ help="The column of the dataset containing the controlnet conditioning image.",
+ )
+ parser.add_argument(
+ "--caption_column",
+ type=str,
+ default="text",
+ help="The column of the dataset containing a caption or a list of captions.",
+ )
+ parser.add_argument(
+ "--max_train_samples",
+ type=int,
+ default=None,
+ help=(
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ ),
+ )
+ parser.add_argument(
+ "--proportion_empty_prompts",
+ type=float,
+ default=0,
+ help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).",
+ )
+ parser.add_argument(
+ "--validation_prompt",
+ type=str,
+ default=None,
+ nargs="+",
+ help=(
+ "A set of prompts evaluated every `--validation_steps` and logged to `--report_to`."
+ " Provide either a matching number of `--validation_image`s, a single `--validation_image`"
+ " to be used with all prompts, or a single prompt that will be used with all `--validation_image`s."
+ ),
+ )
+ parser.add_argument(
+ "--validation_image",
+ type=str,
+ default=None,
+ nargs="+",
+ help=(
+ "A set of paths to the controlnet conditioning image be evaluated every `--validation_steps`"
+ " and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a"
+ " a single `--validation_prompt` to be used with all `--validation_image`s, or a single"
+ " `--validation_image` that will be used with all `--validation_prompt`s."
+ ),
+ )
+ parser.add_argument(
+ "--num_validation_images",
+ type=int,
+ default=4,
+ help="Number of images to be generated for each `--validation_image`, `--validation_prompt` pair",
+ )
+ parser.add_argument(
+ "--validation_steps",
+ type=int,
+ default=100,
+ help=(
+ "Run validation every X steps. Validation consists of running the prompt"
+ " `args.validation_prompt` multiple times: `args.num_validation_images`"
+ " and logging the images."
+ ),
+ )
+ parser.add_argument(
+ "--tracker_project_name",
+ type=str,
+ default="sd_xl_train_controlnet",
+ help=(
+ "The `project_name` argument passed to Accelerator.init_trackers for"
+ " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
+ ),
+ )
+ parser.add_argument(
+ "--control_type",
+ type=str,
+ default="canny",
+ help=("The type of controlnet conditioning image to use. One of `canny`, `depth`" " Defaults to `canny`."),
+ )
+ parser.add_argument(
+ "--transformer_layers_per_block",
+ type=str,
+ default=None,
+ help=("The number of layers per block in the transformer. If None, defaults to" " `args.transformer_layers`."),
+ )
+ parser.add_argument(
+ "--old_style_controlnet",
+ action="store_true",
+ default=False,
+ help=(
+ "Use the old style controlnet, which is a single transformer layer with"
+ " a single head. Defaults to False."
+ ),
+ )
+
+ if input_args is not None:
+ args = parser.parse_args(input_args)
+ else:
+ args = parser.parse_args()
+
+ if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1:
+ raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].")
+
+ if args.validation_prompt is not None and args.validation_image is None:
+ raise ValueError("`--validation_image` must be set if `--validation_prompt` is set")
+
+ if args.validation_prompt is None and args.validation_image is not None:
+ raise ValueError("`--validation_prompt` must be set if `--validation_image` is set")
+
+ if (
+ args.validation_image is not None
+ and args.validation_prompt is not None
+ and len(args.validation_image) != 1
+ and len(args.validation_prompt) != 1
+ and len(args.validation_image) != len(args.validation_prompt)
+ ):
+ raise ValueError(
+ "Must provide either 1 `--validation_image`, 1 `--validation_prompt`,"
+ " or the same number of `--validation_prompt`s and `--validation_image`s"
+ )
+
+ if args.resolution % 8 != 0:
+ raise ValueError(
+ "`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the controlnet encoder."
+ )
+
+ return args
+
+
+# Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt
+def encode_prompt(prompt_batch, text_encoders, tokenizers, proportion_empty_prompts, is_train=True):
+ prompt_embeds_list = []
+
+ captions = []
+ for caption in prompt_batch:
+ if random.random() < proportion_empty_prompts:
+ captions.append("")
+ elif isinstance(caption, str):
+ captions.append(caption)
+ elif isinstance(caption, (list, np.ndarray)):
+ # take a random caption if there are multiple
+ captions.append(random.choice(caption) if is_train else caption[0])
+
+ with torch.no_grad():
+ for tokenizer, text_encoder in zip(tokenizers, text_encoders):
+ text_inputs = tokenizer(
+ captions,
+ padding="max_length",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ prompt_embeds = text_encoder(
+ text_input_ids.to(text_encoder.device),
+ output_hidden_states=True,
+ )
+
+ # We are only ALWAYS interested in the pooled output of the final text encoder
+ pooled_prompt_embeds = prompt_embeds[0]
+ prompt_embeds = prompt_embeds.hidden_states[-2]
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
+ prompt_embeds_list.append(prompt_embeds)
+
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
+ pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
+ return prompt_embeds, pooled_prompt_embeds
+
+
+def main(args):
+ if args.report_to == "wandb" and args.hub_token is not None:
+ raise ValueError(
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
+ " Please use `huggingface-cli login` to authenticate with the Hub."
+ )
+
+ logging_dir = Path(args.output_dir, args.logging_dir)
+
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ )
+
+ # Disable AMP for MPS.
+ if torch.backends.mps.is_available():
+ accelerator.native_amp = False
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name,
+ exist_ok=True,
+ token=args.hub_token,
+ private=True,
+ ).repo_id
+
+ # Load the tokenizers
+ tokenizer_one = AutoTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision, use_fast=False
+ )
+ tokenizer_two = AutoTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="tokenizer_2", revision=args.revision, use_fast=False
+ )
+
+ # import correct text encoder classes
+ text_encoder_cls_one = import_model_class_from_model_name_or_path(
+ args.pretrained_model_name_or_path, args.revision
+ )
+ text_encoder_cls_two = import_model_class_from_model_name_or_path(
+ args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2"
+ )
+
+ # Load scheduler and models
+ # noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
+ noise_scheduler = EulerDiscreteScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
+ text_encoder_one = text_encoder_cls_one.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
+ )
+ text_encoder_two = text_encoder_cls_two.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision
+ )
+ vae_path = (
+ args.pretrained_model_name_or_path
+ if args.pretrained_vae_model_name_or_path is None
+ else args.pretrained_vae_model_name_or_path
+ )
+ vae = AutoencoderKL.from_pretrained(
+ vae_path,
+ subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
+ revision=args.revision,
+ )
+ unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
+ )
+
+ if args.controlnet_model_name_or_path:
+ logger.info("Loading existing controlnet weights")
+ pre_controlnet = ControlNetModel.from_pretrained(args.controlnet_model_name_or_path)
+ else:
+ logger.info("Initializing controlnet weights from unet")
+ pre_controlnet = ControlNetModel.from_unet(unet)
+
+ if args.transformer_layers_per_block is not None:
+ transformer_layers_per_block = [int(x) for x in args.transformer_layers_per_block.split(",")]
+ down_block_types = ["DownBlock2D" if l == 0 else "CrossAttnDownBlock2D" for l in transformer_layers_per_block]
+ controlnet = ControlNetModel.from_config(
+ pre_controlnet.config,
+ down_block_types=down_block_types,
+ transformer_layers_per_block=transformer_layers_per_block,
+ )
+ controlnet.load_state_dict(pre_controlnet.state_dict(), strict=False)
+ del pre_controlnet
+ else:
+ controlnet = pre_controlnet
+
+ if args.control_type == "depth":
+ feature_extractor = DPTImageProcessor.from_pretrained("Intel/dpt-hybrid-midas")
+ depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas")
+ depth_model.requires_grad_(False)
+ else:
+ feature_extractor = None
+ depth_model = None
+
+ # `accelerate` 0.16.0 will have better support for customized saving
+ if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
+ def save_model_hook(models, weights, output_dir):
+ if accelerator.is_main_process:
+ i = len(weights) - 1
+
+ while len(weights) > 0:
+ weights.pop()
+ model = models[i]
+
+ sub_dir = "controlnet"
+ model.save_pretrained(os.path.join(output_dir, sub_dir))
+
+ i -= 1
+
+ def load_model_hook(models, input_dir):
+ while len(models) > 0:
+ # pop models so that they are not loaded again
+ model = models.pop()
+
+ # load diffusers style into model
+ load_model = ControlNetModel.from_pretrained(input_dir, subfolder="controlnet")
+ model.register_to_config(**load_model.config)
+
+ model.load_state_dict(load_model.state_dict())
+ del load_model
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
+
+ vae.requires_grad_(False)
+ unet.requires_grad_(False)
+ text_encoder_one.requires_grad_(False)
+ text_encoder_two.requires_grad_(False)
+ controlnet.train()
+
+ if args.enable_xformers_memory_efficient_attention:
+ if is_xformers_available():
+ import xformers
+
+ xformers_version = version.parse(xformers.__version__)
+ if xformers_version == version.parse("0.0.16"):
+ logger.warning(
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
+ )
+ unet.enable_xformers_memory_efficient_attention()
+ controlnet.enable_xformers_memory_efficient_attention()
+ else:
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
+
+ if args.gradient_checkpointing:
+ controlnet.enable_gradient_checkpointing()
+
+ # Check that all trainable models are in full precision
+ low_precision_error_string = (
+ " Please make sure to always have all model weights in full float32 precision when starting training - even if"
+ " doing mixed precision training, copy of the weights should still be float32."
+ )
+
+ if accelerator.unwrap_model(controlnet).dtype != torch.float32:
+ raise ValueError(
+ f"Controlnet loaded as datatype {accelerator.unwrap_model(controlnet).dtype}. {low_precision_error_string}"
+ )
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
+ )
+
+ optimizer_class = bnb.optim.AdamW8bit
+ else:
+ optimizer_class = torch.optim.AdamW
+
+ # Optimizer creation
+ params_to_optimize = controlnet.parameters()
+ optimizer = optimizer_class(
+ params_to_optimize,
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ # For mixed precision training we cast the text_encoder and vae weights to half-precision
+ # as these models are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ # Move vae, unet and text_encoder to device and cast to weight_dtype
+ # The VAE is in float32 to avoid NaN losses.
+ if args.pretrained_vae_model_name_or_path is not None:
+ vae.to(accelerator.device, dtype=weight_dtype)
+ else:
+ vae.to(accelerator.device, dtype=torch.float32)
+ unet.to(accelerator.device, dtype=weight_dtype)
+ text_encoder_one.to(accelerator.device, dtype=weight_dtype)
+ text_encoder_two.to(accelerator.device, dtype=weight_dtype)
+ if args.control_type == "depth":
+ depth_model.to(accelerator.device, dtype=weight_dtype)
+
+ # Here, we compute not just the text embeddings but also the additional embeddings
+ # needed for the SD XL UNet to operate.
+ def compute_embeddings(
+ prompt_batch, original_sizes, crop_coords, proportion_empty_prompts, text_encoders, tokenizers, is_train=True
+ ):
+ target_size = (args.resolution, args.resolution)
+ original_sizes = list(map(list, zip(*original_sizes)))
+ crops_coords_top_left = list(map(list, zip(*crop_coords)))
+
+ original_sizes = torch.tensor(original_sizes, dtype=torch.long)
+ crops_coords_top_left = torch.tensor(crops_coords_top_left, dtype=torch.long)
+
+ # crops_coords_top_left = (args.crops_coords_top_left_h, args.crops_coords_top_left_w)
+ prompt_embeds, pooled_prompt_embeds = encode_prompt(
+ prompt_batch, text_encoders, tokenizers, proportion_empty_prompts, is_train
+ )
+ add_text_embeds = pooled_prompt_embeds
+
+ # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
+ # add_time_ids = list(crops_coords_top_left + target_size)
+ add_time_ids = list(target_size)
+ add_time_ids = torch.tensor([add_time_ids])
+ add_time_ids = add_time_ids.repeat(len(prompt_batch), 1)
+ # add_time_ids = torch.cat([torch.tensor(original_sizes, dtype=torch.long), add_time_ids], dim=-1)
+ add_time_ids = torch.cat([original_sizes, crops_coords_top_left, add_time_ids], dim=-1)
+ add_time_ids = add_time_ids.to(accelerator.device, dtype=prompt_embeds.dtype)
+
+ prompt_embeds = prompt_embeds.to(accelerator.device)
+ add_text_embeds = add_text_embeds.to(accelerator.device)
+ unet_added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
+
+ return {"prompt_embeds": prompt_embeds, **unet_added_cond_kwargs}
+
+ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
+ sigmas = noise_scheduler.sigmas.to(device=accelerator.device, dtype=dtype)
+ schedule_timesteps = noise_scheduler.timesteps.to(accelerator.device)
+ timesteps = timesteps.to(accelerator.device)
+
+ step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
+
+ sigma = sigmas[step_indices].flatten()
+ while len(sigma.shape) < n_dim:
+ sigma = sigma.unsqueeze(-1)
+ return sigma
+
+ dataset = Text2ImageDataset(
+ train_shards_path_or_url=args.train_shards_path_or_url,
+ eval_shards_path_or_url=args.eval_shards_path_or_url,
+ num_train_examples=args.max_train_samples,
+ per_gpu_batch_size=args.train_batch_size,
+ global_batch_size=args.train_batch_size * accelerator.num_processes,
+ num_workers=args.dataloader_num_workers,
+ resolution=args.resolution,
+ center_crop=False,
+ random_flip=False,
+ shuffle_buffer_size=1000,
+ pin_memory=True,
+ persistent_workers=True,
+ control_type=args.control_type,
+ feature_extractor=feature_extractor,
+ )
+ train_dataloader = dataset.train_dataloader
+
+ # Let's first compute all the embeddings so that we can free up the text encoders
+ # from memory.
+ text_encoders = [text_encoder_one, text_encoder_two]
+ tokenizers = [tokenizer_one, tokenizer_two]
+
+ compute_embeddings_fn = functools.partial(
+ compute_embeddings,
+ proportion_empty_prompts=args.proportion_empty_prompts,
+ text_encoders=text_encoders,
+ tokenizers=tokenizers,
+ )
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(train_dataloader.num_batches / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
+ num_cycles=args.lr_num_cycles,
+ power=args.lr_power,
+ )
+
+ # Prepare everything with our `accelerator`.
+ controlnet, optimizer, lr_scheduler = accelerator.prepare(controlnet, optimizer, lr_scheduler)
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(train_dataloader.num_batches / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ tracker_config = dict(vars(args))
+
+ # tensorboard cannot handle list types for config
+ tracker_config.pop("validation_prompt")
+ tracker_config.pop("validation_image")
+
+ accelerator.init_trackers(args.tracker_project_name, config=tracker_config)
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num batches each epoch = {train_dataloader.num_batches}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the most recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ initial_global_step = 0
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ initial_global_step = global_step
+ first_epoch = global_step // num_update_steps_per_epoch
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
+
+ image_logs = None
+ for epoch in range(first_epoch, args.num_train_epochs):
+ for step, batch in enumerate(train_dataloader):
+ with accelerator.accumulate(controlnet):
+ image, control_image, text, orig_size, crop_coords = batch
+
+ encoded_text = compute_embeddings_fn(text, orig_size, crop_coords)
+ image = image.to(accelerator.device, non_blocking=True)
+ control_image = control_image.to(accelerator.device, non_blocking=True)
+
+ if args.pretrained_vae_model_name_or_path is not None:
+ pixel_values = image.to(dtype=weight_dtype)
+ if vae.dtype != weight_dtype:
+ vae.to(dtype=weight_dtype)
+ else:
+ pixel_values = image
+
+ # latents = vae.encode(pixel_values).latent_dist.sample()
+ # encode pixel values with batch size of at most 8
+ latents = []
+ for i in range(0, pixel_values.shape[0], 8):
+ latents.append(vae.encode(pixel_values[i : i + 8]).latent_dist.sample())
+ latents = torch.cat(latents, dim=0)
+
+ latents = latents * vae.config.scaling_factor
+ if args.pretrained_vae_model_name_or_path is None:
+ latents = latents.to(weight_dtype)
+
+ if args.control_type == "depth":
+ control_image = control_image.to(weight_dtype)
+ with torch.autocast("cuda"):
+ depth_map = depth_model(control_image).predicted_depth
+ depth_map = torch.nn.functional.interpolate(
+ depth_map.unsqueeze(1),
+ size=image.shape[2:],
+ mode="bicubic",
+ align_corners=False,
+ )
+ depth_min = torch.amin(depth_map, dim=[1, 2, 3], keepdim=True)
+ depth_max = torch.amax(depth_map, dim=[1, 2, 3], keepdim=True)
+ depth_map = (depth_map - depth_min) / (depth_max - depth_min)
+ control_image = (depth_map * 255.0).to(torch.uint8).float() / 255.0 # hack to match inference
+ control_image = torch.cat([control_image] * 3, dim=1)
+
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(latents)
+ bsz = latents.shape[0]
+
+ # Sample a random timestep for each image
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
+ timesteps = timesteps.long()
+
+ # Add noise to the latents according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
+ sigmas = get_sigmas(timesteps, len(noisy_latents.shape), noisy_latents.dtype)
+ inp_noisy_latents = noisy_latents / ((sigmas**2 + 1) ** 0.5)
+
+ # ControlNet conditioning.
+ controlnet_image = control_image.to(dtype=weight_dtype)
+ prompt_embeds = encoded_text.pop("prompt_embeds")
+ down_block_res_samples, mid_block_res_sample = controlnet(
+ inp_noisy_latents,
+ timesteps,
+ encoder_hidden_states=prompt_embeds,
+ added_cond_kwargs=encoded_text,
+ controlnet_cond=controlnet_image,
+ return_dict=False,
+ )
+
+ # Predict the noise residual
+ model_pred = unet(
+ inp_noisy_latents,
+ timesteps,
+ encoder_hidden_states=prompt_embeds,
+ added_cond_kwargs=encoded_text,
+ down_block_additional_residuals=[
+ sample.to(dtype=weight_dtype) for sample in down_block_res_samples
+ ],
+ mid_block_additional_residual=mid_block_res_sample.to(dtype=weight_dtype),
+ ).sample
+
+ model_pred = model_pred * (-sigmas) + noisy_latents
+ weighing = sigmas**-2.0
+
+ # Get the target for loss depending on the prediction type
+ if noise_scheduler.config.prediction_type == "epsilon":
+ target = latents # compute loss against the denoised latents
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
+ else:
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
+ loss = torch.mean(
+ (weighing.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1), 1
+ )
+ loss = loss.mean()
+
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ params_to_clip = controlnet.parameters()
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad(set_to_none=args.set_grads_to_none)
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ global_step += 1
+
+ if accelerator.is_main_process:
+ if global_step % args.checkpointing_steps == 0:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ if args.validation_prompt is not None and global_step % args.validation_steps == 0:
+ image_logs = log_validation(
+ vae, unet, controlnet, args, accelerator, weight_dtype, global_step
+ )
+
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+ accelerator.log(logs, step=global_step)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ # Create the pipeline using using the trained modules and save it.
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ controlnet = accelerator.unwrap_model(controlnet)
+ controlnet.save_pretrained(args.output_dir)
+
+ if args.push_to_hub:
+ save_model_card(
+ repo_id,
+ image_logs=image_logs,
+ base_model=args.pretrained_model_name_or_path,
+ repo_folder=args.output_dir,
+ )
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ main(args)
diff --git a/diffusers/examples/research_projects/diffusion_dpo/README.md b/diffusers/examples/research_projects/diffusion_dpo/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..32704b6f772bc2cdf9a24e61ebe59ae78458598a
--- /dev/null
+++ b/diffusers/examples/research_projects/diffusion_dpo/README.md
@@ -0,0 +1,94 @@
+# Diffusion Model Alignment Using Direct Preference Optimization
+
+This directory provides LoRA implementations of Diffusion DPO proposed in [DiffusionModel Alignment Using Direct Preference Optimization](https://arxiv.org/abs/2311.12908) by Bram Wallace, Meihua Dang, Rafael Rafailov, Linqi Zhou, Aaron Lou, Senthil Purushwalkam, Stefano Ermon, Caiming Xiong, Shafiq Joty, and Nikhil Naik.
+
+We provide implementations for both Stable Diffusion (SD) and Stable Diffusion XL (SDXL). The original checkpoints are available at the URLs below:
+
+* [mhdang/dpo-sd1.5-text2image-v1](https://huggingface.co/mhdang/dpo-sd1.5-text2image-v1)
+* [mhdang/dpo-sdxl-text2image-v1](https://huggingface.co/mhdang/dpo-sdxl-text2image-v1)
+
+> 💡 Note: The scripts are highly experimental and were only tested on low-data regimes. Proceed with caution. Feel free to let us know about your findings via GitHub issues.
+
+## SD training command
+
+```bash
+accelerate launch train_diffusion_dpo.py \
+ --pretrained_model_name_or_path=stable-diffusion-v1-5/stable-diffusion-v1-5 \
+ --output_dir="diffusion-dpo" \
+ --mixed_precision="fp16" \
+ --dataset_name=kashif/pickascore \
+ --resolution=512 \
+ --train_batch_size=16 \
+ --gradient_accumulation_steps=2 \
+ --gradient_checkpointing \
+ --use_8bit_adam \
+ --rank=8 \
+ --learning_rate=1e-5 \
+ --report_to="wandb" \
+ --lr_scheduler="constant" \
+ --lr_warmup_steps=0 \
+ --max_train_steps=10000 \
+ --checkpointing_steps=2000 \
+ --run_validation --validation_steps=200 \
+ --seed="0" \
+ --report_to="wandb" \
+ --push_to_hub
+```
+
+## SDXL training command
+
+```bash
+accelerate launch train_diffusion_dpo_sdxl.py \
+ --pretrained_model_name_or_path=stabilityai/stable-diffusion-xl-base-1.0 \
+ --pretrained_vae_model_name_or_path=madebyollin/sdxl-vae-fp16-fix \
+ --output_dir="diffusion-sdxl-dpo" \
+ --mixed_precision="fp16" \
+ --dataset_name=kashif/pickascore \
+ --train_batch_size=8 \
+ --gradient_accumulation_steps=2 \
+ --gradient_checkpointing \
+ --use_8bit_adam \
+ --rank=8 \
+ --learning_rate=1e-5 \
+ --report_to="wandb" \
+ --lr_scheduler="constant" \
+ --lr_warmup_steps=0 \
+ --max_train_steps=2000 \
+ --checkpointing_steps=500 \
+ --run_validation --validation_steps=50 \
+ --seed="0" \
+ --report_to="wandb" \
+ --push_to_hub
+```
+
+## SDXL Turbo training command
+
+```bash
+accelerate launch train_diffusion_dpo_sdxl.py \
+ --pretrained_model_name_or_path=stabilityai/sdxl-turbo \
+ --pretrained_vae_model_name_or_path=madebyollin/sdxl-vae-fp16-fix \
+ --output_dir="diffusion-sdxl-turbo-dpo" \
+ --mixed_precision="fp16" \
+ --dataset_name=kashif/pickascore \
+ --train_batch_size=8 \
+ --gradient_accumulation_steps=2 \
+ --gradient_checkpointing \
+ --use_8bit_adam \
+ --rank=8 \
+ --learning_rate=1e-5 \
+ --report_to="wandb" \
+ --lr_scheduler="constant" \
+ --lr_warmup_steps=0 \
+ --max_train_steps=2000 \
+ --checkpointing_steps=500 \
+ --run_validation --validation_steps=50 \
+ --seed="0" \
+ --report_to="wandb" \
+ --is_turbo --resolution 512 \
+ --push_to_hub
+```
+
+
+## Acknowledgements
+
+This is based on the amazing work done by [Bram](https://github.com/bram-w) here for Diffusion DPO: https://github.com/bram-w/trl/blob/dpo/.
diff --git a/diffusers/examples/research_projects/diffusion_dpo/requirements.txt b/diffusers/examples/research_projects/diffusion_dpo/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..11a5b9df6dad38b5807823afbdc2b0bf66fabc76
--- /dev/null
+++ b/diffusers/examples/research_projects/diffusion_dpo/requirements.txt
@@ -0,0 +1,8 @@
+accelerate>=0.16.0
+torchvision
+transformers>=4.25.1
+ftfy
+tensorboard
+Jinja2
+peft
+wandb
\ No newline at end of file
diff --git a/diffusers/examples/research_projects/diffusion_dpo/train_diffusion_dpo.py b/diffusers/examples/research_projects/diffusion_dpo/train_diffusion_dpo.py
new file mode 100644
index 0000000000000000000000000000000000000000..ab88d49677664732f25ebe542da3d95f2113208c
--- /dev/null
+++ b/diffusers/examples/research_projects/diffusion_dpo/train_diffusion_dpo.py
@@ -0,0 +1,982 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2024 bram-w, The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+import argparse
+import contextlib
+import io
+import logging
+import math
+import os
+import shutil
+from pathlib import Path
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import transformers
+import wandb
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import ProjectConfiguration, set_seed
+from datasets import load_dataset
+from huggingface_hub import create_repo, upload_folder
+from packaging import version
+from peft import LoraConfig
+from peft.utils import get_peft_model_state_dict
+from PIL import Image
+from torchvision import transforms
+from tqdm.auto import tqdm
+from transformers import AutoTokenizer, PretrainedConfig
+
+import diffusers
+from diffusers import (
+ AutoencoderKL,
+ DDPMScheduler,
+ DiffusionPipeline,
+ DPMSolverMultistepScheduler,
+ UNet2DConditionModel,
+)
+from diffusers.loaders import StableDiffusionLoraLoaderMixin
+from diffusers.optimization import get_scheduler
+from diffusers.utils import check_min_version, convert_state_dict_to_diffusers
+from diffusers.utils.import_utils import is_xformers_available
+
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.25.0.dev0")
+
+logger = get_logger(__name__)
+
+
+VALIDATION_PROMPTS = [
+ "portrait photo of a girl, photograph, highly detailed face, depth of field, moody light, golden hour, style by Dan Winters, Russell James, Steve McCurry, centered, extremely detailed, Nikon D850, award winning photography",
+ "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k",
+ "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
+ "A photo of beautiful mountain with realistic sunset and blue lake, highly detailed, masterpiece",
+]
+
+
+def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
+ text_encoder_config = PretrainedConfig.from_pretrained(
+ pretrained_model_name_or_path,
+ subfolder="text_encoder",
+ revision=revision,
+ )
+ model_class = text_encoder_config.architectures[0]
+
+ if model_class == "CLIPTextModel":
+ from transformers import CLIPTextModel
+
+ return CLIPTextModel
+ else:
+ raise ValueError(f"{model_class} is not supported.")
+
+
+def log_validation(args, unet, accelerator, weight_dtype, epoch, is_final_validation=False):
+ logger.info(f"Running validation... \n Generating images with prompts:\n" f" {VALIDATION_PROMPTS}.")
+
+ # create pipeline
+ pipeline = DiffusionPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ revision=args.revision,
+ variant=args.variant,
+ torch_dtype=weight_dtype,
+ )
+ if not is_final_validation:
+ pipeline.unet = accelerator.unwrap_model(unet)
+ else:
+ pipeline.load_lora_weights(args.output_dir, weight_name="pytorch_lora_weights.safetensors")
+
+ pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
+ pipeline = pipeline.to(accelerator.device)
+ pipeline.set_progress_bar_config(disable=True)
+
+ # run inference
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
+ images = []
+ context = contextlib.nullcontext() if is_final_validation else torch.cuda.amp.autocast()
+
+ for prompt in VALIDATION_PROMPTS:
+ with context:
+ image = pipeline(prompt, num_inference_steps=25, generator=generator).images[0]
+ images.append(image)
+
+ tracker_key = "test" if is_final_validation else "validation"
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images(tracker_key, np_images, epoch, dataformats="NHWC")
+ if tracker.name == "wandb":
+ tracker.log(
+ {
+ tracker_key: [
+ wandb.Image(image, caption=f"{i}: {VALIDATION_PROMPTS[i]}") for i, image in enumerate(images)
+ ]
+ }
+ )
+
+ # Also log images without the LoRA params for comparison.
+ if is_final_validation:
+ pipeline.disable_lora()
+ no_lora_images = [
+ pipeline(prompt, num_inference_steps=25, generator=generator).images[0] for prompt in VALIDATION_PROMPTS
+ ]
+
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in no_lora_images])
+ tracker.writer.add_images("test_without_lora", np_images, epoch, dataformats="NHWC")
+ if tracker.name == "wandb":
+ tracker.log(
+ {
+ "test_without_lora": [
+ wandb.Image(image, caption=f"{i}: {VALIDATION_PROMPTS[i]}")
+ for i, image in enumerate(no_lora_images)
+ ]
+ }
+ )
+
+
+def parse_args(input_args=None):
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--dataset_name",
+ type=str,
+ default=None,
+ help=(
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
+ " or to a folder containing files that 🤗 Datasets can understand."
+ ),
+ )
+ parser.add_argument(
+ "--dataset_split_name",
+ type=str,
+ default="validation",
+ help="Dataset split to be used during training. Helpful to specify for conducting experimental runs.",
+ )
+ parser.add_argument(
+ "--variant",
+ type=str,
+ default=None,
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
+ )
+ parser.add_argument(
+ "--run_validation",
+ default=False,
+ action="store_true",
+ help="Whether to run validation inference in between training and also after training. Helps to track progress.",
+ )
+ parser.add_argument(
+ "--validation_steps",
+ type=int,
+ default=200,
+ help="Run validation every X steps.",
+ )
+ parser.add_argument(
+ "--max_train_samples",
+ type=int,
+ default=None,
+ help=(
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ ),
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="diffusion-dpo-lora",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--vae_encode_batch_size",
+ type=int,
+ default=8,
+ help="Batch size to use for VAE encoding of the images for efficient processing.",
+ )
+ parser.add_argument(
+ "--no_hflip",
+ action="store_true",
+ help="whether to randomly flip images horizontally",
+ )
+ parser.add_argument(
+ "--random_crop",
+ default=False,
+ action="store_true",
+ help=(
+ "Whether to random crop the input images to the resolution. If not set, the images will be center-cropped."
+ ),
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=1)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
+ " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--beta_dpo",
+ type=int,
+ default=2500,
+ help="DPO KL Divergence penalty.",
+ )
+ parser.add_argument(
+ "--loss_type",
+ type=str,
+ default="sigmoid",
+ help="DPO loss type. Can be one of 'sigmoid' (default), 'ipo', or 'cpo'",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=5e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--lr_num_cycles",
+ type=int,
+ default=1,
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
+ )
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ parser.add_argument(
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument(
+ "--prior_generation_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp32", "fp16", "bf16"],
+ help=(
+ "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32."
+ ),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+ parser.add_argument(
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
+ )
+ parser.add_argument(
+ "--rank",
+ type=int,
+ default=4,
+ help=("The dimension of the LoRA update matrices."),
+ )
+ parser.add_argument(
+ "--tracker_name",
+ type=str,
+ default="diffusion-dpo-lora",
+ help=("The name of the tracker to report results to."),
+ )
+
+ if input_args is not None:
+ args = parser.parse_args(input_args)
+ else:
+ args = parser.parse_args()
+
+ if args.dataset_name is None:
+ raise ValueError("Must provide a `dataset_name`.")
+
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ return args
+
+
+def tokenize_captions(tokenizer, examples):
+ max_length = tokenizer.model_max_length
+ captions = []
+ for caption in examples["caption"]:
+ captions.append(caption)
+
+ text_inputs = tokenizer(
+ captions, truncation=True, padding="max_length", max_length=max_length, return_tensors="pt"
+ )
+
+ return text_inputs.input_ids
+
+
+@torch.no_grad()
+def encode_prompt(text_encoder, input_ids):
+ text_input_ids = input_ids.to(text_encoder.device)
+ attention_mask = None
+
+ prompt_embeds = text_encoder(text_input_ids, attention_mask=attention_mask)
+ prompt_embeds = prompt_embeds[0]
+
+ return prompt_embeds
+
+
+def main(args):
+ if args.report_to == "wandb" and args.hub_token is not None:
+ raise ValueError(
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
+ " Please use `huggingface-cli login` to authenticate with the Hub."
+ )
+
+ logging_dir = Path(args.output_dir, args.logging_dir)
+
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ )
+
+ # Disable AMP for MPS.
+ if torch.backends.mps.is_available():
+ accelerator.native_amp = False
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
+ ).repo_id
+
+ # Load the tokenizer
+ tokenizer = AutoTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="tokenizer",
+ revision=args.revision,
+ use_fast=False,
+ )
+
+ # import correct text encoder class
+ text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)
+
+ # Load scheduler and models
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
+ text_encoder = text_encoder_cls.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
+ )
+ vae = AutoencoderKL.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant
+ )
+
+ unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
+ )
+
+ vae.requires_grad_(False)
+ text_encoder.requires_grad_(False)
+ unet.requires_grad_(False)
+
+ # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
+ # as these weights are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ # Move unet, vae and text_encoder to device and cast to weight_dtype
+ unet.to(accelerator.device, dtype=weight_dtype)
+ vae.to(accelerator.device, dtype=weight_dtype)
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
+
+ # Set up LoRA.
+ unet_lora_config = LoraConfig(
+ r=args.rank,
+ lora_alpha=args.rank,
+ init_lora_weights="gaussian",
+ target_modules=["to_k", "to_q", "to_v", "to_out.0"],
+ )
+ # Add adapter and make sure the trainable params are in float32.
+ unet.add_adapter(unet_lora_config)
+ if args.mixed_precision == "fp16":
+ for param in unet.parameters():
+ # only upcast trainable parameters (LoRA) into fp32
+ if param.requires_grad:
+ param.data = param.to(torch.float32)
+
+ if args.enable_xformers_memory_efficient_attention:
+ if is_xformers_available():
+ import xformers
+
+ xformers_version = version.parse(xformers.__version__)
+ if xformers_version == version.parse("0.0.16"):
+ logger.warning(
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
+ )
+ unet.enable_xformers_memory_efficient_attention()
+ else:
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
+
+ if args.gradient_checkpointing:
+ unet.enable_gradient_checkpointing()
+
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
+ def save_model_hook(models, weights, output_dir):
+ if accelerator.is_main_process:
+ # there are only two options here. Either are just the unet attn processor layers
+ # or there are the unet and text encoder atten layers
+ unet_lora_layers_to_save = None
+
+ for model in models:
+ if isinstance(model, type(accelerator.unwrap_model(unet))):
+ unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model))
+ else:
+ raise ValueError(f"unexpected save model: {model.__class__}")
+
+ # make sure to pop weight so that corresponding model is not saved again
+ weights.pop()
+
+ StableDiffusionLoraLoaderMixin.save_lora_weights(
+ output_dir,
+ unet_lora_layers=unet_lora_layers_to_save,
+ text_encoder_lora_layers=None,
+ )
+
+ def load_model_hook(models, input_dir):
+ unet_ = None
+
+ while len(models) > 0:
+ model = models.pop()
+
+ if isinstance(model, type(accelerator.unwrap_model(unet))):
+ unet_ = model
+ else:
+ raise ValueError(f"unexpected save model: {model.__class__}")
+
+ lora_state_dict, network_alphas = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir)
+ StableDiffusionLoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_)
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
+ )
+
+ optimizer_class = bnb.optim.AdamW8bit
+ else:
+ optimizer_class = torch.optim.AdamW
+
+ # Optimizer creation
+ params_to_optimize = list(filter(lambda p: p.requires_grad, unet.parameters()))
+ optimizer = optimizer_class(
+ params_to_optimize,
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ # Dataset and DataLoaders creation:
+ train_dataset = load_dataset(
+ args.dataset_name,
+ cache_dir=args.cache_dir,
+ split=args.dataset_split_name,
+ )
+
+ train_transforms = transforms.Compose(
+ [
+ transforms.Resize(int(args.resolution), interpolation=transforms.InterpolationMode.BILINEAR),
+ transforms.RandomCrop(args.resolution) if args.random_crop else transforms.CenterCrop(args.resolution),
+ transforms.Lambda(lambda x: x) if args.no_hflip else transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+ )
+
+ def preprocess_train(examples):
+ all_pixel_values = []
+ for col_name in ["jpg_0", "jpg_1"]:
+ images = [Image.open(io.BytesIO(im_bytes)).convert("RGB") for im_bytes in examples[col_name]]
+ pixel_values = [train_transforms(image) for image in images]
+ all_pixel_values.append(pixel_values)
+
+ # Double on channel dim, jpg_y then jpg_w
+ im_tup_iterator = zip(*all_pixel_values)
+ combined_pixel_values = []
+ for im_tup, label_0 in zip(im_tup_iterator, examples["label_0"]):
+ if label_0 == 0:
+ im_tup = im_tup[::-1]
+ combined_im = torch.cat(im_tup, dim=0) # no batch dim
+ combined_pixel_values.append(combined_im)
+ examples["pixel_values"] = combined_pixel_values
+
+ examples["input_ids"] = tokenize_captions(tokenizer, examples)
+ return examples
+
+ with accelerator.main_process_first():
+ if args.max_train_samples is not None:
+ train_dataset = train_dataset.shuffle(seed=args.seed).select(range(args.max_train_samples))
+ # Set the training transforms
+ train_dataset = train_dataset.with_transform(preprocess_train)
+
+ def collate_fn(examples):
+ pixel_values = torch.stack([example["pixel_values"] for example in examples])
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
+ final_dict = {"pixel_values": pixel_values}
+ final_dict["input_ids"] = torch.stack([example["input_ids"] for example in examples])
+ return final_dict
+
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset,
+ batch_size=args.train_batch_size,
+ shuffle=True,
+ collate_fn=collate_fn,
+ num_workers=args.dataloader_num_workers,
+ )
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
+ num_cycles=args.lr_num_cycles,
+ power=args.lr_power,
+ )
+
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ unet, optimizer, train_dataloader, lr_scheduler
+ )
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ accelerator.init_trackers(args.tracker_name, config=vars(args))
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the mos recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ initial_global_step = 0
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ initial_global_step = global_step
+ first_epoch = global_step // num_update_steps_per_epoch
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
+
+ unet.train()
+ for epoch in range(first_epoch, args.num_train_epochs):
+ for step, batch in enumerate(train_dataloader):
+ with accelerator.accumulate(unet):
+ # (batch_size, 2*channels, h, w) -> (2*batch_size, channels, h, w)
+ pixel_values = batch["pixel_values"].to(dtype=weight_dtype)
+ feed_pixel_values = torch.cat(pixel_values.chunk(2, dim=1))
+
+ latents = []
+ for i in range(0, feed_pixel_values.shape[0], args.vae_encode_batch_size):
+ latents.append(
+ vae.encode(feed_pixel_values[i : i + args.vae_encode_batch_size]).latent_dist.sample()
+ )
+ latents = torch.cat(latents, dim=0)
+ latents = latents * vae.config.scaling_factor
+
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(latents).chunk(2)[0].repeat(2, 1, 1, 1)
+
+ # Sample a random timestep for each image
+ bsz = latents.shape[0] // 2
+ timesteps = torch.randint(
+ 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device, dtype=torch.long
+ ).repeat(2)
+
+ # Add noise to the model input according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ noisy_model_input = noise_scheduler.add_noise(latents, noise, timesteps)
+
+ # Get the text embedding for conditioning
+ encoder_hidden_states = encode_prompt(text_encoder, batch["input_ids"]).repeat(2, 1, 1)
+
+ # Predict the noise residual
+ model_pred = unet(
+ noisy_model_input,
+ timesteps,
+ encoder_hidden_states,
+ ).sample
+
+ # Get the target for loss depending on the prediction type
+ if noise_scheduler.config.prediction_type == "epsilon":
+ target = noise
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
+ else:
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
+
+ # Compute losses.
+ model_losses = F.mse_loss(model_pred.float(), target.float(), reduction="none")
+ model_losses = model_losses.mean(dim=list(range(1, len(model_losses.shape))))
+ model_losses_w, model_losses_l = model_losses.chunk(2)
+
+ # For logging
+ raw_model_loss = 0.5 * (model_losses_w.mean() + model_losses_l.mean())
+ model_diff = model_losses_w - model_losses_l # These are both LBS (as is t)
+
+ # Reference model predictions.
+ accelerator.unwrap_model(unet).disable_adapters()
+ with torch.no_grad():
+ ref_preds = unet(
+ noisy_model_input,
+ timesteps,
+ encoder_hidden_states,
+ ).sample.detach()
+ ref_loss = F.mse_loss(ref_preds.float(), target.float(), reduction="none")
+ ref_loss = ref_loss.mean(dim=list(range(1, len(ref_loss.shape))))
+
+ ref_losses_w, ref_losses_l = ref_loss.chunk(2)
+ ref_diff = ref_losses_w - ref_losses_l
+ raw_ref_loss = ref_loss.mean()
+
+ # Re-enable adapters.
+ accelerator.unwrap_model(unet).enable_adapters()
+
+ # Final loss.
+ logits = ref_diff - model_diff
+ if args.loss_type == "sigmoid":
+ loss = -1 * F.logsigmoid(args.beta_dpo * logits).mean()
+ elif args.loss_type == "hinge":
+ loss = torch.relu(1 - args.beta_dpo * logits).mean()
+ elif args.loss_type == "ipo":
+ losses = (logits - 1 / (2 * args.beta)) ** 2
+ loss = losses.mean()
+ else:
+ raise ValueError(f"Unknown loss type {args.loss_type}")
+
+ implicit_acc = (logits > 0).sum().float() / logits.size(0)
+ implicit_acc += 0.5 * (logits == 0).sum().float() / logits.size(0)
+
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ accelerator.clip_grad_norm_(params_to_optimize, args.max_grad_norm)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ global_step += 1
+
+ if accelerator.is_main_process:
+ if global_step % args.checkpointing_steps == 0:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ if args.run_validation and global_step % args.validation_steps == 0:
+ log_validation(
+ args, unet=unet, accelerator=accelerator, weight_dtype=weight_dtype, epoch=epoch
+ )
+
+ logs = {
+ "loss": loss.detach().item(),
+ "raw_model_loss": raw_model_loss.detach().item(),
+ "ref_loss": raw_ref_loss.detach().item(),
+ "implicit_acc": implicit_acc.detach().item(),
+ "lr": lr_scheduler.get_last_lr()[0],
+ }
+ progress_bar.set_postfix(**logs)
+ accelerator.log(logs, step=global_step)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ # Save the lora layers
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ unet = accelerator.unwrap_model(unet)
+ unet = unet.to(torch.float32)
+ unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet))
+
+ StableDiffusionLoraLoaderMixin.save_lora_weights(
+ save_directory=args.output_dir, unet_lora_layers=unet_lora_state_dict, text_encoder_lora_layers=None
+ )
+
+ # Final validation?
+ if args.run_validation:
+ log_validation(
+ args,
+ unet=None,
+ accelerator=accelerator,
+ weight_dtype=weight_dtype,
+ epoch=epoch,
+ is_final_validation=True,
+ )
+
+ if args.push_to_hub:
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ main(args)
diff --git a/diffusers/examples/research_projects/diffusion_dpo/train_diffusion_dpo_sdxl.py b/diffusers/examples/research_projects/diffusion_dpo/train_diffusion_dpo_sdxl.py
new file mode 100644
index 0000000000000000000000000000000000000000..0297a06f5b2ce18dcee2c438a06dc0cdc33ac12b
--- /dev/null
+++ b/diffusers/examples/research_projects/diffusion_dpo/train_diffusion_dpo_sdxl.py
@@ -0,0 +1,1140 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2024 bram-w, The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+import argparse
+import contextlib
+import io
+import logging
+import math
+import os
+import random
+import shutil
+from pathlib import Path
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import transformers
+import wandb
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import ProjectConfiguration, set_seed
+from datasets import load_dataset
+from huggingface_hub import create_repo, upload_folder
+from packaging import version
+from peft import LoraConfig
+from peft.utils import get_peft_model_state_dict
+from PIL import Image
+from torchvision import transforms
+from torchvision.transforms.functional import crop
+from tqdm.auto import tqdm
+from transformers import AutoTokenizer, PretrainedConfig
+
+import diffusers
+from diffusers import (
+ AutoencoderKL,
+ DDPMScheduler,
+ DiffusionPipeline,
+ UNet2DConditionModel,
+)
+from diffusers.loaders import StableDiffusionXLLoraLoaderMixin
+from diffusers.optimization import get_scheduler
+from diffusers.utils import check_min_version, convert_state_dict_to_diffusers
+from diffusers.utils.import_utils import is_xformers_available
+
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.25.0.dev0")
+
+logger = get_logger(__name__)
+
+
+VALIDATION_PROMPTS = [
+ "portrait photo of a girl, photograph, highly detailed face, depth of field, moody light, golden hour, style by Dan Winters, Russell James, Steve McCurry, centered, extremely detailed, Nikon D850, award winning photography",
+ "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k",
+ "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
+ "A photo of beautiful mountain with realistic sunset and blue lake, highly detailed, masterpiece",
+]
+
+
+def import_model_class_from_model_name_or_path(
+ pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
+):
+ text_encoder_config = PretrainedConfig.from_pretrained(
+ pretrained_model_name_or_path, subfolder=subfolder, revision=revision
+ )
+ model_class = text_encoder_config.architectures[0]
+
+ if model_class == "CLIPTextModel":
+ from transformers import CLIPTextModel
+
+ return CLIPTextModel
+ elif model_class == "CLIPTextModelWithProjection":
+ from transformers import CLIPTextModelWithProjection
+
+ return CLIPTextModelWithProjection
+ else:
+ raise ValueError(f"{model_class} is not supported.")
+
+
+def log_validation(args, unet, vae, accelerator, weight_dtype, epoch, is_final_validation=False):
+ logger.info(f"Running validation... \n Generating images with prompts:\n" f" {VALIDATION_PROMPTS}.")
+
+ if is_final_validation:
+ if args.mixed_precision == "fp16":
+ vae.to(weight_dtype)
+
+ # create pipeline
+ pipeline = DiffusionPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ vae=vae,
+ revision=args.revision,
+ variant=args.variant,
+ torch_dtype=weight_dtype,
+ )
+ if not is_final_validation:
+ pipeline.unet = accelerator.unwrap_model(unet)
+ else:
+ pipeline.load_lora_weights(args.output_dir, weight_name="pytorch_lora_weights.safetensors")
+
+ pipeline = pipeline.to(accelerator.device)
+ pipeline.set_progress_bar_config(disable=True)
+
+ # run inference
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
+ images = []
+ context = contextlib.nullcontext() if is_final_validation else torch.cuda.amp.autocast()
+
+ guidance_scale = 5.0
+ num_inference_steps = 25
+ if args.is_turbo:
+ guidance_scale = 0.0
+ num_inference_steps = 4
+ for prompt in VALIDATION_PROMPTS:
+ with context:
+ image = pipeline(
+ prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, generator=generator
+ ).images[0]
+ images.append(image)
+
+ tracker_key = "test" if is_final_validation else "validation"
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images(tracker_key, np_images, epoch, dataformats="NHWC")
+ if tracker.name == "wandb":
+ tracker.log(
+ {
+ tracker_key: [
+ wandb.Image(image, caption=f"{i}: {VALIDATION_PROMPTS[i]}") for i, image in enumerate(images)
+ ]
+ }
+ )
+
+ # Also log images without the LoRA params for comparison.
+ if is_final_validation:
+ pipeline.disable_lora()
+ no_lora_images = [
+ pipeline(
+ prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, generator=generator
+ ).images[0]
+ for prompt in VALIDATION_PROMPTS
+ ]
+
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in no_lora_images])
+ tracker.writer.add_images("test_without_lora", np_images, epoch, dataformats="NHWC")
+ if tracker.name == "wandb":
+ tracker.log(
+ {
+ "test_without_lora": [
+ wandb.Image(image, caption=f"{i}: {VALIDATION_PROMPTS[i]}")
+ for i, image in enumerate(no_lora_images)
+ ]
+ }
+ )
+
+
+def parse_args(input_args=None):
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--pretrained_vae_model_name_or_path",
+ type=str,
+ default=None,
+ help="Path to pretrained VAE model with better numerical stability. More details: https://github.com/huggingface/diffusers/pull/4038.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--dataset_name",
+ type=str,
+ default=None,
+ help=(
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
+ " or to a folder containing files that 🤗 Datasets can understand."
+ ),
+ )
+ parser.add_argument(
+ "--dataset_split_name",
+ type=str,
+ default="validation",
+ help="Dataset split to be used during training. Helpful to specify for conducting experimental runs.",
+ )
+ parser.add_argument(
+ "--variant",
+ type=str,
+ default=None,
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
+ )
+ parser.add_argument(
+ "--run_validation",
+ default=False,
+ action="store_true",
+ help="Whether to run validation inference in between training and also after training. Helps to track progress.",
+ )
+ parser.add_argument(
+ "--validation_steps",
+ type=int,
+ default=200,
+ help="Run validation every X steps.",
+ )
+ parser.add_argument(
+ "--max_train_samples",
+ type=int,
+ default=None,
+ help=(
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ ),
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="diffusion-dpo-lora",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=1024,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--vae_encode_batch_size",
+ type=int,
+ default=8,
+ help="Batch size to use for VAE encoding of the images for efficient processing.",
+ )
+ parser.add_argument(
+ "--no_hflip",
+ action="store_true",
+ help="whether to randomly flip images horizontally",
+ )
+ parser.add_argument(
+ "--random_crop",
+ default=False,
+ action="store_true",
+ help=(
+ "Whether to random crop the input images to the resolution. If not set, the images will be center-cropped."
+ ),
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=1)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
+ " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--beta_dpo",
+ type=int,
+ default=5000,
+ help="DPO KL Divergence penalty.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=5e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--lr_num_cycles",
+ type=int,
+ default=1,
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
+ )
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ parser.add_argument(
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument(
+ "--prior_generation_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp32", "fp16", "bf16"],
+ help=(
+ "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32."
+ ),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+ parser.add_argument(
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
+ )
+ parser.add_argument(
+ "--is_turbo",
+ action="store_true",
+ help=("Use if tuning SDXL Turbo instead of SDXL"),
+ )
+ parser.add_argument(
+ "--rank",
+ type=int,
+ default=4,
+ help=("The dimension of the LoRA update matrices."),
+ )
+ parser.add_argument(
+ "--tracker_name",
+ type=str,
+ default="diffusion-dpo-lora-sdxl",
+ help=("The name of the tracker to report results to."),
+ )
+
+ if input_args is not None:
+ args = parser.parse_args(input_args)
+ else:
+ args = parser.parse_args()
+
+ if args.dataset_name is None:
+ raise ValueError("Must provide a `dataset_name`.")
+
+ if args.is_turbo:
+ assert "turbo" in args.pretrained_model_name_or_path
+
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ return args
+
+
+def tokenize_captions(tokenizers, examples):
+ captions = []
+ for caption in examples["caption"]:
+ captions.append(caption)
+
+ tokens_one = tokenizers[0](
+ captions, truncation=True, padding="max_length", max_length=tokenizers[0].model_max_length, return_tensors="pt"
+ ).input_ids
+ tokens_two = tokenizers[1](
+ captions, truncation=True, padding="max_length", max_length=tokenizers[1].model_max_length, return_tensors="pt"
+ ).input_ids
+
+ return tokens_one, tokens_two
+
+
+@torch.no_grad()
+def encode_prompt(text_encoders, text_input_ids_list):
+ prompt_embeds_list = []
+
+ for i, text_encoder in enumerate(text_encoders):
+ text_input_ids = text_input_ids_list[i]
+
+ prompt_embeds = text_encoder(
+ text_input_ids.to(text_encoder.device),
+ output_hidden_states=True,
+ )
+
+ # We are only ALWAYS interested in the pooled output of the final text encoder
+ pooled_prompt_embeds = prompt_embeds[0]
+ prompt_embeds = prompt_embeds.hidden_states[-2]
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
+ prompt_embeds_list.append(prompt_embeds)
+
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
+ pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
+ return prompt_embeds, pooled_prompt_embeds
+
+
+def main(args):
+ if args.report_to == "wandb" and args.hub_token is not None:
+ raise ValueError(
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
+ " Please use `huggingface-cli login` to authenticate with the Hub."
+ )
+
+ logging_dir = Path(args.output_dir, args.logging_dir)
+
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ )
+
+ # Disable AMP for MPS.
+ if torch.backends.mps.is_available():
+ accelerator.native_amp = False
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
+ ).repo_id
+
+ # Load the tokenizers
+ tokenizer_one = AutoTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="tokenizer",
+ revision=args.revision,
+ use_fast=False,
+ )
+ tokenizer_two = AutoTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="tokenizer_2",
+ revision=args.revision,
+ use_fast=False,
+ )
+
+ # import correct text encoder classes
+ text_encoder_cls_one = import_model_class_from_model_name_or_path(
+ args.pretrained_model_name_or_path, args.revision
+ )
+ text_encoder_cls_two = import_model_class_from_model_name_or_path(
+ args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2"
+ )
+
+ # Load scheduler and models
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
+
+ def enforce_zero_terminal_snr(scheduler):
+ # Modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_ddpm.py#L93
+ # Original implementation https://arxiv.org/pdf/2305.08891.pdf
+ # Turbo needs zero terminal SNR
+ # Turbo: https://static1.squarespace.com/static/6213c340453c3f502425776e/t/65663480a92fba51d0e1023f/1701197769659/adversarial_diffusion_distillation.pdf
+ # Convert betas to alphas_bar_sqrt
+ alphas = 1 - scheduler.betas
+ alphas_bar = alphas.cumprod(0)
+ alphas_bar_sqrt = alphas_bar.sqrt()
+
+ # Store old values.
+ alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
+ alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
+ # Shift so last timestep is zero.
+ alphas_bar_sqrt -= alphas_bar_sqrt_T
+ # Scale so first timestep is back to old value.
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
+
+ alphas_bar = alphas_bar_sqrt**2
+ alphas = alphas_bar[1:] / alphas_bar[:-1]
+ alphas = torch.cat([alphas_bar[0:1], alphas])
+
+ alphas_cumprod = torch.cumprod(alphas, dim=0)
+ scheduler.alphas_cumprod = alphas_cumprod
+ return
+
+ if args.is_turbo:
+ enforce_zero_terminal_snr(noise_scheduler)
+
+ text_encoder_one = text_encoder_cls_one.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
+ )
+ text_encoder_two = text_encoder_cls_two.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant
+ )
+ vae_path = (
+ args.pretrained_model_name_or_path
+ if args.pretrained_vae_model_name_or_path is None
+ else args.pretrained_vae_model_name_or_path
+ )
+ vae = AutoencoderKL.from_pretrained(
+ vae_path,
+ subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
+ revision=args.revision,
+ variant=args.variant,
+ )
+ unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
+ )
+
+ # We only train the additional adapter LoRA layers
+ vae.requires_grad_(False)
+ text_encoder_one.requires_grad_(False)
+ text_encoder_two.requires_grad_(False)
+ unet.requires_grad_(False)
+
+ # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
+ # as these weights are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ # Move unet and text_encoders to device and cast to weight_dtype
+ unet.to(accelerator.device, dtype=weight_dtype)
+ text_encoder_one.to(accelerator.device, dtype=weight_dtype)
+ text_encoder_two.to(accelerator.device, dtype=weight_dtype)
+
+ # The VAE is always in float32 to avoid NaN losses.
+ vae.to(accelerator.device, dtype=torch.float32)
+
+ # Set up LoRA.
+ unet_lora_config = LoraConfig(
+ r=args.rank,
+ lora_alpha=args.rank,
+ init_lora_weights="gaussian",
+ target_modules=["to_k", "to_q", "to_v", "to_out.0"],
+ )
+ # Add adapter and make sure the trainable params are in float32.
+ unet.add_adapter(unet_lora_config)
+ if args.mixed_precision == "fp16":
+ for param in unet.parameters():
+ # only upcast trainable parameters (LoRA) into fp32
+ if param.requires_grad:
+ param.data = param.to(torch.float32)
+
+ if args.enable_xformers_memory_efficient_attention:
+ if is_xformers_available():
+ import xformers
+
+ xformers_version = version.parse(xformers.__version__)
+ if xformers_version == version.parse("0.0.16"):
+ logger.warning(
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
+ )
+ unet.enable_xformers_memory_efficient_attention()
+ else:
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
+
+ if args.gradient_checkpointing:
+ unet.enable_gradient_checkpointing()
+
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
+ def save_model_hook(models, weights, output_dir):
+ if accelerator.is_main_process:
+ # there are only two options here. Either are just the unet attn processor layers
+ # or there are the unet and text encoder atten layers
+ unet_lora_layers_to_save = None
+
+ for model in models:
+ if isinstance(model, type(accelerator.unwrap_model(unet))):
+ unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model))
+ else:
+ raise ValueError(f"unexpected save model: {model.__class__}")
+
+ # make sure to pop weight so that corresponding model is not saved again
+ weights.pop()
+
+ StableDiffusionXLLoraLoaderMixin.save_lora_weights(output_dir, unet_lora_layers=unet_lora_layers_to_save)
+
+ def load_model_hook(models, input_dir):
+ unet_ = None
+
+ while len(models) > 0:
+ model = models.pop()
+
+ if isinstance(model, type(accelerator.unwrap_model(unet))):
+ unet_ = model
+ else:
+ raise ValueError(f"unexpected save model: {model.__class__}")
+
+ lora_state_dict, network_alphas = StableDiffusionXLLoraLoaderMixin.lora_state_dict(input_dir)
+ StableDiffusionXLLoraLoaderMixin.load_lora_into_unet(
+ lora_state_dict, network_alphas=network_alphas, unet=unet_
+ )
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
+ )
+
+ optimizer_class = bnb.optim.AdamW8bit
+ else:
+ optimizer_class = torch.optim.AdamW
+
+ # Optimizer creation
+ params_to_optimize = list(filter(lambda p: p.requires_grad, unet.parameters()))
+ optimizer = optimizer_class(
+ params_to_optimize,
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ # Dataset and DataLoaders creation:
+ train_dataset = load_dataset(
+ args.dataset_name,
+ cache_dir=args.cache_dir,
+ split=args.dataset_split_name,
+ )
+
+ # Preprocessing the datasets.
+ train_resize = transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR)
+ train_crop = transforms.RandomCrop(args.resolution) if args.random_crop else transforms.CenterCrop(args.resolution)
+ train_flip = transforms.RandomHorizontalFlip(p=1.0)
+ to_tensor = transforms.ToTensor()
+ normalize = transforms.Normalize([0.5], [0.5])
+
+ def preprocess_train(examples):
+ all_pixel_values = []
+ images = [Image.open(io.BytesIO(im_bytes)).convert("RGB") for im_bytes in examples["jpg_0"]]
+ original_sizes = [(image.height, image.width) for image in images]
+ crop_top_lefts = []
+
+ for col_name in ["jpg_0", "jpg_1"]:
+ images = [Image.open(io.BytesIO(im_bytes)).convert("RGB") for im_bytes in examples[col_name]]
+ if col_name == "jpg_1":
+ # Need to bring down the image to the same resolution.
+ # This seems like the simplest reasonable approach.
+ # "::-1" because PIL resize takes (width, height).
+ images = [image.resize(original_sizes[i][::-1]) for i, image in enumerate(images)]
+ pixel_values = [to_tensor(image) for image in images]
+ all_pixel_values.append(pixel_values)
+
+ # Double on channel dim, jpg_y then jpg_w
+ im_tup_iterator = zip(*all_pixel_values)
+ combined_pixel_values = []
+ for im_tup, label_0 in zip(im_tup_iterator, examples["label_0"]):
+ if label_0 == 0:
+ im_tup = im_tup[::-1]
+
+ combined_im = torch.cat(im_tup, dim=0) # no batch dim
+
+ # Resize.
+ combined_im = train_resize(combined_im)
+
+ # Flipping.
+ if not args.no_hflip and random.random() < 0.5:
+ combined_im = train_flip(combined_im)
+
+ # Cropping.
+ if not args.random_crop:
+ y1 = max(0, int(round((combined_im.shape[1] - args.resolution) / 2.0)))
+ x1 = max(0, int(round((combined_im.shape[2] - args.resolution) / 2.0)))
+ combined_im = train_crop(combined_im)
+ else:
+ y1, x1, h, w = train_crop.get_params(combined_im, (args.resolution, args.resolution))
+ combined_im = crop(combined_im, y1, x1, h, w)
+
+ crop_top_left = (y1, x1)
+ crop_top_lefts.append(crop_top_left)
+ combined_im = normalize(combined_im)
+ combined_pixel_values.append(combined_im)
+
+ examples["pixel_values"] = combined_pixel_values
+ examples["original_sizes"] = original_sizes
+ examples["crop_top_lefts"] = crop_top_lefts
+ tokens_one, tokens_two = tokenize_captions([tokenizer_one, tokenizer_two], examples)
+ examples["input_ids_one"] = tokens_one
+ examples["input_ids_two"] = tokens_two
+ return examples
+
+ with accelerator.main_process_first():
+ if args.max_train_samples is not None:
+ train_dataset = train_dataset.shuffle(seed=args.seed).select(range(args.max_train_samples))
+ # Set the training transforms
+ train_dataset = train_dataset.with_transform(preprocess_train)
+
+ def collate_fn(examples):
+ pixel_values = torch.stack([example["pixel_values"] for example in examples])
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
+ original_sizes = [example["original_sizes"] for example in examples]
+ crop_top_lefts = [example["crop_top_lefts"] for example in examples]
+ input_ids_one = torch.stack([example["input_ids_one"] for example in examples])
+ input_ids_two = torch.stack([example["input_ids_two"] for example in examples])
+
+ return {
+ "pixel_values": pixel_values,
+ "input_ids_one": input_ids_one,
+ "input_ids_two": input_ids_two,
+ "original_sizes": original_sizes,
+ "crop_top_lefts": crop_top_lefts,
+ }
+
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset,
+ batch_size=args.train_batch_size,
+ shuffle=True,
+ collate_fn=collate_fn,
+ num_workers=args.dataloader_num_workers,
+ )
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
+ num_cycles=args.lr_num_cycles,
+ power=args.lr_power,
+ )
+
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ unet, optimizer, train_dataloader, lr_scheduler
+ )
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ accelerator.init_trackers(args.tracker_name, config=vars(args))
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the mos recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ initial_global_step = 0
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ initial_global_step = global_step
+ first_epoch = global_step // num_update_steps_per_epoch
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
+
+ unet.train()
+ for epoch in range(first_epoch, args.num_train_epochs):
+ for step, batch in enumerate(train_dataloader):
+ with accelerator.accumulate(unet):
+ # (batch_size, 2*channels, h, w) -> (2*batch_size, channels, h, w)
+ pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
+ feed_pixel_values = torch.cat(pixel_values.chunk(2, dim=1))
+
+ latents = []
+ for i in range(0, feed_pixel_values.shape[0], args.vae_encode_batch_size):
+ latents.append(
+ vae.encode(feed_pixel_values[i : i + args.vae_encode_batch_size]).latent_dist.sample()
+ )
+ latents = torch.cat(latents, dim=0)
+ latents = latents * vae.config.scaling_factor
+ if args.pretrained_vae_model_name_or_path is None:
+ latents = latents.to(weight_dtype)
+
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(latents).chunk(2)[0].repeat(2, 1, 1, 1)
+
+ # Sample a random timestep for each image
+ bsz = latents.shape[0] // 2
+ timesteps = torch.randint(
+ 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device, dtype=torch.long
+ ).repeat(2)
+ if args.is_turbo:
+ # Learn a 4 timestep schedule
+ timesteps_0_to_3 = timesteps % 4
+ timesteps = 250 * timesteps_0_to_3 + 249
+
+ # Add noise to the model input according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ noisy_model_input = noise_scheduler.add_noise(latents, noise, timesteps)
+
+ # time ids
+ def compute_time_ids(original_size, crops_coords_top_left):
+ # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
+ target_size = (args.resolution, args.resolution)
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
+ add_time_ids = torch.tensor([add_time_ids])
+ add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype)
+ return add_time_ids
+
+ add_time_ids = torch.cat(
+ [compute_time_ids(s, c) for s, c in zip(batch["original_sizes"], batch["crop_top_lefts"])]
+ ).repeat(2, 1)
+
+ # Get the text embedding for conditioning
+ prompt_embeds, pooled_prompt_embeds = encode_prompt(
+ [text_encoder_one, text_encoder_two], [batch["input_ids_one"], batch["input_ids_two"]]
+ )
+ prompt_embeds = prompt_embeds.repeat(2, 1, 1)
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(2, 1)
+
+ # Predict the noise residual
+ model_pred = unet(
+ noisy_model_input,
+ timesteps,
+ prompt_embeds,
+ added_cond_kwargs={"time_ids": add_time_ids, "text_embeds": pooled_prompt_embeds},
+ ).sample
+
+ # Get the target for loss depending on the prediction type
+ if noise_scheduler.config.prediction_type == "epsilon":
+ target = noise
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
+ else:
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
+
+ # Compute losses.
+ model_losses = F.mse_loss(model_pred.float(), target.float(), reduction="none")
+ model_losses = model_losses.mean(dim=list(range(1, len(model_losses.shape))))
+ model_losses_w, model_losses_l = model_losses.chunk(2)
+
+ # For logging
+ raw_model_loss = 0.5 * (model_losses_w.mean() + model_losses_l.mean())
+ model_diff = model_losses_w - model_losses_l # These are both LBS (as is t)
+
+ # Reference model predictions.
+ accelerator.unwrap_model(unet).disable_adapters()
+ with torch.no_grad():
+ ref_preds = unet(
+ noisy_model_input,
+ timesteps,
+ prompt_embeds,
+ added_cond_kwargs={"time_ids": add_time_ids, "text_embeds": pooled_prompt_embeds},
+ ).sample
+ ref_loss = F.mse_loss(ref_preds.float(), target.float(), reduction="none")
+ ref_loss = ref_loss.mean(dim=list(range(1, len(ref_loss.shape))))
+
+ ref_losses_w, ref_losses_l = ref_loss.chunk(2)
+ ref_diff = ref_losses_w - ref_losses_l
+ raw_ref_loss = ref_loss.mean()
+
+ # Re-enable adapters.
+ accelerator.unwrap_model(unet).enable_adapters()
+
+ # Final loss.
+ scale_term = -0.5 * args.beta_dpo
+ inside_term = scale_term * (model_diff - ref_diff)
+ loss = -1 * F.logsigmoid(inside_term).mean()
+
+ implicit_acc = (inside_term > 0).sum().float() / inside_term.size(0)
+ implicit_acc += 0.5 * (inside_term == 0).sum().float() / inside_term.size(0)
+
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ accelerator.clip_grad_norm_(params_to_optimize, args.max_grad_norm)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ global_step += 1
+
+ if accelerator.is_main_process:
+ if global_step % args.checkpointing_steps == 0:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ if args.run_validation and global_step % args.validation_steps == 0:
+ log_validation(
+ args, unet=unet, vae=vae, accelerator=accelerator, weight_dtype=weight_dtype, epoch=epoch
+ )
+
+ logs = {
+ "loss": loss.detach().item(),
+ "raw_model_loss": raw_model_loss.detach().item(),
+ "ref_loss": raw_ref_loss.detach().item(),
+ "implicit_acc": implicit_acc.detach().item(),
+ "lr": lr_scheduler.get_last_lr()[0],
+ }
+ progress_bar.set_postfix(**logs)
+ accelerator.log(logs, step=global_step)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ # Save the lora layers
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ unet = accelerator.unwrap_model(unet)
+ unet = unet.to(torch.float32)
+ unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet))
+
+ StableDiffusionXLLoraLoaderMixin.save_lora_weights(
+ save_directory=args.output_dir,
+ unet_lora_layers=unet_lora_state_dict,
+ text_encoder_lora_layers=None,
+ text_encoder_2_lora_layers=None,
+ )
+
+ # Final validation?
+ if args.run_validation:
+ log_validation(
+ args,
+ unet=None,
+ vae=vae,
+ accelerator=accelerator,
+ weight_dtype=weight_dtype,
+ epoch=epoch,
+ is_final_validation=True,
+ )
+
+ if args.push_to_hub:
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ main(args)
diff --git a/diffusers/examples/research_projects/diffusion_orpo/README.md b/diffusers/examples/research_projects/diffusion_orpo/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..3f1ee0413e33e451b5778e3aff72c9cdd0650772
--- /dev/null
+++ b/diffusers/examples/research_projects/diffusion_orpo/README.md
@@ -0,0 +1 @@
+This project has a new home now: [https://mapo-t2i.github.io/](https://mapo-t2i.github.io/). We formally studied the use of ORPO in the context of diffusion models and open-sourced our codebase, models, and datasets. We released our paper too!
diff --git a/diffusers/examples/research_projects/diffusion_orpo/requirements.txt b/diffusers/examples/research_projects/diffusion_orpo/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..9c8c6600b824c9e1f811577e9fcc6bd26a53f0e2
--- /dev/null
+++ b/diffusers/examples/research_projects/diffusion_orpo/requirements.txt
@@ -0,0 +1,7 @@
+datasets
+accelerate
+transformers
+torchvision
+wandb
+peft
+webdataset
\ No newline at end of file
diff --git a/diffusers/examples/research_projects/diffusion_orpo/train_diffusion_orpo_sdxl_lora.py b/diffusers/examples/research_projects/diffusion_orpo/train_diffusion_orpo_sdxl_lora.py
new file mode 100644
index 0000000000000000000000000000000000000000..cdc096190f0873aa85d9d839c0d252282e26ab37
--- /dev/null
+++ b/diffusers/examples/research_projects/diffusion_orpo/train_diffusion_orpo_sdxl_lora.py
@@ -0,0 +1,1092 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+import argparse
+import contextlib
+import io
+import logging
+import math
+import os
+import random
+import shutil
+from pathlib import Path
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import transformers
+import wandb
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import ProjectConfiguration, set_seed
+from datasets import load_dataset
+from huggingface_hub import create_repo, upload_folder
+from packaging import version
+from peft import LoraConfig, set_peft_model_state_dict
+from peft.utils import get_peft_model_state_dict
+from PIL import Image
+from torchvision import transforms
+from torchvision.transforms.functional import crop
+from tqdm.auto import tqdm
+from transformers import AutoTokenizer, PretrainedConfig
+
+import diffusers
+from diffusers import (
+ AutoencoderKL,
+ DDPMScheduler,
+ DiffusionPipeline,
+ UNet2DConditionModel,
+)
+from diffusers.loaders import StableDiffusionXLLoraLoaderMixin
+from diffusers.optimization import get_scheduler
+from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, convert_unet_state_dict_to_peft
+from diffusers.utils.import_utils import is_xformers_available
+
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.25.0.dev0")
+
+logger = get_logger(__name__)
+
+
+VALIDATION_PROMPTS = [
+ "portrait photo of a girl, photograph, highly detailed face, depth of field, moody light, golden hour, style by Dan Winters, Russell James, Steve McCurry, centered, extremely detailed, Nikon D850, award winning photography",
+ "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k",
+ "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
+ "A photo of beautiful mountain with realistic sunset and blue lake, highly detailed, masterpiece",
+]
+
+
+def import_model_class_from_model_name_or_path(
+ pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
+):
+ text_encoder_config = PretrainedConfig.from_pretrained(
+ pretrained_model_name_or_path, subfolder=subfolder, revision=revision
+ )
+ model_class = text_encoder_config.architectures[0]
+
+ if model_class == "CLIPTextModel":
+ from transformers import CLIPTextModel
+
+ return CLIPTextModel
+ elif model_class == "CLIPTextModelWithProjection":
+ from transformers import CLIPTextModelWithProjection
+
+ return CLIPTextModelWithProjection
+ else:
+ raise ValueError(f"{model_class} is not supported.")
+
+
+def log_validation(args, unet, vae, accelerator, weight_dtype, epoch, is_final_validation=False):
+ logger.info(f"Running validation... \n Generating images with prompts:\n" f" {VALIDATION_PROMPTS}.")
+
+ if is_final_validation:
+ if args.mixed_precision == "fp16":
+ vae.to(weight_dtype)
+
+ # create pipeline
+ pipeline = DiffusionPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ vae=vae,
+ revision=args.revision,
+ variant=args.variant,
+ torch_dtype=weight_dtype,
+ )
+ if not is_final_validation:
+ pipeline.unet = accelerator.unwrap_model(unet)
+ else:
+ pipeline.load_lora_weights(args.output_dir, weight_name="pytorch_lora_weights.safetensors")
+
+ pipeline = pipeline.to(accelerator.device)
+ pipeline.set_progress_bar_config(disable=True)
+
+ # run inference
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
+ images = []
+ context = contextlib.nullcontext() if is_final_validation else torch.cuda.amp.autocast()
+
+ guidance_scale = 5.0
+ num_inference_steps = 25
+ for prompt in VALIDATION_PROMPTS:
+ with context:
+ image = pipeline(
+ prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, generator=generator
+ ).images[0]
+ images.append(image)
+
+ tracker_key = "test" if is_final_validation else "validation"
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images(tracker_key, np_images, epoch, dataformats="NHWC")
+ if tracker.name == "wandb":
+ tracker.log(
+ {
+ tracker_key: [
+ wandb.Image(image, caption=f"{i}: {VALIDATION_PROMPTS[i]}") for i, image in enumerate(images)
+ ]
+ }
+ )
+
+ # Also log images without the LoRA params for comparison.
+ if is_final_validation:
+ pipeline.disable_lora()
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
+ no_lora_images = [
+ pipeline(
+ prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, generator=generator
+ ).images[0]
+ for prompt in VALIDATION_PROMPTS
+ ]
+
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in no_lora_images])
+ tracker.writer.add_images("test_without_lora", np_images, epoch, dataformats="NHWC")
+ if tracker.name == "wandb":
+ tracker.log(
+ {
+ "test_without_lora": [
+ wandb.Image(image, caption=f"{i}: {VALIDATION_PROMPTS[i]}")
+ for i, image in enumerate(no_lora_images)
+ ]
+ }
+ )
+
+
+def parse_args(input_args=None):
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--pretrained_vae_model_name_or_path",
+ type=str,
+ default=None,
+ help="Path to pretrained VAE model with better numerical stability. More details: https://github.com/huggingface/diffusers/pull/4038.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--dataset_name",
+ type=str,
+ default=None,
+ help=(
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
+ " or to a folder containing files that 🤗 Datasets can understand."
+ ),
+ )
+ parser.add_argument(
+ "--dataset_split_name",
+ type=str,
+ default="validation",
+ help="Dataset split to be used during training. Helpful to specify for conducting experimental runs.",
+ )
+ parser.add_argument(
+ "--variant",
+ type=str,
+ default=None,
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
+ )
+ parser.add_argument(
+ "--run_validation",
+ default=False,
+ action="store_true",
+ help="Whether to run validation inference in between training and also after training. Helps to track progress.",
+ )
+ parser.add_argument(
+ "--validation_steps",
+ type=int,
+ default=200,
+ help="Run validation every X steps.",
+ )
+ parser.add_argument(
+ "--max_train_samples",
+ type=int,
+ default=None,
+ help=(
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ ),
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="diffusion-orpo-lora-sdxl",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=1024,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--vae_encode_batch_size",
+ type=int,
+ default=8,
+ help="Batch size to use for VAE encoding of the images for efficient processing.",
+ )
+ parser.add_argument(
+ "--no_hflip",
+ action="store_true",
+ help="whether to randomly flip images horizontally",
+ )
+ parser.add_argument(
+ "--random_crop",
+ default=False,
+ action="store_true",
+ help=(
+ "Whether to random crop the input images to the resolution. If not set, the images will be center-cropped."
+ ),
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=1)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
+ " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--beta_orpo",
+ type=float,
+ default=0.1,
+ help="ORPO contribution factor.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=5e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--lr_num_cycles",
+ type=int,
+ default=1,
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
+ )
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ parser.add_argument(
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument(
+ "--prior_generation_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp32", "fp16", "bf16"],
+ help=(
+ "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32."
+ ),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+ parser.add_argument(
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
+ )
+ parser.add_argument(
+ "--rank",
+ type=int,
+ default=4,
+ help=("The dimension of the LoRA update matrices."),
+ )
+ parser.add_argument(
+ "--tracker_name",
+ type=str,
+ default="diffusion-orpo-lora-sdxl",
+ help=("The name of the tracker to report results to."),
+ )
+
+ if input_args is not None:
+ args = parser.parse_args(input_args)
+ else:
+ args = parser.parse_args()
+
+ if args.dataset_name is None:
+ raise ValueError("Must provide a `dataset_name`.")
+
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ return args
+
+
+def tokenize_captions(tokenizers, examples):
+ captions = []
+ for caption in examples["caption"]:
+ captions.append(caption)
+
+ tokens_one = tokenizers[0](
+ captions, truncation=True, padding="max_length", max_length=tokenizers[0].model_max_length, return_tensors="pt"
+ ).input_ids
+ tokens_two = tokenizers[1](
+ captions, truncation=True, padding="max_length", max_length=tokenizers[1].model_max_length, return_tensors="pt"
+ ).input_ids
+
+ return tokens_one, tokens_two
+
+
+@torch.no_grad()
+def encode_prompt(text_encoders, text_input_ids_list):
+ prompt_embeds_list = []
+
+ for i, text_encoder in enumerate(text_encoders):
+ text_input_ids = text_input_ids_list[i]
+
+ prompt_embeds = text_encoder(
+ text_input_ids.to(text_encoder.device),
+ output_hidden_states=True,
+ )
+
+ # We are only ALWAYS interested in the pooled output of the final text encoder
+ pooled_prompt_embeds = prompt_embeds[0]
+ prompt_embeds = prompt_embeds.hidden_states[-2]
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
+ prompt_embeds_list.append(prompt_embeds)
+
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
+ pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
+ return prompt_embeds, pooled_prompt_embeds
+
+
+def main(args):
+ if args.report_to == "wandb" and args.hub_token is not None:
+ raise ValueError(
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
+ " Please use `huggingface-cli login` to authenticate with the Hub."
+ )
+
+ logging_dir = Path(args.output_dir, args.logging_dir)
+
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ )
+
+ # Disable AMP for MPS.
+ if torch.backends.mps.is_available():
+ accelerator.native_amp = False
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
+ ).repo_id
+
+ # Load the tokenizers
+ tokenizer_one = AutoTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="tokenizer",
+ revision=args.revision,
+ use_fast=False,
+ )
+ tokenizer_two = AutoTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="tokenizer_2",
+ revision=args.revision,
+ use_fast=False,
+ )
+
+ # import correct text encoder classes
+ text_encoder_cls_one = import_model_class_from_model_name_or_path(
+ args.pretrained_model_name_or_path, args.revision
+ )
+ text_encoder_cls_two = import_model_class_from_model_name_or_path(
+ args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2"
+ )
+
+ # Load scheduler and models
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
+
+ text_encoder_one = text_encoder_cls_one.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
+ )
+ text_encoder_two = text_encoder_cls_two.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant
+ )
+ vae_path = (
+ args.pretrained_model_name_or_path
+ if args.pretrained_vae_model_name_or_path is None
+ else args.pretrained_vae_model_name_or_path
+ )
+ vae = AutoencoderKL.from_pretrained(
+ vae_path,
+ subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
+ revision=args.revision,
+ variant=args.variant,
+ )
+ unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
+ )
+
+ # We only train the additional adapter LoRA layers
+ vae.requires_grad_(False)
+ text_encoder_one.requires_grad_(False)
+ text_encoder_two.requires_grad_(False)
+ unet.requires_grad_(False)
+
+ # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
+ # as these weights are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ # Move unet and text_encoders to device and cast to weight_dtype
+ unet.to(accelerator.device, dtype=weight_dtype)
+ text_encoder_one.to(accelerator.device, dtype=weight_dtype)
+ text_encoder_two.to(accelerator.device, dtype=weight_dtype)
+
+ # The VAE is always in float32 to avoid NaN losses.
+ vae.to(accelerator.device, dtype=torch.float32)
+
+ # Set up LoRA.
+ unet_lora_config = LoraConfig(
+ r=args.rank,
+ lora_alpha=args.rank,
+ init_lora_weights="gaussian",
+ target_modules=["to_k", "to_q", "to_v", "to_out.0"],
+ )
+ # Add adapter and make sure the trainable params are in float32.
+ unet.add_adapter(unet_lora_config)
+ if args.mixed_precision == "fp16":
+ for param in unet.parameters():
+ # only upcast trainable parameters (LoRA) into fp32
+ if param.requires_grad:
+ param.data = param.to(torch.float32)
+
+ if args.enable_xformers_memory_efficient_attention:
+ if is_xformers_available():
+ import xformers
+
+ xformers_version = version.parse(xformers.__version__)
+ if xformers_version == version.parse("0.0.16"):
+ logger.warning(
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
+ )
+ unet.enable_xformers_memory_efficient_attention()
+ else:
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
+
+ if args.gradient_checkpointing:
+ unet.enable_gradient_checkpointing()
+
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
+ def save_model_hook(models, weights, output_dir):
+ if accelerator.is_main_process:
+ # there are only two options here. Either are just the unet attn processor layers
+ # or there are the unet and text encoder atten layers
+ unet_lora_layers_to_save = None
+
+ for model in models:
+ if isinstance(model, type(accelerator.unwrap_model(unet))):
+ unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model))
+ else:
+ raise ValueError(f"unexpected save model: {model.__class__}")
+
+ # make sure to pop weight so that corresponding model is not saved again
+ weights.pop()
+
+ StableDiffusionXLLoraLoaderMixin.save_lora_weights(
+ output_dir,
+ unet_lora_layers=unet_lora_layers_to_save,
+ text_encoder_lora_layers=None,
+ text_encoder_2_lora_layers=None,
+ )
+
+ def load_model_hook(models, input_dir):
+ unet_ = None
+
+ while len(models) > 0:
+ model = models.pop()
+
+ if isinstance(model, type(accelerator.unwrap_model(unet))):
+ unet_ = model
+ else:
+ raise ValueError(f"unexpected save model: {model.__class__}")
+
+ lora_state_dict, network_alphas = StableDiffusionXLLoraLoaderMixin.lora_state_dict(input_dir)
+
+ unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")}
+ unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
+ incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
+ if incompatible_keys is not None:
+ # check only for unexpected keys
+ unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
+ if unexpected_keys:
+ logger.warning(
+ f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
+ f" {unexpected_keys}. "
+ )
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
+ )
+
+ optimizer_class = bnb.optim.AdamW8bit
+ else:
+ optimizer_class = torch.optim.AdamW
+
+ # Optimizer creation
+ params_to_optimize = list(filter(lambda p: p.requires_grad, unet.parameters()))
+ optimizer = optimizer_class(
+ params_to_optimize,
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ # Dataset and DataLoaders creation:
+ train_dataset = load_dataset(
+ args.dataset_name,
+ cache_dir=args.cache_dir,
+ split=args.dataset_split_name,
+ )
+
+ # Preprocessing the datasets.
+ train_resize = transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR)
+ train_crop = transforms.RandomCrop(args.resolution) if args.random_crop else transforms.CenterCrop(args.resolution)
+ train_flip = transforms.RandomHorizontalFlip(p=1.0)
+ to_tensor = transforms.ToTensor()
+ normalize = transforms.Normalize([0.5], [0.5])
+
+ def preprocess_train(examples):
+ all_pixel_values = []
+ images = [Image.open(io.BytesIO(im_bytes)).convert("RGB") for im_bytes in examples["jpg_0"]]
+ original_sizes = [(image.height, image.width) for image in images]
+ crop_top_lefts = []
+
+ for col_name in ["jpg_0", "jpg_1"]:
+ images = [Image.open(io.BytesIO(im_bytes)).convert("RGB") for im_bytes in examples[col_name]]
+ if col_name == "jpg_1":
+ # Need to bring down the image to the same resolution.
+ # This seems like the simplest reasonable approach.
+ # "::-1" because PIL resize takes (width, height).
+ images = [image.resize(original_sizes[i][::-1]) for i, image in enumerate(images)]
+ pixel_values = [to_tensor(image) for image in images]
+ all_pixel_values.append(pixel_values)
+
+ # Double on channel dim, jpg_y then jpg_w
+ im_tup_iterator = zip(*all_pixel_values)
+ combined_pixel_values = []
+ for im_tup, label_0 in zip(im_tup_iterator, examples["label_0"]):
+ # We randomize selection and rejection.
+ if label_0 == 0.5:
+ if random.random() < 0.5:
+ label_0 = 0
+ else:
+ label_0 = 1
+
+ if label_0 == 0:
+ im_tup = im_tup[::-1]
+
+ combined_im = torch.cat(im_tup, dim=0) # no batch dim
+
+ # Resize.
+ combined_im = train_resize(combined_im)
+
+ # Flipping.
+ if not args.no_hflip and random.random() < 0.5:
+ combined_im = train_flip(combined_im)
+
+ # Cropping.
+ if not args.random_crop:
+ y1 = max(0, int(round((combined_im.shape[1] - args.resolution) / 2.0)))
+ x1 = max(0, int(round((combined_im.shape[2] - args.resolution) / 2.0)))
+ combined_im = train_crop(combined_im)
+ else:
+ y1, x1, h, w = train_crop.get_params(combined_im, (args.resolution, args.resolution))
+ combined_im = crop(combined_im, y1, x1, h, w)
+
+ crop_top_left = (y1, x1)
+ crop_top_lefts.append(crop_top_left)
+ combined_im = normalize(combined_im)
+ combined_pixel_values.append(combined_im)
+
+ examples["pixel_values"] = combined_pixel_values
+ examples["original_sizes"] = original_sizes
+ examples["crop_top_lefts"] = crop_top_lefts
+ tokens_one, tokens_two = tokenize_captions([tokenizer_one, tokenizer_two], examples)
+ examples["input_ids_one"] = tokens_one
+ examples["input_ids_two"] = tokens_two
+ return examples
+
+ with accelerator.main_process_first():
+ if args.max_train_samples is not None:
+ train_dataset = train_dataset.shuffle(seed=args.seed).select(range(args.max_train_samples))
+ # Set the training transforms
+ train_dataset = train_dataset.with_transform(preprocess_train)
+
+ def collate_fn(examples):
+ pixel_values = torch.stack([example["pixel_values"] for example in examples])
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
+ original_sizes = [example["original_sizes"] for example in examples]
+ crop_top_lefts = [example["crop_top_lefts"] for example in examples]
+ input_ids_one = torch.stack([example["input_ids_one"] for example in examples])
+ input_ids_two = torch.stack([example["input_ids_two"] for example in examples])
+
+ return {
+ "pixel_values": pixel_values,
+ "input_ids_one": input_ids_one,
+ "input_ids_two": input_ids_two,
+ "original_sizes": original_sizes,
+ "crop_top_lefts": crop_top_lefts,
+ }
+
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset,
+ batch_size=args.train_batch_size,
+ shuffle=True,
+ collate_fn=collate_fn,
+ num_workers=args.dataloader_num_workers,
+ )
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
+ num_cycles=args.lr_num_cycles,
+ power=args.lr_power,
+ )
+
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ unet, optimizer, train_dataloader, lr_scheduler
+ )
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ accelerator.init_trackers(args.tracker_name, config=vars(args))
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the mos recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ initial_global_step = 0
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ initial_global_step = global_step
+ first_epoch = global_step // num_update_steps_per_epoch
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
+
+ unet.train()
+ for epoch in range(first_epoch, args.num_train_epochs):
+ for step, batch in enumerate(train_dataloader):
+ with accelerator.accumulate(unet):
+ # (batch_size, 2*channels, h, w) -> (2*batch_size, channels, h, w)
+ pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
+ feed_pixel_values = torch.cat(pixel_values.chunk(2, dim=1))
+
+ latents = []
+ for i in range(0, feed_pixel_values.shape[0], args.vae_encode_batch_size):
+ latents.append(
+ vae.encode(feed_pixel_values[i : i + args.vae_encode_batch_size]).latent_dist.sample()
+ )
+ latents = torch.cat(latents, dim=0)
+ latents = latents * vae.config.scaling_factor
+ if args.pretrained_vae_model_name_or_path is None:
+ latents = latents.to(weight_dtype)
+
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(latents).chunk(2)[0].repeat(2, 1, 1, 1)
+
+ # Sample a random timestep for each image
+ bsz = latents.shape[0] // 2
+ timesteps = torch.randint(
+ 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device, dtype=torch.long
+ ).repeat(2)
+
+ # Add noise to the model input according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ noisy_model_input = noise_scheduler.add_noise(latents, noise, timesteps)
+
+ # time ids
+ def compute_time_ids(original_size, crops_coords_top_left):
+ # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
+ target_size = (args.resolution, args.resolution)
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
+ add_time_ids = torch.tensor([add_time_ids])
+ add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype)
+ return add_time_ids
+
+ add_time_ids = torch.cat(
+ [compute_time_ids(s, c) for s, c in zip(batch["original_sizes"], batch["crop_top_lefts"])]
+ ).repeat(2, 1)
+
+ # Get the text embedding for conditioning
+ prompt_embeds, pooled_prompt_embeds = encode_prompt(
+ [text_encoder_one, text_encoder_two], [batch["input_ids_one"], batch["input_ids_two"]]
+ )
+ prompt_embeds = prompt_embeds.repeat(2, 1, 1)
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(2, 1)
+
+ # Predict the noise residual
+ model_pred = unet(
+ noisy_model_input,
+ timesteps,
+ prompt_embeds,
+ added_cond_kwargs={"time_ids": add_time_ids, "text_embeds": pooled_prompt_embeds},
+ ).sample
+
+ # Get the target for loss depending on the prediction type
+ if noise_scheduler.config.prediction_type == "epsilon":
+ target = noise
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
+ else:
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
+
+ # ODDS ratio loss.
+ # In the diffusion formulation, we're assuming that the MSE loss
+ # approximates the logp.
+ model_losses = F.mse_loss(model_pred.float(), target.float(), reduction="none")
+ model_losses = model_losses.mean(dim=list(range(1, len(model_losses.shape))))
+ model_losses_w, model_losses_l = model_losses.chunk(2)
+ log_odds = model_losses_w - model_losses_l
+
+ # Ratio loss.
+ ratio = F.logsigmoid(log_odds)
+ ratio_losses = args.beta_orpo * ratio
+
+ # Full ORPO loss
+ loss = model_losses_w.mean() - ratio_losses.mean()
+
+ # Backprop.
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ accelerator.clip_grad_norm_(params_to_optimize, args.max_grad_norm)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ global_step += 1
+
+ if accelerator.is_main_process:
+ if global_step % args.checkpointing_steps == 0:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ if args.run_validation and global_step % args.validation_steps == 0:
+ log_validation(
+ args, unet=unet, vae=vae, accelerator=accelerator, weight_dtype=weight_dtype, epoch=epoch
+ )
+
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+ accelerator.log(logs, step=global_step)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ # Save the lora layers
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ unet = accelerator.unwrap_model(unet)
+ unet = unet.to(torch.float32)
+ unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet))
+
+ StableDiffusionXLLoraLoaderMixin.save_lora_weights(
+ save_directory=args.output_dir,
+ unet_lora_layers=unet_lora_state_dict,
+ text_encoder_lora_layers=None,
+ text_encoder_2_lora_layers=None,
+ )
+
+ # Final validation?
+ if args.run_validation:
+ log_validation(
+ args,
+ unet=None,
+ vae=vae,
+ accelerator=accelerator,
+ weight_dtype=weight_dtype,
+ epoch=epoch,
+ is_final_validation=True,
+ )
+
+ if args.push_to_hub:
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ main(args)
diff --git a/diffusers/examples/research_projects/diffusion_orpo/train_diffusion_orpo_sdxl_lora_wds.py b/diffusers/examples/research_projects/diffusion_orpo/train_diffusion_orpo_sdxl_lora_wds.py
new file mode 100644
index 0000000000000000000000000000000000000000..cd1ef265d23eb3518847c00db1ea111afc6f1ba5
--- /dev/null
+++ b/diffusers/examples/research_projects/diffusion_orpo/train_diffusion_orpo_sdxl_lora_wds.py
@@ -0,0 +1,1095 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+import argparse
+import contextlib
+import logging
+import math
+import os
+import random
+import shutil
+from pathlib import Path
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import transformers
+import wandb
+import webdataset as wds
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import ProjectConfiguration, set_seed
+from huggingface_hub import create_repo, upload_folder
+from packaging import version
+from peft import LoraConfig, set_peft_model_state_dict
+from peft.utils import get_peft_model_state_dict
+from torchvision import transforms
+from torchvision.transforms.functional import crop
+from tqdm.auto import tqdm
+from transformers import AutoTokenizer, PretrainedConfig
+
+import diffusers
+from diffusers import (
+ AutoencoderKL,
+ DDPMScheduler,
+ DiffusionPipeline,
+ UNet2DConditionModel,
+)
+from diffusers.loaders import StableDiffusionXLLoraLoaderMixin
+from diffusers.optimization import get_scheduler
+from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, convert_unet_state_dict_to_peft
+from diffusers.utils.import_utils import is_xformers_available
+
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.25.0.dev0")
+
+logger = get_logger(__name__)
+
+
+VALIDATION_PROMPTS = [
+ "portrait photo of a girl, photograph, highly detailed face, depth of field, moody light, golden hour, style by Dan Winters, Russell James, Steve McCurry, centered, extremely detailed, Nikon D850, award winning photography",
+ "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k",
+ "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
+ "A photo of beautiful mountain with realistic sunset and blue lake, highly detailed, masterpiece",
+]
+
+
+def import_model_class_from_model_name_or_path(
+ pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
+):
+ text_encoder_config = PretrainedConfig.from_pretrained(
+ pretrained_model_name_or_path, subfolder=subfolder, revision=revision
+ )
+ model_class = text_encoder_config.architectures[0]
+
+ if model_class == "CLIPTextModel":
+ from transformers import CLIPTextModel
+
+ return CLIPTextModel
+ elif model_class == "CLIPTextModelWithProjection":
+ from transformers import CLIPTextModelWithProjection
+
+ return CLIPTextModelWithProjection
+ else:
+ raise ValueError(f"{model_class} is not supported.")
+
+
+def log_validation(args, unet, vae, accelerator, weight_dtype, epoch, is_final_validation=False):
+ logger.info(f"Running validation... \n Generating images with prompts:\n" f" {VALIDATION_PROMPTS}.")
+
+ if is_final_validation:
+ if args.mixed_precision == "fp16":
+ vae.to(weight_dtype)
+
+ # create pipeline
+ pipeline = DiffusionPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ vae=vae,
+ revision=args.revision,
+ variant=args.variant,
+ torch_dtype=weight_dtype,
+ )
+ if not is_final_validation:
+ pipeline.unet = accelerator.unwrap_model(unet)
+ else:
+ pipeline.load_lora_weights(args.output_dir, weight_name="pytorch_lora_weights.safetensors")
+
+ pipeline = pipeline.to(accelerator.device)
+ pipeline.set_progress_bar_config(disable=True)
+
+ # run inference
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
+ images = []
+ context = contextlib.nullcontext() if is_final_validation else torch.cuda.amp.autocast()
+
+ guidance_scale = 5.0
+ num_inference_steps = 25
+ for prompt in VALIDATION_PROMPTS:
+ with context:
+ image = pipeline(
+ prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, generator=generator
+ ).images[0]
+ images.append(image)
+
+ tracker_key = "test" if is_final_validation else "validation"
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images(tracker_key, np_images, epoch, dataformats="NHWC")
+ if tracker.name == "wandb":
+ tracker.log(
+ {
+ tracker_key: [
+ wandb.Image(image, caption=f"{i}: {VALIDATION_PROMPTS[i]}") for i, image in enumerate(images)
+ ]
+ }
+ )
+
+ # Also log images without the LoRA params for comparison.
+ if is_final_validation:
+ pipeline.disable_lora()
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
+ no_lora_images = [
+ pipeline(
+ prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, generator=generator
+ ).images[0]
+ for prompt in VALIDATION_PROMPTS
+ ]
+
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in no_lora_images])
+ tracker.writer.add_images("test_without_lora", np_images, epoch, dataformats="NHWC")
+ if tracker.name == "wandb":
+ tracker.log(
+ {
+ "test_without_lora": [
+ wandb.Image(image, caption=f"{i}: {VALIDATION_PROMPTS[i]}")
+ for i, image in enumerate(no_lora_images)
+ ]
+ }
+ )
+
+
+def parse_args(input_args=None):
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--pretrained_vae_model_name_or_path",
+ type=str,
+ default=None,
+ help="Path to pretrained VAE model with better numerical stability. More details: https://github.com/huggingface/diffusers/pull/4038.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--dataset_path",
+ type=str,
+ default="pipe:aws s3 cp s3://diffusion-preference-opt/{00000..00644}.tar -",
+ )
+ parser.add_argument(
+ "--num_train_examples",
+ type=int,
+ default=1001352,
+ )
+ parser.add_argument(
+ "--variant",
+ type=str,
+ default=None,
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
+ )
+ parser.add_argument(
+ "--run_validation",
+ default=False,
+ action="store_true",
+ help="Whether to run validation inference in between training and also after training. Helps to track progress.",
+ )
+ parser.add_argument(
+ "--validation_steps",
+ type=int,
+ default=200,
+ help="Run validation every X steps.",
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="diffusion-orpo-lora-sdxl",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=1024,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--vae_encode_batch_size",
+ type=int,
+ default=8,
+ help="Batch size to use for VAE encoding of the images for efficient processing.",
+ )
+ parser.add_argument(
+ "--no_hflip",
+ action="store_true",
+ help="whether to randomly flip images horizontally",
+ )
+ parser.add_argument(
+ "--random_crop",
+ default=False,
+ action="store_true",
+ help=(
+ "Whether to random crop the input images to the resolution. If not set, the images will be center-cropped."
+ ),
+ )
+ parser.add_argument("--global_batch_size", type=int, default=64, help="Total batch size.")
+ parser.add_argument(
+ "--per_gpu_batch_size", type=int, default=8, help="Number of samples in a batch for a single GPU."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=1)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
+ " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--beta_orpo",
+ type=float,
+ default=0.1,
+ help="ORPO contribution factor.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=5e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--lr_num_cycles",
+ type=int,
+ default=1,
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
+ )
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ parser.add_argument(
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument(
+ "--prior_generation_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp32", "fp16", "bf16"],
+ help=(
+ "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32."
+ ),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+ parser.add_argument(
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
+ )
+ parser.add_argument(
+ "--rank",
+ type=int,
+ default=4,
+ help=("The dimension of the LoRA update matrices."),
+ )
+ parser.add_argument(
+ "--tracker_name",
+ type=str,
+ default="diffusion-orpo-lora-sdxl",
+ help=("The name of the tracker to report results to."),
+ )
+
+ if input_args is not None:
+ args = parser.parse_args(input_args)
+ else:
+ args = parser.parse_args()
+
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ return args
+
+
+def tokenize_captions(tokenizers, sample):
+ tokens_one = tokenizers[0](
+ sample["original_prompt"],
+ truncation=True,
+ padding="max_length",
+ max_length=tokenizers[0].model_max_length,
+ return_tensors="pt",
+ ).input_ids
+ tokens_two = tokenizers[1](
+ sample["original_prompt"],
+ truncation=True,
+ padding="max_length",
+ max_length=tokenizers[1].model_max_length,
+ return_tensors="pt",
+ ).input_ids
+
+ return tokens_one, tokens_two
+
+
+@torch.no_grad()
+def encode_prompt(text_encoders, text_input_ids_list):
+ prompt_embeds_list = []
+
+ for i, text_encoder in enumerate(text_encoders):
+ text_input_ids = text_input_ids_list[i]
+
+ prompt_embeds = text_encoder(
+ text_input_ids.to(text_encoder.device),
+ output_hidden_states=True,
+ )
+
+ # We are only ALWAYS interested in the pooled output of the final text encoder
+ pooled_prompt_embeds = prompt_embeds[0]
+ prompt_embeds = prompt_embeds.hidden_states[-2]
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
+ prompt_embeds_list.append(prompt_embeds)
+
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
+ pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
+ return prompt_embeds, pooled_prompt_embeds
+
+
+def get_dataset(args):
+ dataset = (
+ wds.WebDataset(args.dataset_path, resampled=True, handler=wds.warn_and_continue)
+ .shuffle(690, handler=wds.warn_and_continue)
+ .decode("pil", handler=wds.warn_and_continue)
+ .rename(
+ original_prompt="original_prompt.txt",
+ jpg_0="jpg_0.jpg",
+ jpg_1="jpg_1.jpg",
+ label_0="label_0.txt",
+ label_1="label_1.txt",
+ handler=wds.warn_and_continue,
+ )
+ )
+ return dataset
+
+
+def get_loader(args, tokenizer_one, tokenizer_two):
+ # 1,001,352
+ num_batches = math.ceil(args.num_train_examples / args.global_batch_size)
+ num_worker_batches = math.ceil(
+ args.num_train_examples / (args.global_batch_size * args.dataloader_num_workers)
+ ) # per dataloader worker
+ num_batches = num_worker_batches * args.dataloader_num_workers
+ num_samples = num_batches * args.global_batch_size
+
+ dataset = get_dataset(args)
+
+ train_resize = transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR)
+ train_crop = transforms.RandomCrop(args.resolution) if args.random_crop else transforms.CenterCrop(args.resolution)
+ train_flip = transforms.RandomHorizontalFlip(p=1.0)
+ to_tensor = transforms.ToTensor()
+ normalize = transforms.Normalize([0.5], [0.5])
+
+ def preprocess_images(sample):
+ jpg_0_image = sample["jpg_0"]
+ original_size = (jpg_0_image.height, jpg_0_image.width)
+ crop_top_left = []
+
+ jpg_1_image = sample["jpg_1"]
+ # Need to bring down the image to the same resolution.
+ # This seems like the simplest reasonable approach.
+ # "::-1" because PIL resize takes (width, height).
+ jpg_1_image = jpg_1_image.resize(original_size[::-1])
+
+ # We randomize selection and rejection.
+ label_0 = sample["label_0"]
+ if sample["label_0"] == 0.5:
+ if random.random() < 0.5:
+ label_0 = 0
+ else:
+ label_0 = 1
+
+ # Double on channel dim, jpg_y then jpg_w
+ if label_0 == 0:
+ pixel_values = torch.cat([to_tensor(image) for image in [jpg_1_image, jpg_0_image]])
+ else:
+ pixel_values = torch.cat([to_tensor(image) for image in [jpg_0_image, jpg_1_image]])
+
+ # Resize.
+ combined_im = train_resize(pixel_values)
+
+ # Flipping.
+ if not args.no_hflip and random.random() < 0.5:
+ combined_im = train_flip(combined_im)
+
+ # Cropping.
+ if not args.random_crop:
+ y1 = max(0, int(round((combined_im.shape[1] - args.resolution) / 2.0)))
+ x1 = max(0, int(round((combined_im.shape[2] - args.resolution) / 2.0)))
+ combined_im = train_crop(combined_im)
+ else:
+ y1, x1, h, w = train_crop.get_params(combined_im, (args.resolution, args.resolution))
+ combined_im = crop(combined_im, y1, x1, h, w)
+
+ crop_top_left = (y1, x1)
+ combined_im = normalize(combined_im)
+ tokens_one, tokens_two = tokenize_captions([tokenizer_one, tokenizer_two], sample)
+
+ return {
+ "pixel_values": combined_im,
+ "original_size": original_size,
+ "crop_top_left": crop_top_left,
+ "tokens_one": tokens_one,
+ "tokens_two": tokens_two,
+ }
+
+ def collate_fn(samples):
+ pixel_values = torch.stack([sample["pixel_values"] for sample in samples])
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
+
+ original_sizes = [example["original_size"] for example in samples]
+ crop_top_lefts = [example["crop_top_left"] for example in samples]
+ input_ids_one = torch.stack([example["tokens_one"] for example in samples])
+ input_ids_two = torch.stack([example["tokens_two"] for example in samples])
+
+ return {
+ "pixel_values": pixel_values,
+ "input_ids_one": input_ids_one,
+ "input_ids_two": input_ids_two,
+ "original_sizes": original_sizes,
+ "crop_top_lefts": crop_top_lefts,
+ }
+
+ dataset = dataset.map(preprocess_images, handler=wds.warn_and_continue)
+ dataset = dataset.batched(args.per_gpu_batch_size, partial=False, collation_fn=collate_fn)
+ dataset = dataset.with_epoch(num_worker_batches)
+
+ dataloader = wds.WebLoader(
+ dataset,
+ batch_size=None,
+ shuffle=False,
+ num_workers=args.dataloader_num_workers,
+ pin_memory=True,
+ persistent_workers=True,
+ )
+ # add meta-data to dataloader instance for convenience
+ dataloader.num_batches = num_batches
+ dataloader.num_samples = num_samples
+ return dataloader
+
+
+def main(args):
+ if args.report_to == "wandb" and args.hub_token is not None:
+ raise ValueError(
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
+ " Please use `huggingface-cli login` to authenticate with the Hub."
+ )
+
+ logging_dir = Path(args.output_dir, args.logging_dir)
+
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ )
+
+ # Disable AMP for MPS.
+ if torch.backends.mps.is_available():
+ accelerator.native_amp = False
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
+ ).repo_id
+
+ # Load the tokenizers
+ tokenizer_one = AutoTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="tokenizer",
+ revision=args.revision,
+ use_fast=False,
+ )
+ tokenizer_two = AutoTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="tokenizer_2",
+ revision=args.revision,
+ use_fast=False,
+ )
+
+ # import correct text encoder classes
+ text_encoder_cls_one = import_model_class_from_model_name_or_path(
+ args.pretrained_model_name_or_path, args.revision
+ )
+ text_encoder_cls_two = import_model_class_from_model_name_or_path(
+ args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2"
+ )
+
+ # Load scheduler and models
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
+
+ text_encoder_one = text_encoder_cls_one.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
+ )
+ text_encoder_two = text_encoder_cls_two.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant
+ )
+ vae_path = (
+ args.pretrained_model_name_or_path
+ if args.pretrained_vae_model_name_or_path is None
+ else args.pretrained_vae_model_name_or_path
+ )
+ vae = AutoencoderKL.from_pretrained(
+ vae_path,
+ subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
+ revision=args.revision,
+ variant=args.variant,
+ )
+ unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
+ )
+
+ # We only train the additional adapter LoRA layers
+ vae.requires_grad_(False)
+ text_encoder_one.requires_grad_(False)
+ text_encoder_two.requires_grad_(False)
+ unet.requires_grad_(False)
+
+ # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
+ # as these weights are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ # Move unet and text_encoders to device and cast to weight_dtype
+ unet.to(accelerator.device, dtype=weight_dtype)
+ text_encoder_one.to(accelerator.device, dtype=weight_dtype)
+ text_encoder_two.to(accelerator.device, dtype=weight_dtype)
+
+ # The VAE is always in float32 to avoid NaN losses.
+ vae.to(accelerator.device, dtype=torch.float32)
+
+ # Set up LoRA.
+ unet_lora_config = LoraConfig(
+ r=args.rank,
+ lora_alpha=args.rank,
+ init_lora_weights="gaussian",
+ target_modules=["to_k", "to_q", "to_v", "to_out.0"],
+ )
+ # Add adapter and make sure the trainable params are in float32.
+ unet.add_adapter(unet_lora_config)
+ if args.mixed_precision == "fp16":
+ for param in unet.parameters():
+ # only upcast trainable parameters (LoRA) into fp32
+ if param.requires_grad:
+ param.data = param.to(torch.float32)
+
+ if args.enable_xformers_memory_efficient_attention:
+ if is_xformers_available():
+ import xformers
+
+ xformers_version = version.parse(xformers.__version__)
+ if xformers_version == version.parse("0.0.16"):
+ logger.warning(
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
+ )
+ unet.enable_xformers_memory_efficient_attention()
+ else:
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
+
+ if args.gradient_checkpointing:
+ unet.enable_gradient_checkpointing()
+
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
+ def save_model_hook(models, weights, output_dir):
+ if accelerator.is_main_process:
+ # there are only two options here. Either are just the unet attn processor layers
+ # or there are the unet and text encoder atten layers
+ unet_lora_layers_to_save = None
+
+ for model in models:
+ if isinstance(model, type(accelerator.unwrap_model(unet))):
+ unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model))
+ else:
+ raise ValueError(f"unexpected save model: {model.__class__}")
+
+ # make sure to pop weight so that corresponding model is not saved again
+ weights.pop()
+
+ StableDiffusionXLLoraLoaderMixin.save_lora_weights(
+ output_dir,
+ unet_lora_layers=unet_lora_layers_to_save,
+ text_encoder_lora_layers=None,
+ text_encoder_2_lora_layers=None,
+ )
+
+ def load_model_hook(models, input_dir):
+ unet_ = None
+
+ while len(models) > 0:
+ model = models.pop()
+
+ if isinstance(model, type(accelerator.unwrap_model(unet))):
+ unet_ = model
+ else:
+ raise ValueError(f"unexpected save model: {model.__class__}")
+
+ lora_state_dict, network_alphas = StableDiffusionXLLoraLoaderMixin.lora_state_dict(input_dir)
+
+ unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")}
+ unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
+ incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
+ if incompatible_keys is not None:
+ # check only for unexpected keys
+ unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
+ if unexpected_keys:
+ logger.warning(
+ f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
+ f" {unexpected_keys}. "
+ )
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rat * args.gradient_accumulation_steps * args.per_gpu_batch_size * accelerator.num_processes
+ )
+
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
+ )
+
+ optimizer_class = bnb.optim.AdamW8bit
+ else:
+ optimizer_class = torch.optim.AdamW
+
+ # Optimizer creation
+ params_to_optimize = list(filter(lambda p: p.requires_grad, unet.parameters()))
+ optimizer = optimizer_class(
+ params_to_optimize,
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ # Dataset and DataLoaders creation:
+ args.global_batch_size = args.per_gpu_batch_size * accelerator.num_processes
+ train_dataloader = get_loader(args, tokenizer_one=tokenizer_one, tokenizer_two=tokenizer_two)
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(train_dataloader.num_batches / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
+ num_cycles=args.lr_num_cycles,
+ power=args.lr_power,
+ )
+
+ unet, optimizer, lr_scheduler = accelerator.prepare(unet, optimizer, lr_scheduler)
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(train_dataloader.num_batches / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ accelerator.init_trackers(args.tracker_name, config=vars(args))
+
+ # Train!
+ total_batch_size = args.per_gpu_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {train_dataloader.num_samples}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.per_gpu_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the mos recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ initial_global_step = 0
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ initial_global_step = global_step
+ first_epoch = global_step // num_update_steps_per_epoch
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
+
+ unet.train()
+ for epoch in range(first_epoch, args.num_train_epochs):
+ for step, batch in enumerate(train_dataloader):
+ with accelerator.accumulate(unet):
+ # (batch_size, 2*channels, h, w) -> (2*batch_size, channels, h, w)
+ pixel_values = batch["pixel_values"].to(dtype=vae.dtype, device=accelerator.device, non_blocking=True)
+ feed_pixel_values = torch.cat(pixel_values.chunk(2, dim=1))
+
+ latents = []
+ for i in range(0, feed_pixel_values.shape[0], args.vae_encode_batch_size):
+ latents.append(
+ vae.encode(feed_pixel_values[i : i + args.vae_encode_batch_size]).latent_dist.sample()
+ )
+ latents = torch.cat(latents, dim=0)
+ latents = latents * vae.config.scaling_factor
+ if args.pretrained_vae_model_name_or_path is None:
+ latents = latents.to(weight_dtype)
+
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(latents).chunk(2)[0].repeat(2, 1, 1, 1)
+
+ # Sample a random timestep for each image
+ bsz = latents.shape[0] // 2
+ timesteps = torch.randint(
+ 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device, dtype=torch.long
+ ).repeat(2)
+
+ # Add noise to the model input according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ noisy_model_input = noise_scheduler.add_noise(latents, noise, timesteps)
+
+ # time ids
+ def compute_time_ids(original_size, crops_coords_top_left):
+ # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
+ target_size = (args.resolution, args.resolution)
+ add_time_ids = list(tuple(original_size) + tuple(crops_coords_top_left) + target_size)
+ add_time_ids = torch.tensor([add_time_ids])
+ add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype)
+ return add_time_ids
+
+ add_time_ids = torch.cat(
+ [compute_time_ids(s, c) for s, c in zip(batch["original_sizes"], batch["crop_top_lefts"])]
+ ).repeat(2, 1)
+
+ # Get the text embedding for conditioning
+ prompt_embeds, pooled_prompt_embeds = encode_prompt(
+ [text_encoder_one, text_encoder_two], [batch["input_ids_one"], batch["input_ids_two"]]
+ )
+ prompt_embeds = prompt_embeds.repeat(2, 1, 1)
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(2, 1)
+
+ # Predict the noise residual
+ model_pred = unet(
+ noisy_model_input,
+ timesteps,
+ prompt_embeds,
+ added_cond_kwargs={"time_ids": add_time_ids, "text_embeds": pooled_prompt_embeds},
+ ).sample
+
+ # Get the target for loss depending on the prediction type
+ if noise_scheduler.config.prediction_type == "epsilon":
+ target = noise
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
+ else:
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
+
+ # ODDS ratio loss.
+ # In the diffusion formulation, we're assuming that the MSE loss
+ # approximates the logp.
+ model_losses = F.mse_loss(model_pred.float(), target.float(), reduction="none")
+ model_losses = model_losses.mean(dim=list(range(1, len(model_losses.shape))))
+ model_losses_w, model_losses_l = model_losses.chunk(2)
+ log_odds = model_losses_w - model_losses_l
+
+ # Ratio loss.
+ ratio = F.logsigmoid(log_odds)
+ ratio_losses = args.beta_orpo * ratio
+
+ # Full ORPO loss
+ loss = model_losses_w.mean() - ratio_losses.mean()
+
+ # Backprop.
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ accelerator.clip_grad_norm_(params_to_optimize, args.max_grad_norm)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ global_step += 1
+
+ if accelerator.is_main_process:
+ if global_step % args.checkpointing_steps == 0:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ if args.run_validation and global_step % args.validation_steps == 0:
+ log_validation(
+ args, unet=unet, vae=vae, accelerator=accelerator, weight_dtype=weight_dtype, epoch=epoch
+ )
+
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+ accelerator.log(logs, step=global_step)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ # Save the lora layers
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ unet = accelerator.unwrap_model(unet)
+ unet = unet.to(torch.float32)
+ unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet))
+
+ StableDiffusionXLLoraLoaderMixin.save_lora_weights(
+ save_directory=args.output_dir,
+ unet_lora_layers=unet_lora_state_dict,
+ text_encoder_lora_layers=None,
+ text_encoder_2_lora_layers=None,
+ )
+
+ # Final validation?
+ if args.run_validation:
+ log_validation(
+ args,
+ unet=None,
+ vae=vae,
+ accelerator=accelerator,
+ weight_dtype=weight_dtype,
+ epoch=epoch,
+ is_final_validation=True,
+ )
+
+ if args.push_to_hub:
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ main(args)
diff --git a/diffusers/examples/research_projects/dreambooth_inpaint/README.md b/diffusers/examples/research_projects/dreambooth_inpaint/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..46703fa982e50ecaea2c71e07c564e5509215dcb
--- /dev/null
+++ b/diffusers/examples/research_projects/dreambooth_inpaint/README.md
@@ -0,0 +1,118 @@
+# Dreambooth for the inpainting model
+
+This script was added by @thedarkzeno .
+
+Please note that this script is not actively maintained, you can open an issue and tag @thedarkzeno or @patil-suraj though.
+
+```bash
+export MODEL_NAME="runwayml/stable-diffusion-inpainting"
+export INSTANCE_DIR="path-to-instance-images"
+export OUTPUT_DIR="path-to-save-model"
+
+accelerate launch train_dreambooth_inpaint.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --instance_data_dir=$INSTANCE_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --instance_prompt="a photo of sks dog" \
+ --resolution=512 \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=1 \
+ --learning_rate=5e-6 \
+ --lr_scheduler="constant" \
+ --lr_warmup_steps=0 \
+ --max_train_steps=400
+```
+
+### Training with prior-preservation loss
+
+Prior-preservation is used to avoid overfitting and language-drift. Refer to the paper to learn more about it. For prior-preservation we first generate images using the model with a class prompt and then use those during training along with our data.
+According to the paper, it's recommended to generate `num_epochs * num_samples` images for prior-preservation. 200-300 works well for most cases.
+
+```bash
+export MODEL_NAME="runwayml/stable-diffusion-inpainting"
+export INSTANCE_DIR="path-to-instance-images"
+export CLASS_DIR="path-to-class-images"
+export OUTPUT_DIR="path-to-save-model"
+
+accelerate launch train_dreambooth_inpaint.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --instance_data_dir=$INSTANCE_DIR \
+ --class_data_dir=$CLASS_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --with_prior_preservation --prior_loss_weight=1.0 \
+ --instance_prompt="a photo of sks dog" \
+ --class_prompt="a photo of dog" \
+ --resolution=512 \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=1 \
+ --learning_rate=5e-6 \
+ --lr_scheduler="constant" \
+ --lr_warmup_steps=0 \
+ --num_class_images=200 \
+ --max_train_steps=800
+```
+
+
+### Training with gradient checkpointing and 8-bit optimizer:
+
+With the help of gradient checkpointing and the 8-bit optimizer from bitsandbytes it's possible to run train dreambooth on a 16GB GPU.
+
+To install `bitandbytes` please refer to this [readme](https://github.com/TimDettmers/bitsandbytes#requirements--installation).
+
+```bash
+export MODEL_NAME="runwayml/stable-diffusion-inpainting"
+export INSTANCE_DIR="path-to-instance-images"
+export CLASS_DIR="path-to-class-images"
+export OUTPUT_DIR="path-to-save-model"
+
+accelerate launch train_dreambooth_inpaint.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --instance_data_dir=$INSTANCE_DIR \
+ --class_data_dir=$CLASS_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --with_prior_preservation --prior_loss_weight=1.0 \
+ --instance_prompt="a photo of sks dog" \
+ --class_prompt="a photo of dog" \
+ --resolution=512 \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=2 --gradient_checkpointing \
+ --use_8bit_adam \
+ --learning_rate=5e-6 \
+ --lr_scheduler="constant" \
+ --lr_warmup_steps=0 \
+ --num_class_images=200 \
+ --max_train_steps=800
+```
+
+### Fine-tune text encoder with the UNet.
+
+The script also allows to fine-tune the `text_encoder` along with the `unet`. It's been observed experimentally that fine-tuning `text_encoder` gives much better results especially on faces.
+Pass the `--train_text_encoder` argument to the script to enable training `text_encoder`.
+
+___Note: Training text encoder requires more memory, with this option the training won't fit on 16GB GPU. It needs at least 24GB VRAM.___
+
+```bash
+export MODEL_NAME="runwayml/stable-diffusion-inpainting"
+export INSTANCE_DIR="path-to-instance-images"
+export CLASS_DIR="path-to-class-images"
+export OUTPUT_DIR="path-to-save-model"
+
+accelerate launch train_dreambooth_inpaint.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --train_text_encoder \
+ --instance_data_dir=$INSTANCE_DIR \
+ --class_data_dir=$CLASS_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --with_prior_preservation --prior_loss_weight=1.0 \
+ --instance_prompt="a photo of sks dog" \
+ --class_prompt="a photo of dog" \
+ --resolution=512 \
+ --train_batch_size=1 \
+ --use_8bit_adam \
+ --gradient_checkpointing \
+ --learning_rate=2e-6 \
+ --lr_scheduler="constant" \
+ --lr_warmup_steps=0 \
+ --num_class_images=200 \
+ --max_train_steps=800
+```
diff --git a/diffusers/examples/research_projects/dreambooth_inpaint/requirements.txt b/diffusers/examples/research_projects/dreambooth_inpaint/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..aad6387026f181053d1872fd1961c7b56e86f1df
--- /dev/null
+++ b/diffusers/examples/research_projects/dreambooth_inpaint/requirements.txt
@@ -0,0 +1,7 @@
+diffusers==0.9.0
+accelerate>=0.16.0
+torchvision
+transformers>=4.21.0
+ftfy
+tensorboard
+Jinja2
diff --git a/diffusers/examples/research_projects/dreambooth_inpaint/train_dreambooth_inpaint.py b/diffusers/examples/research_projects/dreambooth_inpaint/train_dreambooth_inpaint.py
new file mode 100644
index 0000000000000000000000000000000000000000..9b484c89fdbd87720084675ea106fdf1223ce231
--- /dev/null
+++ b/diffusers/examples/research_projects/dreambooth_inpaint/train_dreambooth_inpaint.py
@@ -0,0 +1,812 @@
+import argparse
+import itertools
+import math
+import os
+import random
+from pathlib import Path
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import ProjectConfiguration, set_seed
+from huggingface_hub import create_repo, upload_folder
+from huggingface_hub.utils import insecure_hashlib
+from PIL import Image, ImageDraw
+from torch.utils.data import Dataset
+from torchvision import transforms
+from tqdm.auto import tqdm
+from transformers import CLIPTextModel, CLIPTokenizer
+
+from diffusers import (
+ AutoencoderKL,
+ DDPMScheduler,
+ StableDiffusionInpaintPipeline,
+ StableDiffusionPipeline,
+ UNet2DConditionModel,
+)
+from diffusers.optimization import get_scheduler
+from diffusers.utils import check_min_version
+
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.13.0.dev0")
+
+logger = get_logger(__name__)
+
+
+def prepare_mask_and_masked_image(image, mask):
+ image = np.array(image.convert("RGB"))
+ image = image[None].transpose(0, 3, 1, 2)
+ image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
+
+ mask = np.array(mask.convert("L"))
+ mask = mask.astype(np.float32) / 255.0
+ mask = mask[None, None]
+ mask[mask < 0.5] = 0
+ mask[mask >= 0.5] = 1
+ mask = torch.from_numpy(mask)
+
+ masked_image = image * (mask < 0.5)
+
+ return mask, masked_image
+
+
+# generate random masks
+def random_mask(im_shape, ratio=1, mask_full_image=False):
+ mask = Image.new("L", im_shape, 0)
+ draw = ImageDraw.Draw(mask)
+ size = (random.randint(0, int(im_shape[0] * ratio)), random.randint(0, int(im_shape[1] * ratio)))
+ # use this to always mask the whole image
+ if mask_full_image:
+ size = (int(im_shape[0] * ratio), int(im_shape[1] * ratio))
+ limits = (im_shape[0] - size[0] // 2, im_shape[1] - size[1] // 2)
+ center = (random.randint(size[0] // 2, limits[0]), random.randint(size[1] // 2, limits[1]))
+ draw_type = random.randint(0, 1)
+ if draw_type == 0 or mask_full_image:
+ draw.rectangle(
+ (center[0] - size[0] // 2, center[1] - size[1] // 2, center[0] + size[0] // 2, center[1] + size[1] // 2),
+ fill=255,
+ )
+ else:
+ draw.ellipse(
+ (center[0] - size[0] // 2, center[1] - size[1] // 2, center[0] + size[0] // 2, center[1] + size[1] // 2),
+ fill=255,
+ )
+
+ return mask
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--tokenizer_name",
+ type=str,
+ default=None,
+ help="Pretrained tokenizer name or path if not the same as model_name",
+ )
+ parser.add_argument(
+ "--instance_data_dir",
+ type=str,
+ default=None,
+ required=True,
+ help="A folder containing the training data of instance images.",
+ )
+ parser.add_argument(
+ "--class_data_dir",
+ type=str,
+ default=None,
+ required=False,
+ help="A folder containing the training data of class images.",
+ )
+ parser.add_argument(
+ "--instance_prompt",
+ type=str,
+ default=None,
+ help="The prompt with identifier specifying the instance",
+ )
+ parser.add_argument(
+ "--class_prompt",
+ type=str,
+ default=None,
+ help="The prompt to specify images in the same class as provided instance images.",
+ )
+ parser.add_argument(
+ "--with_prior_preservation",
+ default=False,
+ action="store_true",
+ help="Flag to add prior preservation loss.",
+ )
+ parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.")
+ parser.add_argument(
+ "--num_class_images",
+ type=int,
+ default=100,
+ help=(
+ "Minimal class images for prior preservation loss. If not have enough images, additional images will be"
+ " sampled with class_prompt."
+ ),
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="text-inversion-model",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--center_crop",
+ default=False,
+ action="store_true",
+ help=(
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
+ " cropped. The images will be resized to the resolution first before cropping."
+ ),
+ )
+ parser.add_argument("--train_text_encoder", action="store_true", help="Whether to train the text encoder")
+ parser.add_argument(
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument(
+ "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=1)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=5e-6,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default="no",
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose"
+ "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
+ "and an Nvidia Ampere GPU."
+ ),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
+ " checkpoints in case they are better than the last checkpoint and are suitable for resuming training"
+ " using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=(
+ "Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
+ " See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
+ " for more docs"
+ ),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+
+ args = parser.parse_args()
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ if args.instance_data_dir is None:
+ raise ValueError("You must specify a train data directory.")
+
+ if args.with_prior_preservation:
+ if args.class_data_dir is None:
+ raise ValueError("You must specify a data directory for class images.")
+ if args.class_prompt is None:
+ raise ValueError("You must specify prompt for class images.")
+
+ return args
+
+
+class DreamBoothDataset(Dataset):
+ """
+ A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
+ It pre-processes the images and the tokenizes prompts.
+ """
+
+ def __init__(
+ self,
+ instance_data_root,
+ instance_prompt,
+ tokenizer,
+ class_data_root=None,
+ class_prompt=None,
+ size=512,
+ center_crop=False,
+ ):
+ self.size = size
+ self.center_crop = center_crop
+ self.tokenizer = tokenizer
+
+ self.instance_data_root = Path(instance_data_root)
+ if not self.instance_data_root.exists():
+ raise ValueError("Instance images root doesn't exists.")
+
+ self.instance_images_path = list(Path(instance_data_root).iterdir())
+ self.num_instance_images = len(self.instance_images_path)
+ self.instance_prompt = instance_prompt
+ self._length = self.num_instance_images
+
+ if class_data_root is not None:
+ self.class_data_root = Path(class_data_root)
+ self.class_data_root.mkdir(parents=True, exist_ok=True)
+ self.class_images_path = list(self.class_data_root.iterdir())
+ self.num_class_images = len(self.class_images_path)
+ self._length = max(self.num_class_images, self.num_instance_images)
+ self.class_prompt = class_prompt
+ else:
+ self.class_data_root = None
+
+ self.image_transforms_resize_and_crop = transforms.Compose(
+ [
+ transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
+ transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
+ ]
+ )
+
+ self.image_transforms = transforms.Compose(
+ [
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+ )
+
+ def __len__(self):
+ return self._length
+
+ def __getitem__(self, index):
+ example = {}
+ instance_image = Image.open(self.instance_images_path[index % self.num_instance_images])
+ if not instance_image.mode == "RGB":
+ instance_image = instance_image.convert("RGB")
+ instance_image = self.image_transforms_resize_and_crop(instance_image)
+
+ example["PIL_images"] = instance_image
+ example["instance_images"] = self.image_transforms(instance_image)
+
+ example["instance_prompt_ids"] = self.tokenizer(
+ self.instance_prompt,
+ padding="do_not_pad",
+ truncation=True,
+ max_length=self.tokenizer.model_max_length,
+ ).input_ids
+
+ if self.class_data_root:
+ class_image = Image.open(self.class_images_path[index % self.num_class_images])
+ if not class_image.mode == "RGB":
+ class_image = class_image.convert("RGB")
+ class_image = self.image_transforms_resize_and_crop(class_image)
+ example["class_images"] = self.image_transforms(class_image)
+ example["class_PIL_images"] = class_image
+ example["class_prompt_ids"] = self.tokenizer(
+ self.class_prompt,
+ padding="do_not_pad",
+ truncation=True,
+ max_length=self.tokenizer.model_max_length,
+ ).input_ids
+
+ return example
+
+
+class PromptDataset(Dataset):
+ """A simple dataset to prepare the prompts to generate class images on multiple GPUs."""
+
+ def __init__(self, prompt, num_samples):
+ self.prompt = prompt
+ self.num_samples = num_samples
+
+ def __len__(self):
+ return self.num_samples
+
+ def __getitem__(self, index):
+ example = {}
+ example["prompt"] = self.prompt
+ example["index"] = index
+ return example
+
+
+def main():
+ args = parse_args()
+ logging_dir = Path(args.output_dir, args.logging_dir)
+
+ project_config = ProjectConfiguration(
+ total_limit=args.checkpoints_total_limit, project_dir=args.output_dir, logging_dir=logging_dir
+ )
+
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with="tensorboard",
+ project_config=project_config,
+ )
+
+ # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
+ # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
+ # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.
+ if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1:
+ raise ValueError(
+ "Gradient accumulation is not supported when training the text encoder in distributed training. "
+ "Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
+ )
+
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ if args.with_prior_preservation:
+ class_images_dir = Path(args.class_data_dir)
+ if not class_images_dir.exists():
+ class_images_dir.mkdir(parents=True)
+ cur_class_images = len(list(class_images_dir.iterdir()))
+
+ if cur_class_images < args.num_class_images:
+ torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
+ pipeline = StableDiffusionInpaintPipeline.from_pretrained(
+ args.pretrained_model_name_or_path, torch_dtype=torch_dtype, safety_checker=None
+ )
+ pipeline.set_progress_bar_config(disable=True)
+
+ num_new_images = args.num_class_images - cur_class_images
+ logger.info(f"Number of class images to sample: {num_new_images}.")
+
+ sample_dataset = PromptDataset(args.class_prompt, num_new_images)
+ sample_dataloader = torch.utils.data.DataLoader(
+ sample_dataset, batch_size=args.sample_batch_size, num_workers=1
+ )
+
+ sample_dataloader = accelerator.prepare(sample_dataloader)
+ pipeline.to(accelerator.device)
+ transform_to_pil = transforms.ToPILImage()
+ for example in tqdm(
+ sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
+ ):
+ bsz = len(example["prompt"])
+ fake_images = torch.rand((3, args.resolution, args.resolution))
+ transform_to_pil = transforms.ToPILImage()
+ fake_pil_images = transform_to_pil(fake_images)
+
+ fake_mask = random_mask((args.resolution, args.resolution), ratio=1, mask_full_image=True)
+
+ images = pipeline(prompt=example["prompt"], mask_image=fake_mask, image=fake_pil_images).images
+
+ for i, image in enumerate(images):
+ hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest()
+ image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
+ image.save(image_filename)
+
+ del pipeline
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
+ ).repo_id
+
+ # Load the tokenizer
+ if args.tokenizer_name:
+ tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
+ elif args.pretrained_model_name_or_path:
+ tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
+
+ # Load models and create wrapper for stable diffusion
+ text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
+ vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
+ unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
+
+ vae.requires_grad_(False)
+ if not args.train_text_encoder:
+ text_encoder.requires_grad_(False)
+
+ if args.gradient_checkpointing:
+ unet.enable_gradient_checkpointing()
+ if args.train_text_encoder:
+ text_encoder.gradient_checkpointing_enable()
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
+ )
+
+ optimizer_class = bnb.optim.AdamW8bit
+ else:
+ optimizer_class = torch.optim.AdamW
+
+ params_to_optimize = (
+ itertools.chain(unet.parameters(), text_encoder.parameters()) if args.train_text_encoder else unet.parameters()
+ )
+ optimizer = optimizer_class(
+ params_to_optimize,
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
+
+ train_dataset = DreamBoothDataset(
+ instance_data_root=args.instance_data_dir,
+ instance_prompt=args.instance_prompt,
+ class_data_root=args.class_data_dir if args.with_prior_preservation else None,
+ class_prompt=args.class_prompt,
+ tokenizer=tokenizer,
+ size=args.resolution,
+ center_crop=args.center_crop,
+ )
+
+ def collate_fn(examples):
+ input_ids = [example["instance_prompt_ids"] for example in examples]
+ pixel_values = [example["instance_images"] for example in examples]
+
+ # Concat class and instance examples for prior preservation.
+ # We do this to avoid doing two forward passes.
+ if args.with_prior_preservation:
+ input_ids += [example["class_prompt_ids"] for example in examples]
+ pixel_values += [example["class_images"] for example in examples]
+ pior_pil = [example["class_PIL_images"] for example in examples]
+
+ masks = []
+ masked_images = []
+ for example in examples:
+ pil_image = example["PIL_images"]
+ # generate a random mask
+ mask = random_mask(pil_image.size, 1, False)
+ # prepare mask and masked image
+ mask, masked_image = prepare_mask_and_masked_image(pil_image, mask)
+
+ masks.append(mask)
+ masked_images.append(masked_image)
+
+ if args.with_prior_preservation:
+ for pil_image in pior_pil:
+ # generate a random mask
+ mask = random_mask(pil_image.size, 1, False)
+ # prepare mask and masked image
+ mask, masked_image = prepare_mask_and_masked_image(pil_image, mask)
+
+ masks.append(mask)
+ masked_images.append(masked_image)
+
+ pixel_values = torch.stack(pixel_values)
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
+
+ input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids
+ masks = torch.stack(masks)
+ masked_images = torch.stack(masked_images)
+ batch = {"input_ids": input_ids, "pixel_values": pixel_values, "masks": masks, "masked_images": masked_images}
+ return batch
+
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn
+ )
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
+ )
+
+ if args.train_text_encoder:
+ unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ unet, text_encoder, optimizer, train_dataloader, lr_scheduler
+ )
+ else:
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ unet, optimizer, train_dataloader, lr_scheduler
+ )
+ accelerator.register_for_checkpointing(lr_scheduler)
+
+ weight_dtype = torch.float32
+ if args.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif args.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ # Move text_encode and vae to gpu.
+ # For mixed precision training we cast the text_encoder and vae weights to half-precision
+ # as these models are only used for inference, keeping weights in full precision is not required.
+ vae.to(accelerator.device, dtype=weight_dtype)
+ if not args.train_text_encoder:
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ accelerator.init_trackers("dreambooth", config=vars(args))
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the most recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ resume_global_step = global_step * args.gradient_accumulation_steps
+ first_epoch = global_step // num_update_steps_per_epoch
+ resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
+
+ # Only show the progress bar once on each machine.
+ progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
+ progress_bar.set_description("Steps")
+
+ for epoch in range(first_epoch, args.num_train_epochs):
+ unet.train()
+ for step, batch in enumerate(train_dataloader):
+ # Skip steps until we reach the resumed step
+ if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
+ if step % args.gradient_accumulation_steps == 0:
+ progress_bar.update(1)
+ continue
+
+ with accelerator.accumulate(unet):
+ # Convert images to latent space
+
+ latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
+ latents = latents * vae.config.scaling_factor
+
+ # Convert masked images to latent space
+ masked_latents = vae.encode(
+ batch["masked_images"].reshape(batch["pixel_values"].shape).to(dtype=weight_dtype)
+ ).latent_dist.sample()
+ masked_latents = masked_latents * vae.config.scaling_factor
+
+ masks = batch["masks"]
+ # resize the mask to latents shape as we concatenate the mask to the latents
+ mask = torch.stack(
+ [
+ torch.nn.functional.interpolate(mask, size=(args.resolution // 8, args.resolution // 8))
+ for mask in masks
+ ]
+ )
+ mask = mask.reshape(-1, 1, args.resolution // 8, args.resolution // 8)
+
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(latents)
+ bsz = latents.shape[0]
+ # Sample a random timestep for each image
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
+ timesteps = timesteps.long()
+
+ # Add noise to the latents according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
+
+ # concatenate the noised latents with the mask and the masked latents
+ latent_model_input = torch.cat([noisy_latents, mask, masked_latents], dim=1)
+
+ # Get the text embedding for conditioning
+ encoder_hidden_states = text_encoder(batch["input_ids"])[0]
+
+ # Predict the noise residual
+ noise_pred = unet(latent_model_input, timesteps, encoder_hidden_states).sample
+
+ # Get the target for loss depending on the prediction type
+ if noise_scheduler.config.prediction_type == "epsilon":
+ target = noise
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
+ else:
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
+
+ if args.with_prior_preservation:
+ # Chunk the noise and noise_pred into two parts and compute the loss on each part separately.
+ noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0)
+ target, target_prior = torch.chunk(target, 2, dim=0)
+
+ # Compute instance loss
+ loss = F.mse_loss(noise_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean()
+
+ # Compute prior loss
+ prior_loss = F.mse_loss(noise_pred_prior.float(), target_prior.float(), reduction="mean")
+
+ # Add the prior loss to the instance loss.
+ loss = loss + args.prior_loss_weight * prior_loss
+ else:
+ loss = F.mse_loss(noise_pred.float(), target.float(), reduction="mean")
+
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ params_to_clip = (
+ itertools.chain(unet.parameters(), text_encoder.parameters())
+ if args.train_text_encoder
+ else unet.parameters()
+ )
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ global_step += 1
+
+ if global_step % args.checkpointing_steps == 0:
+ if accelerator.is_main_process:
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+ accelerator.log(logs, step=global_step)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ accelerator.wait_for_everyone()
+
+ # Create the pipeline using using the trained modules and save it.
+ if accelerator.is_main_process:
+ pipeline = StableDiffusionPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ unet=accelerator.unwrap_model(unet),
+ text_encoder=accelerator.unwrap_model(text_encoder),
+ )
+ pipeline.save_pretrained(args.output_dir)
+
+ if args.push_to_hub:
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/diffusers/examples/research_projects/dreambooth_inpaint/train_dreambooth_inpaint_lora.py b/diffusers/examples/research_projects/dreambooth_inpaint/train_dreambooth_inpaint_lora.py
new file mode 100644
index 0000000000000000000000000000000000000000..720a69b4e7917d858a97c8855229efcf8a239c99
--- /dev/null
+++ b/diffusers/examples/research_projects/dreambooth_inpaint/train_dreambooth_inpaint_lora.py
@@ -0,0 +1,831 @@
+import argparse
+import math
+import os
+import random
+from pathlib import Path
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import ProjectConfiguration, set_seed
+from huggingface_hub import create_repo, upload_folder
+from huggingface_hub.utils import insecure_hashlib
+from PIL import Image, ImageDraw
+from torch.utils.data import Dataset
+from torchvision import transforms
+from tqdm.auto import tqdm
+from transformers import CLIPTextModel, CLIPTokenizer
+
+from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionInpaintPipeline, UNet2DConditionModel
+from diffusers.loaders import AttnProcsLayers
+from diffusers.models.attention_processor import LoRAAttnProcessor
+from diffusers.optimization import get_scheduler
+from diffusers.utils import check_min_version
+from diffusers.utils.import_utils import is_xformers_available
+
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.13.0.dev0")
+
+logger = get_logger(__name__)
+
+
+def prepare_mask_and_masked_image(image, mask):
+ image = np.array(image.convert("RGB"))
+ image = image[None].transpose(0, 3, 1, 2)
+ image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
+
+ mask = np.array(mask.convert("L"))
+ mask = mask.astype(np.float32) / 255.0
+ mask = mask[None, None]
+ mask[mask < 0.5] = 0
+ mask[mask >= 0.5] = 1
+ mask = torch.from_numpy(mask)
+
+ masked_image = image * (mask < 0.5)
+
+ return mask, masked_image
+
+
+# generate random masks
+def random_mask(im_shape, ratio=1, mask_full_image=False):
+ mask = Image.new("L", im_shape, 0)
+ draw = ImageDraw.Draw(mask)
+ size = (random.randint(0, int(im_shape[0] * ratio)), random.randint(0, int(im_shape[1] * ratio)))
+ # use this to always mask the whole image
+ if mask_full_image:
+ size = (int(im_shape[0] * ratio), int(im_shape[1] * ratio))
+ limits = (im_shape[0] - size[0] // 2, im_shape[1] - size[1] // 2)
+ center = (random.randint(size[0] // 2, limits[0]), random.randint(size[1] // 2, limits[1]))
+ draw_type = random.randint(0, 1)
+ if draw_type == 0 or mask_full_image:
+ draw.rectangle(
+ (center[0] - size[0] // 2, center[1] - size[1] // 2, center[0] + size[0] // 2, center[1] + size[1] // 2),
+ fill=255,
+ )
+ else:
+ draw.ellipse(
+ (center[0] - size[0] // 2, center[1] - size[1] // 2, center[0] + size[0] // 2, center[1] + size[1] // 2),
+ fill=255,
+ )
+
+ return mask
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--tokenizer_name",
+ type=str,
+ default=None,
+ help="Pretrained tokenizer name or path if not the same as model_name",
+ )
+ parser.add_argument(
+ "--instance_data_dir",
+ type=str,
+ default=None,
+ required=True,
+ help="A folder containing the training data of instance images.",
+ )
+ parser.add_argument(
+ "--class_data_dir",
+ type=str,
+ default=None,
+ required=False,
+ help="A folder containing the training data of class images.",
+ )
+ parser.add_argument(
+ "--instance_prompt",
+ type=str,
+ default=None,
+ help="The prompt with identifier specifying the instance",
+ )
+ parser.add_argument(
+ "--class_prompt",
+ type=str,
+ default=None,
+ help="The prompt to specify images in the same class as provided instance images.",
+ )
+ parser.add_argument(
+ "--with_prior_preservation",
+ default=False,
+ action="store_true",
+ help="Flag to add prior preservation loss.",
+ )
+ parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.")
+ parser.add_argument(
+ "--num_class_images",
+ type=int,
+ default=100,
+ help=(
+ "Minimal class images for prior preservation loss. If not have enough images, additional images will be"
+ " sampled with class_prompt."
+ ),
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="dreambooth-inpaint-model",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--center_crop",
+ default=False,
+ action="store_true",
+ help=(
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
+ " cropped. The images will be resized to the resolution first before cropping."
+ ),
+ )
+ parser.add_argument("--train_text_encoder", action="store_true", help="Whether to train the text encoder")
+ parser.add_argument(
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument(
+ "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=1)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=5e-6,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default="no",
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose"
+ "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
+ "and an Nvidia Ampere GPU."
+ ),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
+ " checkpoints in case they are better than the last checkpoint and are suitable for resuming training"
+ " using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=(
+ "Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
+ " See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
+ " for more docs"
+ ),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
+ )
+
+ args = parser.parse_args()
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ if args.instance_data_dir is None:
+ raise ValueError("You must specify a train data directory.")
+
+ if args.with_prior_preservation:
+ if args.class_data_dir is None:
+ raise ValueError("You must specify a data directory for class images.")
+ if args.class_prompt is None:
+ raise ValueError("You must specify prompt for class images.")
+
+ return args
+
+
+class DreamBoothDataset(Dataset):
+ """
+ A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
+ It pre-processes the images and the tokenizes prompts.
+ """
+
+ def __init__(
+ self,
+ instance_data_root,
+ instance_prompt,
+ tokenizer,
+ class_data_root=None,
+ class_prompt=None,
+ size=512,
+ center_crop=False,
+ ):
+ self.size = size
+ self.center_crop = center_crop
+ self.tokenizer = tokenizer
+
+ self.instance_data_root = Path(instance_data_root)
+ if not self.instance_data_root.exists():
+ raise ValueError("Instance images root doesn't exists.")
+
+ self.instance_images_path = list(Path(instance_data_root).iterdir())
+ self.num_instance_images = len(self.instance_images_path)
+ self.instance_prompt = instance_prompt
+ self._length = self.num_instance_images
+
+ if class_data_root is not None:
+ self.class_data_root = Path(class_data_root)
+ self.class_data_root.mkdir(parents=True, exist_ok=True)
+ self.class_images_path = list(self.class_data_root.iterdir())
+ self.num_class_images = len(self.class_images_path)
+ self._length = max(self.num_class_images, self.num_instance_images)
+ self.class_prompt = class_prompt
+ else:
+ self.class_data_root = None
+
+ self.image_transforms_resize_and_crop = transforms.Compose(
+ [
+ transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
+ transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
+ ]
+ )
+
+ self.image_transforms = transforms.Compose(
+ [
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+ )
+
+ def __len__(self):
+ return self._length
+
+ def __getitem__(self, index):
+ example = {}
+ instance_image = Image.open(self.instance_images_path[index % self.num_instance_images])
+ if not instance_image.mode == "RGB":
+ instance_image = instance_image.convert("RGB")
+ instance_image = self.image_transforms_resize_and_crop(instance_image)
+
+ example["PIL_images"] = instance_image
+ example["instance_images"] = self.image_transforms(instance_image)
+
+ example["instance_prompt_ids"] = self.tokenizer(
+ self.instance_prompt,
+ padding="do_not_pad",
+ truncation=True,
+ max_length=self.tokenizer.model_max_length,
+ ).input_ids
+
+ if self.class_data_root:
+ class_image = Image.open(self.class_images_path[index % self.num_class_images])
+ if not class_image.mode == "RGB":
+ class_image = class_image.convert("RGB")
+ class_image = self.image_transforms_resize_and_crop(class_image)
+ example["class_images"] = self.image_transforms(class_image)
+ example["class_PIL_images"] = class_image
+ example["class_prompt_ids"] = self.tokenizer(
+ self.class_prompt,
+ padding="do_not_pad",
+ truncation=True,
+ max_length=self.tokenizer.model_max_length,
+ ).input_ids
+
+ return example
+
+
+class PromptDataset(Dataset):
+ """A simple dataset to prepare the prompts to generate class images on multiple GPUs."""
+
+ def __init__(self, prompt, num_samples):
+ self.prompt = prompt
+ self.num_samples = num_samples
+
+ def __len__(self):
+ return self.num_samples
+
+ def __getitem__(self, index):
+ example = {}
+ example["prompt"] = self.prompt
+ example["index"] = index
+ return example
+
+
+def main():
+ args = parse_args()
+ logging_dir = Path(args.output_dir, args.logging_dir)
+
+ accelerator_project_config = ProjectConfiguration(
+ total_limit=args.checkpoints_total_limit, project_dir=args.output_dir, logging_dir=logging_dir
+ )
+
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with="tensorboard",
+ project_config=accelerator_project_config,
+ )
+
+ # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
+ # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
+ # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.
+ if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1:
+ raise ValueError(
+ "Gradient accumulation is not supported when training the text encoder in distributed training. "
+ "Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
+ )
+
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ if args.with_prior_preservation:
+ class_images_dir = Path(args.class_data_dir)
+ if not class_images_dir.exists():
+ class_images_dir.mkdir(parents=True)
+ cur_class_images = len(list(class_images_dir.iterdir()))
+
+ if cur_class_images < args.num_class_images:
+ torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
+ pipeline = StableDiffusionInpaintPipeline.from_pretrained(
+ args.pretrained_model_name_or_path, torch_dtype=torch_dtype, safety_checker=None
+ )
+ pipeline.set_progress_bar_config(disable=True)
+
+ num_new_images = args.num_class_images - cur_class_images
+ logger.info(f"Number of class images to sample: {num_new_images}.")
+
+ sample_dataset = PromptDataset(args.class_prompt, num_new_images)
+ sample_dataloader = torch.utils.data.DataLoader(
+ sample_dataset, batch_size=args.sample_batch_size, num_workers=1
+ )
+
+ sample_dataloader = accelerator.prepare(sample_dataloader)
+ pipeline.to(accelerator.device)
+ transform_to_pil = transforms.ToPILImage()
+ for example in tqdm(
+ sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
+ ):
+ bsz = len(example["prompt"])
+ fake_images = torch.rand((3, args.resolution, args.resolution))
+ transform_to_pil = transforms.ToPILImage()
+ fake_pil_images = transform_to_pil(fake_images)
+
+ fake_mask = random_mask((args.resolution, args.resolution), ratio=1, mask_full_image=True)
+
+ images = pipeline(prompt=example["prompt"], mask_image=fake_mask, image=fake_pil_images).images
+
+ for i, image in enumerate(images):
+ hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest()
+ image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
+ image.save(image_filename)
+
+ del pipeline
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
+ ).repo_id
+
+ # Load the tokenizer
+ if args.tokenizer_name:
+ tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
+ elif args.pretrained_model_name_or_path:
+ tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
+
+ # Load models and create wrapper for stable diffusion
+ text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
+ vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
+ unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
+
+ # We only train the additional adapter LoRA layers
+ vae.requires_grad_(False)
+ text_encoder.requires_grad_(False)
+ unet.requires_grad_(False)
+
+ weight_dtype = torch.float32
+ if args.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif args.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ # Move text_encode and vae to gpu.
+ # For mixed precision training we cast the text_encoder and vae weights to half-precision
+ # as these models are only used for inference, keeping weights in full precision is not required.
+ unet.to(accelerator.device, dtype=weight_dtype)
+ vae.to(accelerator.device, dtype=weight_dtype)
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
+
+ if args.enable_xformers_memory_efficient_attention:
+ if is_xformers_available():
+ unet.enable_xformers_memory_efficient_attention()
+ else:
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
+
+ # now we will add new LoRA weights to the attention layers
+ # It's important to realize here how many attention weights will be added and of which sizes
+ # The sizes of the attention layers consist only of two different variables:
+ # 1) - the "hidden_size", which is increased according to `unet.config.block_out_channels`.
+ # 2) - the "cross attention size", which is set to `unet.config.cross_attention_dim`.
+
+ # Let's first see how many attention processors we will have to set.
+ # For Stable Diffusion, it should be equal to:
+ # - down blocks (2x attention layers) * (2x transformer layers) * (3x down blocks) = 12
+ # - mid blocks (2x attention layers) * (1x transformer layers) * (1x mid blocks) = 2
+ # - up blocks (2x attention layers) * (3x transformer layers) * (3x down blocks) = 18
+ # => 32 layers
+
+ # Set correct lora layers
+ lora_attn_procs = {}
+ for name in unet.attn_processors.keys():
+ cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
+ if name.startswith("mid_block"):
+ hidden_size = unet.config.block_out_channels[-1]
+ elif name.startswith("up_blocks"):
+ block_id = int(name[len("up_blocks.")])
+ hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
+ elif name.startswith("down_blocks"):
+ block_id = int(name[len("down_blocks.")])
+ hidden_size = unet.config.block_out_channels[block_id]
+
+ lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
+
+ unet.set_attn_processor(lora_attn_procs)
+ lora_layers = AttnProcsLayers(unet.attn_processors)
+
+ accelerator.register_for_checkpointing(lora_layers)
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
+ )
+
+ optimizer_class = bnb.optim.AdamW8bit
+ else:
+ optimizer_class = torch.optim.AdamW
+
+ optimizer = optimizer_class(
+ lora_layers.parameters(),
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
+
+ train_dataset = DreamBoothDataset(
+ instance_data_root=args.instance_data_dir,
+ instance_prompt=args.instance_prompt,
+ class_data_root=args.class_data_dir if args.with_prior_preservation else None,
+ class_prompt=args.class_prompt,
+ tokenizer=tokenizer,
+ size=args.resolution,
+ center_crop=args.center_crop,
+ )
+
+ def collate_fn(examples):
+ input_ids = [example["instance_prompt_ids"] for example in examples]
+ pixel_values = [example["instance_images"] for example in examples]
+
+ # Concat class and instance examples for prior preservation.
+ # We do this to avoid doing two forward passes.
+ if args.with_prior_preservation:
+ input_ids += [example["class_prompt_ids"] for example in examples]
+ pixel_values += [example["class_images"] for example in examples]
+ pior_pil = [example["class_PIL_images"] for example in examples]
+
+ masks = []
+ masked_images = []
+ for example in examples:
+ pil_image = example["PIL_images"]
+ # generate a random mask
+ mask = random_mask(pil_image.size, 1, False)
+ # prepare mask and masked image
+ mask, masked_image = prepare_mask_and_masked_image(pil_image, mask)
+
+ masks.append(mask)
+ masked_images.append(masked_image)
+
+ if args.with_prior_preservation:
+ for pil_image in pior_pil:
+ # generate a random mask
+ mask = random_mask(pil_image.size, 1, False)
+ # prepare mask and masked image
+ mask, masked_image = prepare_mask_and_masked_image(pil_image, mask)
+
+ masks.append(mask)
+ masked_images.append(masked_image)
+
+ pixel_values = torch.stack(pixel_values)
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
+
+ input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids
+ masks = torch.stack(masks)
+ masked_images = torch.stack(masked_images)
+ batch = {"input_ids": input_ids, "pixel_values": pixel_values, "masks": masks, "masked_images": masked_images}
+ return batch
+
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn
+ )
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
+ )
+
+ # Prepare everything with our `accelerator`.
+ lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ lora_layers, optimizer, train_dataloader, lr_scheduler
+ )
+ # accelerator.register_for_checkpointing(lr_scheduler)
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ accelerator.init_trackers("dreambooth-inpaint-lora", config=vars(args))
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the most recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ resume_global_step = global_step * args.gradient_accumulation_steps
+ first_epoch = global_step // num_update_steps_per_epoch
+ resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
+
+ # Only show the progress bar once on each machine.
+ progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
+ progress_bar.set_description("Steps")
+
+ for epoch in range(first_epoch, args.num_train_epochs):
+ unet.train()
+ for step, batch in enumerate(train_dataloader):
+ # Skip steps until we reach the resumed step
+ if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
+ if step % args.gradient_accumulation_steps == 0:
+ progress_bar.update(1)
+ continue
+
+ with accelerator.accumulate(unet):
+ # Convert images to latent space
+
+ latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
+ latents = latents * vae.config.scaling_factor
+
+ # Convert masked images to latent space
+ masked_latents = vae.encode(
+ batch["masked_images"].reshape(batch["pixel_values"].shape).to(dtype=weight_dtype)
+ ).latent_dist.sample()
+ masked_latents = masked_latents * vae.config.scaling_factor
+
+ masks = batch["masks"]
+ # resize the mask to latents shape as we concatenate the mask to the latents
+ mask = torch.stack(
+ [
+ torch.nn.functional.interpolate(mask, size=(args.resolution // 8, args.resolution // 8))
+ for mask in masks
+ ]
+ ).to(dtype=weight_dtype)
+ mask = mask.reshape(-1, 1, args.resolution // 8, args.resolution // 8)
+
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(latents)
+ bsz = latents.shape[0]
+ # Sample a random timestep for each image
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
+ timesteps = timesteps.long()
+
+ # Add noise to the latents according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
+
+ # concatenate the noised latents with the mask and the masked latents
+ latent_model_input = torch.cat([noisy_latents, mask, masked_latents], dim=1)
+
+ # Get the text embedding for conditioning
+ encoder_hidden_states = text_encoder(batch["input_ids"])[0]
+
+ # Predict the noise residual
+ noise_pred = unet(latent_model_input, timesteps, encoder_hidden_states).sample
+
+ # Get the target for loss depending on the prediction type
+ if noise_scheduler.config.prediction_type == "epsilon":
+ target = noise
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
+ else:
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
+
+ if args.with_prior_preservation:
+ # Chunk the noise and noise_pred into two parts and compute the loss on each part separately.
+ noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0)
+ target, target_prior = torch.chunk(target, 2, dim=0)
+
+ # Compute instance loss
+ loss = F.mse_loss(noise_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean()
+
+ # Compute prior loss
+ prior_loss = F.mse_loss(noise_pred_prior.float(), target_prior.float(), reduction="mean")
+
+ # Add the prior loss to the instance loss.
+ loss = loss + args.prior_loss_weight * prior_loss
+ else:
+ loss = F.mse_loss(noise_pred.float(), target.float(), reduction="mean")
+
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ params_to_clip = lora_layers.parameters()
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ global_step += 1
+
+ if global_step % args.checkpointing_steps == 0:
+ if accelerator.is_main_process:
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+ accelerator.log(logs, step=global_step)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ accelerator.wait_for_everyone()
+
+ # Save the lora layers
+ if accelerator.is_main_process:
+ unet = unet.to(torch.float32)
+ unet.save_attn_procs(args.output_dir)
+
+ if args.push_to_hub:
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/diffusers/examples/research_projects/geodiff/README.md b/diffusers/examples/research_projects/geodiff/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..c85bbdf792cf7b7185f5df0b4ad24863e367eb3a
--- /dev/null
+++ b/diffusers/examples/research_projects/geodiff/README.md
@@ -0,0 +1,6 @@
+# GeoDiff
+
+> [!TIP]
+> This notebook is not actively maintained by the Diffusers team. For any questions or comments, please contact [natolambert](https://twitter.com/natolambert).
+
+This is an experimental research notebook demonstrating how to generate stable 3D structures of molecules with [GeoDiff](https://github.com/MinkaiXu/GeoDiff) and Diffusers.
diff --git a/diffusers/examples/research_projects/geodiff/geodiff_molecule_conformation.ipynb b/diffusers/examples/research_projects/geodiff/geodiff_molecule_conformation.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..bde093802a5d057570ac7f891bae4fbb9bb16602
--- /dev/null
+++ b/diffusers/examples/research_projects/geodiff/geodiff_molecule_conformation.ipynb
@@ -0,0 +1,3652 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "F88mignPnalS"
+ },
+ "source": [
+ "# Introduction\n",
+ "\n",
+ "This colab is design to run the pretrained models from [GeoDiff](https://github.com/MinkaiXu/GeoDiff).\n",
+ "The visualization code is inspired by this PyMol [colab](https://colab.research.google.com/gist/iwatobipen/2ec7faeafe5974501e69fcc98c122922/pymol.ipynb#scrollTo=Hm4kY7CaZSlw).\n",
+ "\n",
+ "The goal is to generate physically accurate molecules. Given the input of a molecule graph (atom and bond structures with their connectivity -- in the form of a 2d graph). What we want to generate is a stable 3d structure of the molecule.\n",
+ "\n",
+ "This colab uses GEOM datasets that have multiple 3d targets per configuration, which provide more compelling targets for generative methods.\n",
+ "\n",
+ "> Colab made by [natolambert](https://twitter.com/natolambert).\n",
+ "\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "7cnwXMocnuzB"
+ },
+ "source": [
+ "## Installations\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "### Install Conda"
+ ],
+ "metadata": {
+ "id": "ff9SxWnaNId9"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "1g_6zOabItDk"
+ },
+ "source": [
+ "Here we check the `cuda` version of colab. When this was built, the version was always 11.1, which impacts some installation decisions below."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "K0ofXobG5Y-X",
+ "outputId": "572c3d25-6f19-4c1e-83f5-a1d084a3207f"
+ },
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "nvcc: NVIDIA (R) Cuda compiler driver\n",
+ "Copyright (c) 2005-2021 NVIDIA Corporation\n",
+ "Built on Sun_Feb_14_21:12:58_PST_2021\n",
+ "Cuda compilation tools, release 11.2, V11.2.152\n",
+ "Build cuda_11.2.r11.2/compiler.29618528_0\n"
+ ]
+ }
+ ],
+ "source": [
+ "!nvcc --version"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "VfthW90vI0nw"
+ },
+ "source": [
+ "Install Conda for some more complex dependencies for geometric networks."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "2WNFzSnbiE0k",
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "outputId": "690d0d4d-9d0a-4ead-c6dc-086f113f532f"
+ },
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
+ "\u001b[0m"
+ ]
+ }
+ ],
+ "source": [
+ "!pip install -q condacolab"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "NUsbWYCUI7Km"
+ },
+ "source": [
+ "Setup Conda"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "FZelreINdmd0",
+ "outputId": "635f0cb8-0af4-499f-e0a4-b3790cb12e9f"
+ },
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "✨🍰✨ Everything looks OK!\n"
+ ]
+ }
+ ],
+ "source": [
+ "import condacolab\n",
+ "condacolab.install()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "JzDHaPU7I9Sn"
+ },
+ "source": [
+ "Install pytorch requirements (this takes a few minutes, go grab yourself a coffee 🤗)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "JMxRjHhL7w8V",
+ "outputId": "6ed511b3-9262-49e8-b340-08e76b05ebd8"
+ },
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Collecting package metadata (current_repodata.json): - \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\bdone\n",
+ "Solving environment: \\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\bdone\n",
+ "\n",
+ "## Package Plan ##\n",
+ "\n",
+ " environment location: /usr/local\n",
+ "\n",
+ " added / updated specs:\n",
+ " - cudatoolkit=11.1\n",
+ " - pytorch\n",
+ " - torchaudio\n",
+ " - torchvision\n",
+ "\n",
+ "\n",
+ "The following packages will be downloaded:\n",
+ "\n",
+ " package | build\n",
+ " ---------------------------|-----------------\n",
+ " conda-22.9.0 | py37h89c1867_1 960 KB conda-forge\n",
+ " ------------------------------------------------------------\n",
+ " Total: 960 KB\n",
+ "\n",
+ "The following packages will be UPDATED:\n",
+ "\n",
+ " conda 4.14.0-py37h89c1867_0 --> 22.9.0-py37h89c1867_1\n",
+ "\n",
+ "\n",
+ "\n",
+ "Downloading and Extracting Packages\n",
+ "conda-22.9.0 | 960 KB | : 100% 1.0/1 [00:00<00:00, 4.15it/s]\n",
+ "Preparing transaction: / \b\bdone\n",
+ "Verifying transaction: \\ \b\bdone\n",
+ "Executing transaction: / \b\bdone\n",
+ "Retrieving notices: ...working... done\n"
+ ]
+ }
+ ],
+ "source": [
+ "!conda install pytorch torchvision torchaudio cudatoolkit=11.1 -c pytorch-lts -c nvidia\n",
+ "# !conda install pytorch==1.8.0 torchvision==0.9.0 torchaudio==0.8.0 cudatoolkit=11.1 -c pytorch -c conda-forge"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Need to remove a pathspec for colab that specifies the incorrect cuda version."
+ ],
+ "metadata": {
+ "id": "QDS6FPZ0Tu5b"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "!rm /usr/local/conda-meta/pinned"
+ ],
+ "metadata": {
+ "id": "dq1lxR10TtrR",
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "outputId": "ed9c5a71-b449-418f-abb7-072b74e7f6c8"
+ },
+ "execution_count": null,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "rm: cannot remove '/usr/local/conda-meta/pinned': No such file or directory\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "Z1L3DdZOJB30"
+ },
+ "source": [
+ "Install torch geometric (used in the model later)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "D5ukfCOWfjzK",
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "outputId": "8437485a-5aa6-4d53-8f7f-23517ac1ace6"
+ },
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Collecting package metadata (current_repodata.json): - \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\bdone\n",
+ "Solving environment: | \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\bdone\n",
+ "\n",
+ "## Package Plan ##\n",
+ "\n",
+ " environment location: /usr/local\n",
+ "\n",
+ " added / updated specs:\n",
+ " - pytorch-geometric=1.7.2\n",
+ "\n",
+ "\n",
+ "The following packages will be downloaded:\n",
+ "\n",
+ " package | build\n",
+ " ---------------------------|-----------------\n",
+ " decorator-4.4.2 | py_0 11 KB conda-forge\n",
+ " googledrivedownloader-0.4 | pyhd3deb0d_1 7 KB conda-forge\n",
+ " jinja2-3.1.2 | pyhd8ed1ab_1 99 KB conda-forge\n",
+ " joblib-1.2.0 | pyhd8ed1ab_0 205 KB conda-forge\n",
+ " markupsafe-2.1.1 | py37h540881e_1 22 KB conda-forge\n",
+ " networkx-2.5.1 | pyhd8ed1ab_0 1.2 MB conda-forge\n",
+ " pandas-1.2.3 | py37hdc94413_0 11.8 MB conda-forge\n",
+ " pyparsing-3.0.9 | pyhd8ed1ab_0 79 KB conda-forge\n",
+ " python-dateutil-2.8.2 | pyhd8ed1ab_0 240 KB conda-forge\n",
+ " python-louvain-0.15 | pyhd8ed1ab_1 13 KB conda-forge\n",
+ " pytorch-cluster-1.5.9 |py37_torch_1.8.0_cu111 1.2 MB rusty1s\n",
+ " pytorch-geometric-1.7.2 |py37_torch_1.8.0_cu111 445 KB rusty1s\n",
+ " pytorch-scatter-2.0.8 |py37_torch_1.8.0_cu111 6.1 MB rusty1s\n",
+ " pytorch-sparse-0.6.12 |py37_torch_1.8.0_cu111 2.9 MB rusty1s\n",
+ " pytorch-spline-conv-1.2.1 |py37_torch_1.8.0_cu111 736 KB rusty1s\n",
+ " pytz-2022.4 | pyhd8ed1ab_0 232 KB conda-forge\n",
+ " scikit-learn-1.0.2 | py37hf9e9bfc_0 7.8 MB conda-forge\n",
+ " scipy-1.7.3 | py37hf2a6cf1_0 21.8 MB conda-forge\n",
+ " setuptools-59.8.0 | py37h89c1867_1 1.0 MB conda-forge\n",
+ " threadpoolctl-3.1.0 | pyh8a188c0_0 18 KB conda-forge\n",
+ " ------------------------------------------------------------\n",
+ " Total: 55.9 MB\n",
+ "\n",
+ "The following NEW packages will be INSTALLED:\n",
+ "\n",
+ " decorator conda-forge/noarch::decorator-4.4.2-py_0 None\n",
+ " googledrivedownlo~ conda-forge/noarch::googledrivedownloader-0.4-pyhd3deb0d_1 None\n",
+ " jinja2 conda-forge/noarch::jinja2-3.1.2-pyhd8ed1ab_1 None\n",
+ " joblib conda-forge/noarch::joblib-1.2.0-pyhd8ed1ab_0 None\n",
+ " markupsafe conda-forge/linux-64::markupsafe-2.1.1-py37h540881e_1 None\n",
+ " networkx conda-forge/noarch::networkx-2.5.1-pyhd8ed1ab_0 None\n",
+ " pandas conda-forge/linux-64::pandas-1.2.3-py37hdc94413_0 None\n",
+ " pyparsing conda-forge/noarch::pyparsing-3.0.9-pyhd8ed1ab_0 None\n",
+ " python-dateutil conda-forge/noarch::python-dateutil-2.8.2-pyhd8ed1ab_0 None\n",
+ " python-louvain conda-forge/noarch::python-louvain-0.15-pyhd8ed1ab_1 None\n",
+ " pytorch-cluster rusty1s/linux-64::pytorch-cluster-1.5.9-py37_torch_1.8.0_cu111 None\n",
+ " pytorch-geometric rusty1s/linux-64::pytorch-geometric-1.7.2-py37_torch_1.8.0_cu111 None\n",
+ " pytorch-scatter rusty1s/linux-64::pytorch-scatter-2.0.8-py37_torch_1.8.0_cu111 None\n",
+ " pytorch-sparse rusty1s/linux-64::pytorch-sparse-0.6.12-py37_torch_1.8.0_cu111 None\n",
+ " pytorch-spline-co~ rusty1s/linux-64::pytorch-spline-conv-1.2.1-py37_torch_1.8.0_cu111 None\n",
+ " pytz conda-forge/noarch::pytz-2022.4-pyhd8ed1ab_0 None\n",
+ " scikit-learn conda-forge/linux-64::scikit-learn-1.0.2-py37hf9e9bfc_0 None\n",
+ " scipy conda-forge/linux-64::scipy-1.7.3-py37hf2a6cf1_0 None\n",
+ " threadpoolctl conda-forge/noarch::threadpoolctl-3.1.0-pyh8a188c0_0 None\n",
+ "\n",
+ "The following packages will be DOWNGRADED:\n",
+ "\n",
+ " setuptools 65.3.0-py37h89c1867_0 --> 59.8.0-py37h89c1867_1 None\n",
+ "\n",
+ "\n",
+ "\n",
+ "Downloading and Extracting Packages\n",
+ "scikit-learn-1.0.2 | 7.8 MB | : 100% 1.0/1 [00:01<00:00, 1.37s/it] \n",
+ "pytorch-scatter-2.0. | 6.1 MB | : 100% 1.0/1 [00:06<00:00, 6.18s/it]\n",
+ "pytorch-geometric-1. | 445 KB | : 100% 1.0/1 [00:02<00:00, 2.53s/it]\n",
+ "scipy-1.7.3 | 21.8 MB | : 100% 1.0/1 [00:03<00:00, 3.06s/it]\n",
+ "python-dateutil-2.8. | 240 KB | : 100% 1.0/1 [00:00<00:00, 21.48it/s]\n",
+ "pytorch-spline-conv- | 736 KB | : 100% 1.0/1 [00:01<00:00, 1.00s/it]\n",
+ "pytorch-sparse-0.6.1 | 2.9 MB | : 100% 1.0/1 [00:07<00:00, 7.51s/it]\n",
+ "pyparsing-3.0.9 | 79 KB | : 100% 1.0/1 [00:00<00:00, 26.32it/s]\n",
+ "pytorch-cluster-1.5. | 1.2 MB | : 100% 1.0/1 [00:02<00:00, 2.78s/it]\n",
+ "jinja2-3.1.2 | 99 KB | : 100% 1.0/1 [00:00<00:00, 20.28it/s]\n",
+ "decorator-4.4.2 | 11 KB | : 100% 1.0/1 [00:00<00:00, 21.57it/s]\n",
+ "joblib-1.2.0 | 205 KB | : 100% 1.0/1 [00:00<00:00, 15.04it/s]\n",
+ "pytz-2022.4 | 232 KB | : 100% 1.0/1 [00:00<00:00, 10.21it/s]\n",
+ "python-louvain-0.15 | 13 KB | : 100% 1.0/1 [00:00<00:00, 3.34it/s]\n",
+ "googledrivedownloade | 7 KB | : 100% 1.0/1 [00:00<00:00, 3.33it/s]\n",
+ "threadpoolctl-3.1.0 | 18 KB | : 100% 1.0/1 [00:00<00:00, 29.40it/s]\n",
+ "markupsafe-2.1.1 | 22 KB | : 100% 1.0/1 [00:00<00:00, 28.62it/s]\n",
+ "pandas-1.2.3 | 11.8 MB | : 100% 1.0/1 [00:02<00:00, 2.08s/it] \n",
+ "networkx-2.5.1 | 1.2 MB | : 100% 1.0/1 [00:01<00:00, 1.39s/it]\n",
+ "setuptools-59.8.0 | 1.0 MB | : 100% 1.0/1 [00:00<00:00, 4.25it/s]\n",
+ "Preparing transaction: / \b\b- \b\b\\ \b\bdone\n",
+ "Verifying transaction: / \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\bdone\n",
+ "Executing transaction: / \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\bdone\n",
+ "Retrieving notices: ...working... done\n"
+ ]
+ }
+ ],
+ "source": [
+ "!conda install -c rusty1s pytorch-geometric=1.7.2"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "ppxv6Mdkalbc"
+ },
+ "source": [
+ "### Install Diffusers"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "mgQA_XN-XGY2",
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "outputId": "85392615-b6a4-4052-9d2a-79604be62c94"
+ },
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "/content\n",
+ "Cloning into 'diffusers'...\n",
+ "remote: Enumerating objects: 9298, done.\u001b[K\n",
+ "remote: Counting objects: 100% (40/40), done.\u001b[K\n",
+ "remote: Compressing objects: 100% (23/23), done.\u001b[K\n",
+ "remote: Total 9298 (delta 17), reused 23 (delta 11), pack-reused 9258\u001b[K\n",
+ "Receiving objects: 100% (9298/9298), 7.38 MiB | 5.28 MiB/s, done.\n",
+ "Resolving deltas: 100% (6168/6168), done.\n",
+ " Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n",
+ " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n",
+ " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m757.0/757.0 kB\u001b[0m \u001b[31m52.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m163.5/163.5 kB\u001b[0m \u001b[31m21.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m40.8/40.8 kB\u001b[0m \u001b[31m5.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m596.3/596.3 kB\u001b[0m \u001b[31m51.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25h Building wheel for diffusers (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
+ "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m432.7/432.7 kB\u001b[0m \u001b[31m36.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.3/5.3 MB\u001b[0m \u001b[31m90.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m35.3/35.3 MB\u001b[0m \u001b[31m39.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m115.1/115.1 kB\u001b[0m \u001b[31m16.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m948.0/948.0 kB\u001b[0m \u001b[31m63.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m212.2/212.2 kB\u001b[0m \u001b[31m21.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m95.8/95.8 kB\u001b[0m \u001b[31m12.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m140.8/140.8 kB\u001b[0m \u001b[31m18.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.6/7.6 MB\u001b[0m \u001b[31m104.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m148.0/148.0 kB\u001b[0m \u001b[31m20.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m231.3/231.3 kB\u001b[0m \u001b[31m30.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m94.8/94.8 kB\u001b[0m \u001b[31m14.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m58.8/58.8 kB\u001b[0m \u001b[31m8.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25h\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
+ "\u001b[0m"
+ ]
+ }
+ ],
+ "source": [
+ "%cd /content\n",
+ "\n",
+ "# install latest HF diffusers (will update to the release once added)\n",
+ "!git clone https://github.com/huggingface/diffusers.git\n",
+ "!pip install -q /content/diffusers\n",
+ "\n",
+ "# dependencies for diffusers\n",
+ "!pip install -q datasets transformers"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "LZO6AJKuJKO8"
+ },
+ "source": [
+ "Check that torch is installed correctly and utilizing the GPU in the colab"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "gZt7BNi1e1PA",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 53
+ },
+ "outputId": "a0e1832c-9c02-49aa-cff8-1339e6cdc889"
+ },
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "True\n"
+ ]
+ },
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "'1.8.2'"
+ ],
+ "application/vnd.google.colaboratory.intrinsic+json": {
+ "type": "string"
+ }
+ },
+ "metadata": {},
+ "execution_count": 8
+ }
+ ],
+ "source": [
+ "import torch\n",
+ "print(torch.cuda.is_available())\n",
+ "torch.__version__"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "KLE7CqlfJNUO"
+ },
+ "source": [
+ "### Install Chemistry-specific Dependencies\n",
+ "\n",
+ "Install RDKit, a tool for working with and visualizing chemsitry in python (you use this to visualize the generate models later)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "0CPv_NvehRz3",
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "outputId": "6ee0ae4e-4511-4816-de29-22b1c21d49bc"
+ },
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
+ "Collecting rdkit\n",
+ " Downloading rdkit-2022.3.5-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (36.8 MB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m36.8/36.8 MB\u001b[0m \u001b[31m34.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hRequirement already satisfied: Pillow in /usr/local/lib/python3.7/site-packages (from rdkit) (9.2.0)\n",
+ "Requirement already satisfied: numpy in /usr/local/lib/python3.7/site-packages (from rdkit) (1.21.6)\n",
+ "Installing collected packages: rdkit\n",
+ "Successfully installed rdkit-2022.3.5\n",
+ "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
+ "\u001b[0m"
+ ]
+ }
+ ],
+ "source": [
+ "!pip install rdkit"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "88GaDbDPxJ5I"
+ },
+ "source": [
+ "### Get viewer from nglview\n",
+ "\n",
+ "The model you will use outputs a position matrix tensor. This pytorch geometric data object will have many features (positions, known features, edge features -- all tensors).\n",
+ "The data we give to the model will also have a rdmol object (which can extract features to geometric if needed).\n",
+ "The rdmol in this object is a source of ground truth for the generated molecules.\n",
+ "\n",
+ "You will use one rendering function from nglviewer later!\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "jcl8GCS2mz6t",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 1000
+ },
+ "outputId": "99b5cc40-67bb-4d8e-faa0-47d7cb33e98f"
+ },
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
+ "Collecting nglview\n",
+ " Downloading nglview-3.0.3.tar.gz (5.7 MB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.7/5.7 MB\u001b[0m \u001b[31m91.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25h Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n",
+ " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n",
+ " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
+ "Requirement already satisfied: numpy in /usr/local/lib/python3.7/site-packages (from nglview) (1.21.6)\n",
+ "Collecting jupyterlab-widgets\n",
+ " Downloading jupyterlab_widgets-3.0.3-py3-none-any.whl (384 kB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m384.1/384.1 kB\u001b[0m \u001b[31m40.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hCollecting ipywidgets>=7\n",
+ " Downloading ipywidgets-8.0.2-py3-none-any.whl (134 kB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m134.4/134.4 kB\u001b[0m \u001b[31m21.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hCollecting widgetsnbextension~=4.0\n",
+ " Downloading widgetsnbextension-4.0.3-py3-none-any.whl (2.0 MB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.0/2.0 MB\u001b[0m \u001b[31m84.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hCollecting ipython>=6.1.0\n",
+ " Downloading ipython-7.34.0-py3-none-any.whl (793 kB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m793.8/793.8 kB\u001b[0m \u001b[31m60.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hCollecting ipykernel>=4.5.1\n",
+ " Downloading ipykernel-6.16.0-py3-none-any.whl (138 kB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m138.4/138.4 kB\u001b[0m \u001b[31m20.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hCollecting traitlets>=4.3.1\n",
+ " Downloading traitlets-5.4.0-py3-none-any.whl (107 kB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m107.1/107.1 kB\u001b[0m \u001b[31m17.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hRequirement already satisfied: packaging in /usr/local/lib/python3.7/site-packages (from ipykernel>=4.5.1->ipywidgets>=7->nglview) (21.3)\n",
+ "Collecting pyzmq>=17\n",
+ " Downloading pyzmq-24.0.1-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl (1.1 MB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.1/1.1 MB\u001b[0m \u001b[31m68.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hCollecting matplotlib-inline>=0.1\n",
+ " Downloading matplotlib_inline-0.1.6-py3-none-any.whl (9.4 kB)\n",
+ "Collecting tornado>=6.1\n",
+ " Downloading tornado-6.2-cp37-abi3-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (423 kB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m424.0/424.0 kB\u001b[0m \u001b[31m41.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hCollecting nest-asyncio\n",
+ " Downloading nest_asyncio-1.5.6-py3-none-any.whl (5.2 kB)\n",
+ "Collecting debugpy>=1.0\n",
+ " Downloading debugpy-1.6.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (1.8 MB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.8/1.8 MB\u001b[0m \u001b[31m83.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hCollecting psutil\n",
+ " Downloading psutil-5.9.2-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (281 kB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m281.3/281.3 kB\u001b[0m \u001b[31m33.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hCollecting jupyter-client>=6.1.12\n",
+ " Downloading jupyter_client-7.4.2-py3-none-any.whl (132 kB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m132.2/132.2 kB\u001b[0m \u001b[31m19.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hCollecting pickleshare\n",
+ " Downloading pickleshare-0.7.5-py2.py3-none-any.whl (6.9 kB)\n",
+ "Requirement already satisfied: setuptools>=18.5 in /usr/local/lib/python3.7/site-packages (from ipython>=6.1.0->ipywidgets>=7->nglview) (59.8.0)\n",
+ "Collecting backcall\n",
+ " Downloading backcall-0.2.0-py2.py3-none-any.whl (11 kB)\n",
+ "Collecting pexpect>4.3\n",
+ " Downloading pexpect-4.8.0-py2.py3-none-any.whl (59 kB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m59.0/59.0 kB\u001b[0m \u001b[31m7.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hCollecting pygments\n",
+ " Downloading Pygments-2.13.0-py3-none-any.whl (1.1 MB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.1/1.1 MB\u001b[0m \u001b[31m70.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hCollecting jedi>=0.16\n",
+ " Downloading jedi-0.18.1-py2.py3-none-any.whl (1.6 MB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.6/1.6 MB\u001b[0m \u001b[31m83.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hCollecting prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0\n",
+ " Downloading prompt_toolkit-3.0.31-py3-none-any.whl (382 kB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m382.3/382.3 kB\u001b[0m \u001b[31m40.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hRequirement already satisfied: decorator in /usr/local/lib/python3.7/site-packages (from ipython>=6.1.0->ipywidgets>=7->nglview) (4.4.2)\n",
+ "Collecting parso<0.9.0,>=0.8.0\n",
+ " Downloading parso-0.8.3-py2.py3-none-any.whl (100 kB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m100.8/100.8 kB\u001b[0m \u001b[31m14.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hRequirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.7/site-packages (from jupyter-client>=6.1.12->ipykernel>=4.5.1->ipywidgets>=7->nglview) (2.8.2)\n",
+ "Collecting entrypoints\n",
+ " Downloading entrypoints-0.4-py3-none-any.whl (5.3 kB)\n",
+ "Collecting jupyter-core>=4.9.2\n",
+ " Downloading jupyter_core-4.11.1-py3-none-any.whl (88 kB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m88.4/88.4 kB\u001b[0m \u001b[31m14.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hCollecting ptyprocess>=0.5\n",
+ " Downloading ptyprocess-0.7.0-py2.py3-none-any.whl (13 kB)\n",
+ "Collecting wcwidth\n",
+ " Downloading wcwidth-0.2.5-py2.py3-none-any.whl (30 kB)\n",
+ "Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.7/site-packages (from packaging->ipykernel>=4.5.1->ipywidgets>=7->nglview) (3.0.9)\n",
+ "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.7/site-packages (from python-dateutil>=2.8.2->jupyter-client>=6.1.12->ipykernel>=4.5.1->ipywidgets>=7->nglview) (1.16.0)\n",
+ "Building wheels for collected packages: nglview\n",
+ " Building wheel for nglview (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
+ " Created wheel for nglview: filename=nglview-3.0.3-py3-none-any.whl size=8057538 sha256=b7e1071bb91822e48515bf27f4e6b197c6e85e06b90912b3439edc8be1e29514\n",
+ " Stored in directory: /root/.cache/pip/wheels/01/0c/49/c6f79d8edba8fe89752bf20de2d99040bfa57db0548975c5d5\n",
+ "Successfully built nglview\n",
+ "Installing collected packages: wcwidth, ptyprocess, pickleshare, backcall, widgetsnbextension, traitlets, tornado, pyzmq, pygments, psutil, prompt-toolkit, pexpect, parso, nest-asyncio, jupyterlab-widgets, entrypoints, debugpy, matplotlib-inline, jupyter-core, jedi, jupyter-client, ipython, ipykernel, ipywidgets, nglview\n",
+ "Successfully installed backcall-0.2.0 debugpy-1.6.3 entrypoints-0.4 ipykernel-6.16.0 ipython-7.34.0 ipywidgets-8.0.2 jedi-0.18.1 jupyter-client-7.4.2 jupyter-core-4.11.1 jupyterlab-widgets-3.0.3 matplotlib-inline-0.1.6 nest-asyncio-1.5.6 nglview-3.0.3 parso-0.8.3 pexpect-4.8.0 pickleshare-0.7.5 prompt-toolkit-3.0.31 psutil-5.9.2 ptyprocess-0.7.0 pygments-2.13.0 pyzmq-24.0.1 tornado-6.2 traitlets-5.4.0 wcwidth-0.2.5 widgetsnbextension-4.0.3\n",
+ "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
+ "\u001b[0m"
+ ]
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "application/vnd.colab-display-data+json": {
+ "pip_warning": {
+ "packages": [
+ "pexpect",
+ "pickleshare",
+ "wcwidth"
+ ]
+ }
+ }
+ },
+ "metadata": {}
+ }
+ ],
+ "source": [
+ "!pip install nglview"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## Create a diffusion model"
+ ],
+ "metadata": {
+ "id": "8t8_e_uVLdKB"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "### Model class(es)"
+ ],
+ "metadata": {
+ "id": "G0rMncVtNSqU"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Imports"
+ ],
+ "metadata": {
+ "id": "L5FEXz5oXkzt"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "# Model adapted from GeoDiff https://github.com/MinkaiXu/GeoDiff\n",
+ "# Model inspired by https://github.com/DeepGraphLearning/torchdrug/tree/master/torchdrug/models\n",
+ "from dataclasses import dataclass\n",
+ "from typing import Callable, Tuple, Union\n",
+ "\n",
+ "import numpy as np\n",
+ "import torch\n",
+ "import torch.nn.functional as F\n",
+ "from torch import Tensor, nn\n",
+ "from torch.nn import Embedding, Linear, Module, ModuleList, Sequential\n",
+ "\n",
+ "from torch_geometric.nn import MessagePassing, radius, radius_graph\n",
+ "from torch_geometric.typing import Adj, OptPairTensor, OptTensor, Size\n",
+ "from torch_geometric.utils import dense_to_sparse, to_dense_adj\n",
+ "from torch_scatter import scatter_add\n",
+ "from torch_sparse import SparseTensor, coalesce\n",
+ "\n",
+ "from diffusers.configuration_utils import ConfigMixin, register_to_config\n",
+ "from diffusers.modeling_utils import ModelMixin\n",
+ "from diffusers.utils import BaseOutput\n"
+ ],
+ "metadata": {
+ "id": "-3-P4w5sXkRU"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Helper classes"
+ ],
+ "metadata": {
+ "id": "EzJQXPN_XrMX"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "@dataclass\n",
+ "class MoleculeGNNOutput(BaseOutput):\n",
+ " \"\"\"\n",
+ " Args:\n",
+ " sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`):\n",
+ " Hidden states output. Output of last layer of model.\n",
+ " \"\"\"\n",
+ "\n",
+ " sample: torch.Tensor\n",
+ "\n",
+ "\n",
+ "class MultiLayerPerceptron(nn.Module):\n",
+ " \"\"\"\n",
+ " Multi-layer Perceptron. Note there is no activation or dropout in the last layer.\n",
+ " Args:\n",
+ " input_dim (int): input dimension\n",
+ " hidden_dim (list of int): hidden dimensions\n",
+ " activation (str or function, optional): activation function\n",
+ " dropout (float, optional): dropout rate\n",
+ " \"\"\"\n",
+ "\n",
+ " def __init__(self, input_dim, hidden_dims, activation=\"relu\", dropout=0):\n",
+ " super(MultiLayerPerceptron, self).__init__()\n",
+ "\n",
+ " self.dims = [input_dim] + hidden_dims\n",
+ " if isinstance(activation, str):\n",
+ " self.activation = getattr(F, activation)\n",
+ " else:\n",
+ " print(f\"Warning, activation passed {activation} is not string and ignored\")\n",
+ " self.activation = None\n",
+ " if dropout > 0:\n",
+ " self.dropout = nn.Dropout(dropout)\n",
+ " else:\n",
+ " self.dropout = None\n",
+ "\n",
+ " self.layers = nn.ModuleList()\n",
+ " for i in range(len(self.dims) - 1):\n",
+ " self.layers.append(nn.Linear(self.dims[i], self.dims[i + 1]))\n",
+ "\n",
+ " def forward(self, x):\n",
+ " \"\"\"\"\"\"\n",
+ " for i, layer in enumerate(self.layers):\n",
+ " x = layer(x)\n",
+ " if i < len(self.layers) - 1:\n",
+ " if self.activation:\n",
+ " x = self.activation(x)\n",
+ " if self.dropout:\n",
+ " x = self.dropout(x)\n",
+ " return x\n",
+ "\n",
+ "\n",
+ "class ShiftedSoftplus(torch.nn.Module):\n",
+ " def __init__(self):\n",
+ " super(ShiftedSoftplus, self).__init__()\n",
+ " self.shift = torch.log(torch.tensor(2.0)).item()\n",
+ "\n",
+ " def forward(self, x):\n",
+ " return F.softplus(x) - self.shift\n",
+ "\n",
+ "\n",
+ "class CFConv(MessagePassing):\n",
+ " def __init__(self, in_channels, out_channels, num_filters, mlp, cutoff, smooth):\n",
+ " super(CFConv, self).__init__(aggr=\"add\")\n",
+ " self.lin1 = Linear(in_channels, num_filters, bias=False)\n",
+ " self.lin2 = Linear(num_filters, out_channels)\n",
+ " self.nn = mlp\n",
+ " self.cutoff = cutoff\n",
+ " self.smooth = smooth\n",
+ "\n",
+ " self.reset_parameters()\n",
+ "\n",
+ " def reset_parameters(self):\n",
+ " torch.nn.init.xavier_uniform_(self.lin1.weight)\n",
+ " torch.nn.init.xavier_uniform_(self.lin2.weight)\n",
+ " self.lin2.bias.data.fill_(0)\n",
+ "\n",
+ " def forward(self, x, edge_index, edge_length, edge_attr):\n",
+ " if self.smooth:\n",
+ " C = 0.5 * (torch.cos(edge_length * np.pi / self.cutoff) + 1.0)\n",
+ " C = C * (edge_length <= self.cutoff) * (edge_length >= 0.0) # Modification: cutoff\n",
+ " else:\n",
+ " C = (edge_length <= self.cutoff).float()\n",
+ " W = self.nn(edge_attr) * C.view(-1, 1)\n",
+ "\n",
+ " x = self.lin1(x)\n",
+ " x = self.propagate(edge_index, x=x, W=W)\n",
+ " x = self.lin2(x)\n",
+ " return x\n",
+ "\n",
+ " def message(self, x_j: torch.Tensor, W) -> torch.Tensor:\n",
+ " return x_j * W\n",
+ "\n",
+ "\n",
+ "class InteractionBlock(torch.nn.Module):\n",
+ " def __init__(self, hidden_channels, num_gaussians, num_filters, cutoff, smooth):\n",
+ " super(InteractionBlock, self).__init__()\n",
+ " mlp = Sequential(\n",
+ " Linear(num_gaussians, num_filters),\n",
+ " ShiftedSoftplus(),\n",
+ " Linear(num_filters, num_filters),\n",
+ " )\n",
+ " self.conv = CFConv(hidden_channels, hidden_channels, num_filters, mlp, cutoff, smooth)\n",
+ " self.act = ShiftedSoftplus()\n",
+ " self.lin = Linear(hidden_channels, hidden_channels)\n",
+ "\n",
+ " def forward(self, x, edge_index, edge_length, edge_attr):\n",
+ " x = self.conv(x, edge_index, edge_length, edge_attr)\n",
+ " x = self.act(x)\n",
+ " x = self.lin(x)\n",
+ " return x\n",
+ "\n",
+ "\n",
+ "class SchNetEncoder(Module):\n",
+ " def __init__(\n",
+ " self, hidden_channels=128, num_filters=128, num_interactions=6, edge_channels=100, cutoff=10.0, smooth=False\n",
+ " ):\n",
+ " super().__init__()\n",
+ "\n",
+ " self.hidden_channels = hidden_channels\n",
+ " self.num_filters = num_filters\n",
+ " self.num_interactions = num_interactions\n",
+ " self.cutoff = cutoff\n",
+ "\n",
+ " self.embedding = Embedding(100, hidden_channels, max_norm=10.0)\n",
+ "\n",
+ " self.interactions = ModuleList()\n",
+ " for _ in range(num_interactions):\n",
+ " block = InteractionBlock(hidden_channels, edge_channels, num_filters, cutoff, smooth)\n",
+ " self.interactions.append(block)\n",
+ "\n",
+ " def forward(self, z, edge_index, edge_length, edge_attr, embed_node=True):\n",
+ " if embed_node:\n",
+ " assert z.dim() == 1 and z.dtype == torch.long\n",
+ " h = self.embedding(z)\n",
+ " else:\n",
+ " h = z\n",
+ " for interaction in self.interactions:\n",
+ " h = h + interaction(h, edge_index, edge_length, edge_attr)\n",
+ "\n",
+ " return h\n",
+ "\n",
+ "\n",
+ "class GINEConv(MessagePassing):\n",
+ " \"\"\"\n",
+ " Custom class of the graph isomorphism operator from the \"How Powerful are Graph Neural Networks?\n",
+ " https://arxiv.org/abs/1810.00826 paper. Note that this implementation has the added option of a custom activation.\n",
+ " \"\"\"\n",
+ "\n",
+ " def __init__(self, mlp: Callable, eps: float = 0.0, train_eps: bool = False, activation=\"softplus\", **kwargs):\n",
+ " super(GINEConv, self).__init__(aggr=\"add\", **kwargs)\n",
+ " self.nn = mlp\n",
+ " self.initial_eps = eps\n",
+ "\n",
+ " if isinstance(activation, str):\n",
+ " self.activation = getattr(F, activation)\n",
+ " else:\n",
+ " self.activation = None\n",
+ "\n",
+ " if train_eps:\n",
+ " self.eps = torch.nn.Parameter(torch.Tensor([eps]))\n",
+ " else:\n",
+ " self.register_buffer(\"eps\", torch.Tensor([eps]))\n",
+ "\n",
+ " def forward(\n",
+ " self, x: Union[Tensor, OptPairTensor], edge_index: Adj, edge_attr: OptTensor = None, size: Size = None\n",
+ " ) -> torch.Tensor:\n",
+ " \"\"\"\"\"\"\n",
+ " if isinstance(x, torch.Tensor):\n",
+ " x: OptPairTensor = (x, x)\n",
+ "\n",
+ " # Node and edge feature dimensionalites need to match.\n",
+ " if isinstance(edge_index, torch.Tensor):\n",
+ " assert edge_attr is not None\n",
+ " assert x[0].size(-1) == edge_attr.size(-1)\n",
+ " elif isinstance(edge_index, SparseTensor):\n",
+ " assert x[0].size(-1) == edge_index.size(-1)\n",
+ "\n",
+ " # propagate_type: (x: OptPairTensor, edge_attr: OptTensor)\n",
+ " out = self.propagate(edge_index, x=x, edge_attr=edge_attr, size=size)\n",
+ "\n",
+ " x_r = x[1]\n",
+ " if x_r is not None:\n",
+ " out += (1 + self.eps) * x_r\n",
+ "\n",
+ " return self.nn(out)\n",
+ "\n",
+ " def message(self, x_j: torch.Tensor, edge_attr: torch.Tensor) -> torch.Tensor:\n",
+ " if self.activation:\n",
+ " return self.activation(x_j + edge_attr)\n",
+ " else:\n",
+ " return x_j + edge_attr\n",
+ "\n",
+ " def __repr__(self):\n",
+ " return \"{}(nn={})\".format(self.__class__.__name__, self.nn)\n",
+ "\n",
+ "\n",
+ "class GINEncoder(torch.nn.Module):\n",
+ " def __init__(self, hidden_dim, num_convs=3, activation=\"relu\", short_cut=True, concat_hidden=False):\n",
+ " super().__init__()\n",
+ "\n",
+ " self.hidden_dim = hidden_dim\n",
+ " self.num_convs = num_convs\n",
+ " self.short_cut = short_cut\n",
+ " self.concat_hidden = concat_hidden\n",
+ " self.node_emb = nn.Embedding(100, hidden_dim)\n",
+ "\n",
+ " if isinstance(activation, str):\n",
+ " self.activation = getattr(F, activation)\n",
+ " else:\n",
+ " self.activation = None\n",
+ "\n",
+ " self.convs = nn.ModuleList()\n",
+ " for i in range(self.num_convs):\n",
+ " self.convs.append(\n",
+ " GINEConv(\n",
+ " MultiLayerPerceptron(hidden_dim, [hidden_dim, hidden_dim], activation=activation),\n",
+ " activation=activation,\n",
+ " )\n",
+ " )\n",
+ "\n",
+ " def forward(self, z, edge_index, edge_attr):\n",
+ " \"\"\"\n",
+ " Input:\n",
+ " data: (torch_geometric.data.Data): batched graph edge_index: bond indices of the original graph (num_node,\n",
+ " hidden) edge_attr: edge feature tensor with shape (num_edge, hidden)\n",
+ " Output:\n",
+ " node_feature: graph feature\n",
+ " \"\"\"\n",
+ "\n",
+ " node_attr = self.node_emb(z) # (num_node, hidden)\n",
+ "\n",
+ " hiddens = []\n",
+ " conv_input = node_attr # (num_node, hidden)\n",
+ "\n",
+ " for conv_idx, conv in enumerate(self.convs):\n",
+ " hidden = conv(conv_input, edge_index, edge_attr)\n",
+ " if conv_idx < len(self.convs) - 1 and self.activation is not None:\n",
+ " hidden = self.activation(hidden)\n",
+ " assert hidden.shape == conv_input.shape\n",
+ " if self.short_cut and hidden.shape == conv_input.shape:\n",
+ " hidden += conv_input\n",
+ "\n",
+ " hiddens.append(hidden)\n",
+ " conv_input = hidden\n",
+ "\n",
+ " if self.concat_hidden:\n",
+ " node_feature = torch.cat(hiddens, dim=-1)\n",
+ " else:\n",
+ " node_feature = hiddens[-1]\n",
+ "\n",
+ " return node_feature\n",
+ "\n",
+ "\n",
+ "class MLPEdgeEncoder(Module):\n",
+ " def __init__(self, hidden_dim=100, activation=\"relu\"):\n",
+ " super().__init__()\n",
+ " self.hidden_dim = hidden_dim\n",
+ " self.bond_emb = Embedding(100, embedding_dim=self.hidden_dim)\n",
+ " self.mlp = MultiLayerPerceptron(1, [self.hidden_dim, self.hidden_dim], activation=activation)\n",
+ "\n",
+ " @property\n",
+ " def out_channels(self):\n",
+ " return self.hidden_dim\n",
+ "\n",
+ " def forward(self, edge_length, edge_type):\n",
+ " \"\"\"\n",
+ " Input:\n",
+ " edge_length: The length of edges, shape=(E, 1). edge_type: The type pf edges, shape=(E,)\n",
+ " Returns:\n",
+ " edge_attr: The representation of edges. (E, 2 * num_gaussians)\n",
+ " \"\"\"\n",
+ " d_emb = self.mlp(edge_length) # (num_edge, hidden_dim)\n",
+ " edge_attr = self.bond_emb(edge_type) # (num_edge, hidden_dim)\n",
+ " return d_emb * edge_attr # (num_edge, hidden)\n",
+ "\n",
+ "\n",
+ "def assemble_atom_pair_feature(node_attr, edge_index, edge_attr):\n",
+ " h_row, h_col = node_attr[edge_index[0]], node_attr[edge_index[1]]\n",
+ " h_pair = torch.cat([h_row * h_col, edge_attr], dim=-1) # (E, 2H)\n",
+ " return h_pair\n",
+ "\n",
+ "\n",
+ "def _extend_graph_order(num_nodes, edge_index, edge_type, order=3):\n",
+ " \"\"\"\n",
+ " Args:\n",
+ " num_nodes: Number of atoms.\n",
+ " edge_index: Bond indices of the original graph.\n",
+ " edge_type: Bond types of the original graph.\n",
+ " order: Extension order.\n",
+ " Returns:\n",
+ " new_edge_index: Extended edge indices. new_edge_type: Extended edge types.\n",
+ " \"\"\"\n",
+ "\n",
+ " def binarize(x):\n",
+ " return torch.where(x > 0, torch.ones_like(x), torch.zeros_like(x))\n",
+ "\n",
+ " def get_higher_order_adj_matrix(adj, order):\n",
+ " \"\"\"\n",
+ " Args:\n",
+ " adj: (N, N)\n",
+ " type_mat: (N, N)\n",
+ " Returns:\n",
+ " Following attributes will be updated:\n",
+ " - edge_index\n",
+ " - edge_type\n",
+ " Following attributes will be added to the data object:\n",
+ " - bond_edge_index: Original edge_index.\n",
+ " \"\"\"\n",
+ " adj_mats = [\n",
+ " torch.eye(adj.size(0), dtype=torch.long, device=adj.device),\n",
+ " binarize(adj + torch.eye(adj.size(0), dtype=torch.long, device=adj.device)),\n",
+ " ]\n",
+ "\n",
+ " for i in range(2, order + 1):\n",
+ " adj_mats.append(binarize(adj_mats[i - 1] @ adj_mats[1]))\n",
+ " order_mat = torch.zeros_like(adj)\n",
+ "\n",
+ " for i in range(1, order + 1):\n",
+ " order_mat += (adj_mats[i] - adj_mats[i - 1]) * i\n",
+ "\n",
+ " return order_mat\n",
+ "\n",
+ " num_types = 22\n",
+ " # given from len(BOND_TYPES), where BOND_TYPES = {t: i for i, t in enumerate(BT.names.values())}\n",
+ " # from rdkit.Chem.rdchem import BondType as BT\n",
+ " N = num_nodes\n",
+ " adj = to_dense_adj(edge_index).squeeze(0)\n",
+ " adj_order = get_higher_order_adj_matrix(adj, order) # (N, N)\n",
+ "\n",
+ " type_mat = to_dense_adj(edge_index, edge_attr=edge_type).squeeze(0) # (N, N)\n",
+ " type_highorder = torch.where(adj_order > 1, num_types + adj_order - 1, torch.zeros_like(adj_order))\n",
+ " assert (type_mat * type_highorder == 0).all()\n",
+ " type_new = type_mat + type_highorder\n",
+ "\n",
+ " new_edge_index, new_edge_type = dense_to_sparse(type_new)\n",
+ " _, edge_order = dense_to_sparse(adj_order)\n",
+ "\n",
+ " # data.bond_edge_index = data.edge_index # Save original edges\n",
+ " new_edge_index, new_edge_type = coalesce(new_edge_index, new_edge_type.long(), N, N) # modify data\n",
+ "\n",
+ " return new_edge_index, new_edge_type\n",
+ "\n",
+ "\n",
+ "def _extend_to_radius_graph(pos, edge_index, edge_type, cutoff, batch, unspecified_type_number=0, is_sidechain=None):\n",
+ " assert edge_type.dim() == 1\n",
+ " N = pos.size(0)\n",
+ "\n",
+ " bgraph_adj = torch.sparse.LongTensor(edge_index, edge_type, torch.Size([N, N]))\n",
+ "\n",
+ " if is_sidechain is None:\n",
+ " rgraph_edge_index = radius_graph(pos, r=cutoff, batch=batch) # (2, E_r)\n",
+ " else:\n",
+ " # fetch sidechain and its batch index\n",
+ " is_sidechain = is_sidechain.bool()\n",
+ " dummy_index = torch.arange(pos.size(0), device=pos.device)\n",
+ " sidechain_pos = pos[is_sidechain]\n",
+ " sidechain_index = dummy_index[is_sidechain]\n",
+ " sidechain_batch = batch[is_sidechain]\n",
+ "\n",
+ " assign_index = radius(x=pos, y=sidechain_pos, r=cutoff, batch_x=batch, batch_y=sidechain_batch)\n",
+ " r_edge_index_x = assign_index[1]\n",
+ " r_edge_index_y = assign_index[0]\n",
+ " r_edge_index_y = sidechain_index[r_edge_index_y]\n",
+ "\n",
+ " rgraph_edge_index1 = torch.stack((r_edge_index_x, r_edge_index_y)) # (2, E)\n",
+ " rgraph_edge_index2 = torch.stack((r_edge_index_y, r_edge_index_x)) # (2, E)\n",
+ " rgraph_edge_index = torch.cat((rgraph_edge_index1, rgraph_edge_index2), dim=-1) # (2, 2E)\n",
+ " # delete self loop\n",
+ " rgraph_edge_index = rgraph_edge_index[:, (rgraph_edge_index[0] != rgraph_edge_index[1])]\n",
+ "\n",
+ " rgraph_adj = torch.sparse.LongTensor(\n",
+ " rgraph_edge_index,\n",
+ " torch.ones(rgraph_edge_index.size(1)).long().to(pos.device) * unspecified_type_number,\n",
+ " torch.Size([N, N]),\n",
+ " )\n",
+ "\n",
+ " composed_adj = (bgraph_adj + rgraph_adj).coalesce() # Sparse (N, N, T)\n",
+ "\n",
+ " new_edge_index = composed_adj.indices()\n",
+ " new_edge_type = composed_adj.values().long()\n",
+ "\n",
+ " return new_edge_index, new_edge_type\n",
+ "\n",
+ "\n",
+ "def extend_graph_order_radius(\n",
+ " num_nodes,\n",
+ " pos,\n",
+ " edge_index,\n",
+ " edge_type,\n",
+ " batch,\n",
+ " order=3,\n",
+ " cutoff=10.0,\n",
+ " extend_order=True,\n",
+ " extend_radius=True,\n",
+ " is_sidechain=None,\n",
+ "):\n",
+ " if extend_order:\n",
+ " edge_index, edge_type = _extend_graph_order(\n",
+ " num_nodes=num_nodes, edge_index=edge_index, edge_type=edge_type, order=order\n",
+ " )\n",
+ "\n",
+ " if extend_radius:\n",
+ " edge_index, edge_type = _extend_to_radius_graph(\n",
+ " pos=pos, edge_index=edge_index, edge_type=edge_type, cutoff=cutoff, batch=batch, is_sidechain=is_sidechain\n",
+ " )\n",
+ "\n",
+ " return edge_index, edge_type\n",
+ "\n",
+ "\n",
+ "def get_distance(pos, edge_index):\n",
+ " return (pos[edge_index[0]] - pos[edge_index[1]]).norm(dim=-1)\n",
+ "\n",
+ "\n",
+ "def graph_field_network(score_d, pos, edge_index, edge_length):\n",
+ " \"\"\"\n",
+ " Transformation to make the epsilon predicted from the diffusion model roto-translational equivariant. See equations\n",
+ " 5-7 of the GeoDiff Paper https://arxiv.org/pdf/2203.02923.pdf\n",
+ " \"\"\"\n",
+ " N = pos.size(0)\n",
+ " dd_dr = (1.0 / edge_length) * (pos[edge_index[0]] - pos[edge_index[1]]) # (E, 3)\n",
+ " score_pos = scatter_add(dd_dr * score_d, edge_index[0], dim=0, dim_size=N) + scatter_add(\n",
+ " -dd_dr * score_d, edge_index[1], dim=0, dim_size=N\n",
+ " ) # (N, 3)\n",
+ " return score_pos\n",
+ "\n",
+ "\n",
+ "def clip_norm(vec, limit, p=2):\n",
+ " norm = torch.norm(vec, dim=-1, p=2, keepdim=True)\n",
+ " denom = torch.where(norm > limit, limit / norm, torch.ones_like(norm))\n",
+ " return vec * denom\n",
+ "\n",
+ "\n",
+ "def is_local_edge(edge_type):\n",
+ " return edge_type > 0\n"
+ ],
+ "metadata": {
+ "id": "oR1Y56QiLY90"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Main model class!"
+ ],
+ "metadata": {
+ "id": "QWrHJFcYXyUB"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "class MoleculeGNN(ModelMixin, ConfigMixin):\n",
+ " @register_to_config\n",
+ " def __init__(\n",
+ " self,\n",
+ " hidden_dim=128,\n",
+ " num_convs=6,\n",
+ " num_convs_local=4,\n",
+ " cutoff=10.0,\n",
+ " mlp_act=\"relu\",\n",
+ " edge_order=3,\n",
+ " edge_encoder=\"mlp\",\n",
+ " smooth_conv=True,\n",
+ " ):\n",
+ " super().__init__()\n",
+ " self.cutoff = cutoff\n",
+ " self.edge_encoder = edge_encoder\n",
+ " self.edge_order = edge_order\n",
+ "\n",
+ " \"\"\"\n",
+ " edge_encoder: Takes both edge type and edge length as input and outputs a vector [Note]: node embedding is done\n",
+ " in SchNetEncoder\n",
+ " \"\"\"\n",
+ " self.edge_encoder_global = MLPEdgeEncoder(hidden_dim, mlp_act) # get_edge_encoder(config)\n",
+ " self.edge_encoder_local = MLPEdgeEncoder(hidden_dim, mlp_act) # get_edge_encoder(config)\n",
+ "\n",
+ " \"\"\"\n",
+ " The graph neural network that extracts node-wise features.\n",
+ " \"\"\"\n",
+ " self.encoder_global = SchNetEncoder(\n",
+ " hidden_channels=hidden_dim,\n",
+ " num_filters=hidden_dim,\n",
+ " num_interactions=num_convs,\n",
+ " edge_channels=self.edge_encoder_global.out_channels,\n",
+ " cutoff=cutoff,\n",
+ " smooth=smooth_conv,\n",
+ " )\n",
+ " self.encoder_local = GINEncoder(\n",
+ " hidden_dim=hidden_dim,\n",
+ " num_convs=num_convs_local,\n",
+ " )\n",
+ "\n",
+ " \"\"\"\n",
+ " `output_mlp` takes a mixture of two nodewise features and edge features as input and outputs\n",
+ " gradients w.r.t. edge_length (out_dim = 1).\n",
+ " \"\"\"\n",
+ " self.grad_global_dist_mlp = MultiLayerPerceptron(\n",
+ " 2 * hidden_dim, [hidden_dim, hidden_dim // 2, 1], activation=mlp_act\n",
+ " )\n",
+ "\n",
+ " self.grad_local_dist_mlp = MultiLayerPerceptron(\n",
+ " 2 * hidden_dim, [hidden_dim, hidden_dim // 2, 1], activation=mlp_act\n",
+ " )\n",
+ "\n",
+ " \"\"\"\n",
+ " Incorporate parameters together\n",
+ " \"\"\"\n",
+ " self.model_global = nn.ModuleList([self.edge_encoder_global, self.encoder_global, self.grad_global_dist_mlp])\n",
+ " self.model_local = nn.ModuleList([self.edge_encoder_local, self.encoder_local, self.grad_local_dist_mlp])\n",
+ "\n",
+ " def _forward(\n",
+ " self,\n",
+ " atom_type,\n",
+ " pos,\n",
+ " bond_index,\n",
+ " bond_type,\n",
+ " batch,\n",
+ " time_step, # NOTE, model trained without timestep performed best\n",
+ " edge_index=None,\n",
+ " edge_type=None,\n",
+ " edge_length=None,\n",
+ " return_edges=False,\n",
+ " extend_order=True,\n",
+ " extend_radius=True,\n",
+ " is_sidechain=None,\n",
+ " ):\n",
+ " \"\"\"\n",
+ " Args:\n",
+ " atom_type: Types of atoms, (N, ).\n",
+ " bond_index: Indices of bonds (not extended, not radius-graph), (2, E).\n",
+ " bond_type: Bond types, (E, ).\n",
+ " batch: Node index to graph index, (N, ).\n",
+ " \"\"\"\n",
+ " N = atom_type.size(0)\n",
+ " if edge_index is None or edge_type is None or edge_length is None:\n",
+ " edge_index, edge_type = extend_graph_order_radius(\n",
+ " num_nodes=N,\n",
+ " pos=pos,\n",
+ " edge_index=bond_index,\n",
+ " edge_type=bond_type,\n",
+ " batch=batch,\n",
+ " order=self.edge_order,\n",
+ " cutoff=self.cutoff,\n",
+ " extend_order=extend_order,\n",
+ " extend_radius=extend_radius,\n",
+ " is_sidechain=is_sidechain,\n",
+ " )\n",
+ " edge_length = get_distance(pos, edge_index).unsqueeze(-1) # (E, 1)\n",
+ " local_edge_mask = is_local_edge(edge_type) # (E, )\n",
+ "\n",
+ " # with the parameterization of NCSNv2\n",
+ " # DDPM loss implicit handle the noise variance scale conditioning\n",
+ " sigma_edge = torch.ones(size=(edge_index.size(1), 1), device=pos.device) # (E, 1)\n",
+ "\n",
+ " # Encoding global\n",
+ " edge_attr_global = self.edge_encoder_global(edge_length=edge_length, edge_type=edge_type) # Embed edges\n",
+ "\n",
+ " # Global\n",
+ " node_attr_global = self.encoder_global(\n",
+ " z=atom_type,\n",
+ " edge_index=edge_index,\n",
+ " edge_length=edge_length,\n",
+ " edge_attr=edge_attr_global,\n",
+ " )\n",
+ " # Assemble pairwise features\n",
+ " h_pair_global = assemble_atom_pair_feature(\n",
+ " node_attr=node_attr_global,\n",
+ " edge_index=edge_index,\n",
+ " edge_attr=edge_attr_global,\n",
+ " ) # (E_global, 2H)\n",
+ " # Invariant features of edges (radius graph, global)\n",
+ " edge_inv_global = self.grad_global_dist_mlp(h_pair_global) * (1.0 / sigma_edge) # (E_global, 1)\n",
+ "\n",
+ " # Encoding local\n",
+ " edge_attr_local = self.edge_encoder_global(edge_length=edge_length, edge_type=edge_type) # Embed edges\n",
+ " # edge_attr += temb_edge\n",
+ "\n",
+ " # Local\n",
+ " node_attr_local = self.encoder_local(\n",
+ " z=atom_type,\n",
+ " edge_index=edge_index[:, local_edge_mask],\n",
+ " edge_attr=edge_attr_local[local_edge_mask],\n",
+ " )\n",
+ " # Assemble pairwise features\n",
+ " h_pair_local = assemble_atom_pair_feature(\n",
+ " node_attr=node_attr_local,\n",
+ " edge_index=edge_index[:, local_edge_mask],\n",
+ " edge_attr=edge_attr_local[local_edge_mask],\n",
+ " ) # (E_local, 2H)\n",
+ "\n",
+ " # Invariant features of edges (bond graph, local)\n",
+ " if isinstance(sigma_edge, torch.Tensor):\n",
+ " edge_inv_local = self.grad_local_dist_mlp(h_pair_local) * (\n",
+ " 1.0 / sigma_edge[local_edge_mask]\n",
+ " ) # (E_local, 1)\n",
+ " else:\n",
+ " edge_inv_local = self.grad_local_dist_mlp(h_pair_local) * (1.0 / sigma_edge) # (E_local, 1)\n",
+ "\n",
+ " if return_edges:\n",
+ " return edge_inv_global, edge_inv_local, edge_index, edge_type, edge_length, local_edge_mask\n",
+ " else:\n",
+ " return edge_inv_global, edge_inv_local\n",
+ "\n",
+ " def forward(\n",
+ " self,\n",
+ " sample,\n",
+ " timestep: Union[torch.Tensor, float, int],\n",
+ " return_dict: bool = True,\n",
+ " sigma=1.0,\n",
+ " global_start_sigma=0.5,\n",
+ " w_global=1.0,\n",
+ " extend_order=False,\n",
+ " extend_radius=True,\n",
+ " clip_local=None,\n",
+ " clip_global=1000.0,\n",
+ " ) -> Union[MoleculeGNNOutput, Tuple]:\n",
+ " r\"\"\"\n",
+ " Args:\n",
+ " sample: packed torch geometric object\n",
+ " timestep (`torch.Tensor` or `float` or `int): TODO verify type and shape (batch) timesteps\n",
+ " return_dict (`bool`, *optional*, defaults to `True`):\n",
+ " Whether or not to return a [`~models.molecule_gnn.MoleculeGNNOutput`] instead of a plain tuple.\n",
+ " Returns:\n",
+ " [`~models.molecule_gnn.MoleculeGNNOutput`] or `tuple`: [`~models.molecule_gnn.MoleculeGNNOutput`] if\n",
+ " `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.\n",
+ " \"\"\"\n",
+ "\n",
+ " # unpack sample\n",
+ " atom_type = sample.atom_type\n",
+ " bond_index = sample.edge_index\n",
+ " bond_type = sample.edge_type\n",
+ " num_graphs = sample.num_graphs\n",
+ " pos = sample.pos\n",
+ "\n",
+ " timesteps = torch.full(size=(num_graphs,), fill_value=timestep, dtype=torch.long, device=pos.device)\n",
+ "\n",
+ " edge_inv_global, edge_inv_local, edge_index, edge_type, edge_length, local_edge_mask = self._forward(\n",
+ " atom_type=atom_type,\n",
+ " pos=sample.pos,\n",
+ " bond_index=bond_index,\n",
+ " bond_type=bond_type,\n",
+ " batch=sample.batch,\n",
+ " time_step=timesteps,\n",
+ " return_edges=True,\n",
+ " extend_order=extend_order,\n",
+ " extend_radius=extend_radius,\n",
+ " ) # (E_global, 1), (E_local, 1)\n",
+ "\n",
+ " # Important equation in the paper for equivariant features - eqns 5-7 of GeoDiff\n",
+ " node_eq_local = graph_field_network(\n",
+ " edge_inv_local, pos, edge_index[:, local_edge_mask], edge_length[local_edge_mask]\n",
+ " )\n",
+ " if clip_local is not None:\n",
+ " node_eq_local = clip_norm(node_eq_local, limit=clip_local)\n",
+ "\n",
+ " # Global\n",
+ " if sigma < global_start_sigma:\n",
+ " edge_inv_global = edge_inv_global * (1 - local_edge_mask.view(-1, 1).float())\n",
+ " node_eq_global = graph_field_network(edge_inv_global, pos, edge_index, edge_length)\n",
+ " node_eq_global = clip_norm(node_eq_global, limit=clip_global)\n",
+ " else:\n",
+ " node_eq_global = 0\n",
+ "\n",
+ " # Sum\n",
+ " eps_pos = node_eq_local + node_eq_global * w_global\n",
+ "\n",
+ " if not return_dict:\n",
+ " return (-eps_pos,)\n",
+ "\n",
+ " return MoleculeGNNOutput(sample=torch.Tensor(-eps_pos).to(pos.device))"
+ ],
+ "metadata": {
+ "id": "MCeZA1qQXzoK"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "CCIrPYSJj9wd"
+ },
+ "source": [
+ "### Load pretrained model"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "YdrAr6Ch--Ab"
+ },
+ "source": [
+ "#### Load a model\n",
+ "The model used is a design an\n",
+ "equivariant convolutional layer, named graph field network (GFN).\n",
+ "\n",
+ "The warning about `betas` and `alphas` can be ignored, those were moved to the scheduler."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "DyCo0nsqjbml",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 172,
+ "referenced_widgets": [
+ "d90f304e9560472eacfbdd11e46765eb",
+ "1c6246f15b654f4daa11c9bcf997b78c",
+ "c2321b3bff6f490ca12040a20308f555",
+ "b7feb522161f4cf4b7cc7c1a078ff12d",
+ "e2d368556e494ae7ae4e2e992af2cd4f",
+ "bbef741e76ec41b7ab7187b487a383df",
+ "561f742d418d4721b0670cc8dd62e22c",
+ "872915dd1bb84f538c44e26badabafdd",
+ "d022575f1fa2446d891650897f187b4d",
+ "fdc393f3468c432aa0ada05e238a5436",
+ "2c9362906e4b40189f16d14aa9a348da",
+ "6010fc8daa7a44d5aec4b830ec2ebaa1",
+ "7e0bb1b8d65249d3974200686b193be2",
+ "ba98aa6d6a884e4ab8bbb5dfb5e4cf7a",
+ "6526646be5ed415c84d1245b040e629b",
+ "24d31fc3576e43dd9f8301d2ef3a37ab",
+ "2918bfaadc8d4b1a9832522c40dfefb8",
+ "a4bfdca35cc54dae8812720f1b276a08",
+ "e4901541199b45c6a18824627692fc39",
+ "f915cf874246446595206221e900b2fe",
+ "a9e388f22a9742aaaf538e22575c9433",
+ "42f6c3db29d7484ba6b4f73590abd2f4"
+ ]
+ },
+ "outputId": "d6bce9d5-c51e-43a4-e680-e1e81bdfaf45"
+ },
+ "outputs": [
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ "Downloading: 0%| | 0.00/3.27M [00:00, ?B/s]"
+ ],
+ "application/vnd.jupyter.widget-view+json": {
+ "version_major": 2,
+ "version_minor": 0,
+ "model_id": "d90f304e9560472eacfbdd11e46765eb"
+ }
+ },
+ "metadata": {}
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ "Downloading: 0%| | 0.00/401 [00:00, ?B/s]"
+ ],
+ "application/vnd.jupyter.widget-view+json": {
+ "version_major": 2,
+ "version_minor": 0,
+ "model_id": "6010fc8daa7a44d5aec4b830ec2ebaa1"
+ }
+ },
+ "metadata": {}
+ },
+ {
+ "output_type": "stream",
+ "name": "stderr",
+ "text": [
+ "The config attributes {'type': 'diffusion', 'network': 'dualenc', 'beta_schedule': 'sigmoid', 'beta_start': 1e-07, 'beta_end': 0.002, 'num_diffusion_timesteps': 5000} were passed to MoleculeGNN, but are not expected and will be ignored. Please verify your config.json configuration file.\n",
+ "Some weights of the model checkpoint at fusing/gfn-molecule-gen-drugs were not used when initializing MoleculeGNN: ['betas', 'alphas']\n",
+ "- This IS expected if you are initializing MoleculeGNN from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
+ "- This IS NOT expected if you are initializing MoleculeGNN from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
+ ]
+ }
+ ],
+ "source": [
+ "DEVICE = 'cuda'\n",
+ "model = MoleculeGNN.from_pretrained(\"fusing/gfn-molecule-gen-drugs\").to(DEVICE)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "The warnings above are because the pre-trained model was uploaded before cleaning the code!"
+ ],
+ "metadata": {
+ "id": "HdclRaqoUWUD"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "PlOkPySoJ1m9"
+ },
+ "source": [
+ "#### Create scheduler\n",
+ "Note, other schedulers are used in the paper for slightly improved performance over DDPM."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "nNHnIk9CkAb2"
+ },
+ "outputs": [],
+ "source": [
+ "from diffusers import DDPMScheduler"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "RnDJdDBztjFF"
+ },
+ "outputs": [],
+ "source": [
+ "num_timesteps = 1000\n",
+ "scheduler = DDPMScheduler(num_train_timesteps=num_timesteps,beta_schedule=\"sigmoid\",beta_start=1e-7, beta_end=2e-3, clip_sample=False)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "1vh3fpSAflkL"
+ },
+ "source": [
+ "### Get a dataset"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "B6qzaGjVKFVk"
+ },
+ "source": [
+ "Grab a google tool so we can upload our data directly. Note you need to download the data from ***this [file](https://huggingface.co/datasets/fusing/geodiff-example-data/blob/main/data/molecules.pkl)***\n",
+ "\n",
+ "(direct downloading from the hub does not yet work for this datatype)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "jbLl3EJdgj3x"
+ },
+ "outputs": [],
+ "source": [
+ "# from google.colab import files"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "E591lVuTgxPE"
+ },
+ "outputs": [],
+ "source": [
+ "# uploaded = files.upload()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "KUNxfK3ln98Q"
+ },
+ "source": [
+ "Load the dataset with torch."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "7L4iOShTpcQX",
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "outputId": "7f2dcd29-493e-44de-98d1-3ad50f109a4a"
+ },
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "--2022-10-12 18:32:19-- https://huggingface.co/datasets/fusing/geodiff-example-data/resolve/main/data/molecules.pkl\n",
+ "Resolving huggingface.co (huggingface.co)... 44.195.102.200, 52.5.54.249, 54.210.225.113, ...\n",
+ "Connecting to huggingface.co (huggingface.co)|44.195.102.200|:443... connected.\n",
+ "HTTP request sent, awaiting response... 200 OK\n",
+ "Length: 127774 (125K) [application/octet-stream]\n",
+ "Saving to: ‘molecules.pkl’\n",
+ "\n",
+ "molecules.pkl 100%[===================>] 124.78K 180KB/s in 0.7s \n",
+ "\n",
+ "2022-10-12 18:32:20 (180 KB/s) - ‘molecules.pkl’ saved [127774/127774]\n",
+ "\n"
+ ]
+ }
+ ],
+ "source": [
+ "import torch\n",
+ "import numpy as np\n",
+ "\n",
+ "!wget https://huggingface.co/datasets/fusing/geodiff-example-data/resolve/main/data/molecules.pkl\n",
+ "dataset = torch.load('/content/molecules.pkl')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "QZcmy1EvKQRk"
+ },
+ "source": [
+ "Print out one entry of the dataset, it contains molecular formulas, atom types, positions, and more."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "JVjz6iH_H6Eh",
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "outputId": "898cb0cf-a0b3-411b-fd4c-bea1fbfd17fe"
+ },
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "Data(atom_type=[51], bond_edge_index=[2, 108], edge_index=[2, 598], edge_order=[598], edge_type=[598], idx=[1], is_bond=[598], num_nodes_per_graph=[1], num_pos_ref=[1], nx=, pos=[51, 3], pos_ref=[255, 3], rdmol=, smiles=\"CC1CCCN(C(=O)C2CCN(S(=O)(=O)c3cccc4nonc34)CC2)C1\")"
+ ]
+ },
+ "metadata": {},
+ "execution_count": 20
+ }
+ ],
+ "source": [
+ "dataset[0]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## Run the diffusion process"
+ ],
+ "metadata": {
+ "id": "vHNiZAUxNgoy"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "jZ1KZrxKqENg"
+ },
+ "source": [
+ "#### Helper Functions"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "s240tYueqKKf"
+ },
+ "outputs": [],
+ "source": [
+ "from torch_geometric.data import Data, Batch\n",
+ "from torch_scatter import scatter_add, scatter_mean\n",
+ "from tqdm import tqdm\n",
+ "import copy\n",
+ "import os\n",
+ "\n",
+ "def repeat_data(data: Data, num_repeat) -> Batch:\n",
+ " datas = [copy.deepcopy(data) for i in range(num_repeat)]\n",
+ " return Batch.from_data_list(datas)\n",
+ "\n",
+ "def repeat_batch(batch: Batch, num_repeat) -> Batch:\n",
+ " datas = batch.to_data_list()\n",
+ " new_data = []\n",
+ " for i in range(num_repeat):\n",
+ " new_data += copy.deepcopy(datas)\n",
+ " return Batch.from_data_list(new_data)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "AMnQTk0eqT7Z"
+ },
+ "source": [
+ "#### Constants"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "WYGkzqgzrHmF"
+ },
+ "outputs": [],
+ "source": [
+ "num_samples = 1 # solutions per molecule\n",
+ "num_molecules = 3\n",
+ "\n",
+ "DEVICE = 'cuda'\n",
+ "sampling_type = 'ddpm_noisy' #'' # paper also uses \"generalize\" and \"ld\"\n",
+ "# constants for inference\n",
+ "w_global = 0.5 #0,.3 for qm9\n",
+ "global_start_sigma = 0.5\n",
+ "eta = 1.0\n",
+ "clip_local = None\n",
+ "clip_pos = None\n",
+ "\n",
+ "# constands for data handling\n",
+ "save_traj = False\n",
+ "save_data = False\n",
+ "output_dir = '/content/'"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "-xD5bJ3SqM7t"
+ },
+ "source": [
+ "#### Generate samples!\n",
+ "Note that the 3d representation of a molecule is referred to as the **conformation**"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "x9xuLUNg26z1",
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "outputId": "236d2a60-09ed-4c4d-97c1-6e3c0f2d26c4"
+ },
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stderr",
+ "text": [
+ "/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:4: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " after removing the cwd from sys.path.\n",
+ "100%|██████████| 5/5 [00:55<00:00, 11.06s/it]\n"
+ ]
+ }
+ ],
+ "source": [
+ "results = []\n",
+ "\n",
+ "# define sigmas\n",
+ "sigmas = torch.tensor(1.0 - scheduler.alphas_cumprod).sqrt() / torch.tensor(scheduler.alphas_cumprod).sqrt()\n",
+ "sigmas = sigmas.to(DEVICE)\n",
+ "\n",
+ "for count, data in enumerate(tqdm(dataset)):\n",
+ " num_samples = max(data.pos_ref.size(0) // data.num_nodes, 1)\n",
+ "\n",
+ " data_input = data.clone()\n",
+ " data_input['pos_ref'] = None\n",
+ " batch = repeat_data(data_input, num_samples).to(DEVICE)\n",
+ "\n",
+ " # initial configuration\n",
+ " pos_init = torch.randn(batch.num_nodes, 3).to(DEVICE)\n",
+ "\n",
+ " # for logging animation of denoising\n",
+ " pos_traj = []\n",
+ " with torch.no_grad():\n",
+ "\n",
+ " # scale initial sample\n",
+ " pos = pos_init * sigmas[-1]\n",
+ " for t in scheduler.timesteps:\n",
+ " batch.pos = pos\n",
+ "\n",
+ " # generate geometry with model, then filter it\n",
+ " epsilon = model.forward(batch, t, sigma=sigmas[t], return_dict=False)[0]\n",
+ "\n",
+ " # Update\n",
+ " reconstructed_pos = scheduler.step(epsilon, t, pos)[\"prev_sample\"].to(DEVICE)\n",
+ "\n",
+ " pos = reconstructed_pos\n",
+ "\n",
+ " if torch.isnan(pos).any():\n",
+ " print(\"NaN detected. Please restart.\")\n",
+ " raise FloatingPointError()\n",
+ "\n",
+ " # recenter graph of positions for next iteration\n",
+ " pos = pos - scatter_mean(pos, batch.batch, dim=0)[batch.batch]\n",
+ "\n",
+ " # optional clipping\n",
+ " if clip_pos is not None:\n",
+ " pos = torch.clamp(pos, min=-clip_pos, max=clip_pos)\n",
+ " pos_traj.append(pos.clone().cpu())\n",
+ "\n",
+ " pos_gen = pos.cpu()\n",
+ " if save_traj:\n",
+ " pos_gen_traj = pos_traj.cpu()\n",
+ " data.pos_gen = torch.stack(pos_gen_traj)\n",
+ " else:\n",
+ " data.pos_gen = pos_gen\n",
+ " results.append(data)\n",
+ "\n",
+ "\n",
+ "if save_data:\n",
+ " save_path = os.path.join(output_dir, 'samples_all.pkl')\n",
+ "\n",
+ " with open(save_path, 'wb') as f:\n",
+ " pickle.dump(results, f)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## Render the results!"
+ ],
+ "metadata": {
+ "id": "fSApwSaZNndW"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "d47Zxo2OKdgZ"
+ },
+ "source": [
+ "This function allows us to render 3d in colab."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "e9Cd0kCAv9b8"
+ },
+ "outputs": [],
+ "source": [
+ "from google.colab import output\n",
+ "output.enable_custom_widget_manager()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "### Helper functions"
+ ],
+ "metadata": {
+ "id": "RjaVuR15NqzF"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "28rBYa9NKhlz"
+ },
+ "source": [
+ "Here is a helper function for copying the generated tensors into a format used by RDKit & NGLViewer."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "LKdKdwxcyTQ6"
+ },
+ "outputs": [],
+ "source": [
+ "from copy import deepcopy\n",
+ "def set_rdmol_positions(rdkit_mol, pos):\n",
+ " \"\"\"\n",
+ " Args:\n",
+ " rdkit_mol: An `rdkit.Chem.rdchem.Mol` object.\n",
+ " pos: (N_atoms, 3)\n",
+ " \"\"\"\n",
+ " mol = deepcopy(rdkit_mol)\n",
+ " set_rdmol_positions_(mol, pos)\n",
+ " return mol\n",
+ "\n",
+ "def set_rdmol_positions_(mol, pos):\n",
+ " \"\"\"\n",
+ " Args:\n",
+ " rdkit_mol: An `rdkit.Chem.rdchem.Mol` object.\n",
+ " pos: (N_atoms, 3)\n",
+ " \"\"\"\n",
+ " for i in range(pos.shape[0]):\n",
+ " mol.GetConformer(0).SetAtomPosition(i, pos[i].tolist())\n",
+ " return mol\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "NuE10hcpKmzK"
+ },
+ "source": [
+ "Process the generated data to make it easy to view."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "KieVE1vc0_Vs",
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "outputId": "6faa185d-b1bc-47e8-be18-30d1e557e7c8"
+ },
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "collect 5 generated molecules in `mols`\n"
+ ]
+ }
+ ],
+ "source": [
+ "# the model can generate multiple conformations per 2d geometry\n",
+ "num_gen = results[0]['pos_gen'].shape[0]\n",
+ "\n",
+ "# init storage objects\n",
+ "mols_gen = []\n",
+ "mols_orig = []\n",
+ "for to_process in results:\n",
+ "\n",
+ " # store the reference 3d position\n",
+ " to_process['pos_ref'] = to_process['pos_ref'].reshape(-1, to_process['rdmol'].GetNumAtoms(), 3)\n",
+ "\n",
+ " # store the generated 3d position\n",
+ " to_process['pos_gen'] = to_process['pos_gen'].reshape(-1, to_process['rdmol'].GetNumAtoms(), 3)\n",
+ "\n",
+ " # copy data to new object\n",
+ " new_mol = set_rdmol_positions(to_process.rdmol, to_process['pos_gen'][0])\n",
+ "\n",
+ " # append results\n",
+ " mols_gen.append(new_mol)\n",
+ " mols_orig.append(to_process.rdmol)\n",
+ "\n",
+ "print(f\"collect {len(mols_gen)} generated molecules in `mols`\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "tin89JwMKp4v"
+ },
+ "source": [
+ "Import tools to visualize the 2d chemical diagram of the molecule."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "yqV6gllSZn38"
+ },
+ "outputs": [],
+ "source": [
+ "from rdkit.Chem import AllChem\n",
+ "from rdkit import Chem\n",
+ "from rdkit.Chem.Draw import rdMolDraw2D as MD2\n",
+ "from IPython.display import SVG, display"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "TFNKmGddVoOk"
+ },
+ "source": [
+ "Select molecule to visualize"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "KzuwLlrrVaGc"
+ },
+ "outputs": [],
+ "source": [
+ "idx = 0\n",
+ "assert idx < len(results), \"selected molecule that was not generated\""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "### Viewing"
+ ],
+ "metadata": {
+ "id": "hkb8w0_SNtU8"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "I3R4QBQeKttN"
+ },
+ "source": [
+ "This 2D rendering is the equivalent of the **input to the model**!"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "gkQRWjraaKex",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 321
+ },
+ "outputId": "9c3d1a91-a51d-475d-9e34-2be2459abc47"
+ },
+ "outputs": [
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ ""
+ ],
+ "image/svg+xml": "\n\n \n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n "
+ },
+ "metadata": {}
+ }
+ ],
+ "source": [
+ "mc = Chem.MolFromSmiles(dataset[0]['smiles'])\n",
+ "molSize=(450,300)\n",
+ "drawer = MD2.MolDraw2DSVG(molSize[0],molSize[1])\n",
+ "drawer.DrawMolecule(mc)\n",
+ "drawer.FinishDrawing()\n",
+ "svg = drawer.GetDrawingText()\n",
+ "display(SVG(svg.replace('svg:','')))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "z4FDMYMxKw2I"
+ },
+ "source": [
+ "Generate the 3d molecule!"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "aT1Bkb8YxJfV",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 17,
+ "referenced_widgets": [
+ "695ab5bbf30a4ab19df1f9f33469f314",
+ "eac6a8dcdc9d4335a2e51031793ead29"
+ ]
+ },
+ "outputId": "b98870ae-049d-4386-b676-166e9526bda2"
+ },
+ "outputs": [
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [],
+ "application/vnd.jupyter.widget-view+json": {
+ "version_major": 2,
+ "version_minor": 0,
+ "model_id": "695ab5bbf30a4ab19df1f9f33469f314"
+ }
+ },
+ "metadata": {
+ "application/vnd.jupyter.widget-view+json": {
+ "colab": {
+ "custom_widget_manager": {
+ "url": "https://ssl.gstatic.com/colaboratory-static/widgets/colab-cdn-widget-manager/d2e234f7cc04bf79/manager.min.js"
+ }
+ }
+ }
+ }
+ }
+ ],
+ "source": [
+ "from nglview import show_rdkit as show"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "pxtq8I-I18C-",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 337,
+ "referenced_widgets": [
+ "be446195da2b4ff2aec21ec5ff963a54",
+ "c6596896148b4a8a9c57963b67c7782f",
+ "2489b5e5648541fbbdceadb05632a050",
+ "01e0ba4e5da04914b4652b8d58565d7b",
+ "c30e6c2f3e2a44dbbb3d63bd519acaa4",
+ "f31c6e40e9b2466a9064a2669933ecd5",
+ "19308ccac642498ab8b58462e3f1b0bb",
+ "4a081cdc2ec3421ca79dd933b7e2b0c4",
+ "e5c0d75eb5e1447abd560c8f2c6017e1",
+ "5146907ef6764654ad7d598baebc8b58",
+ "144ec959b7604a2cabb5ca46ae5e5379",
+ "abce2a80e6304df3899109c6d6cac199",
+ "65195cb7a4134f4887e9dd19f3676462"
+ ]
+ },
+ "outputId": "72ed63ac-d2ec-4f5c-a0b1-4e7c1840a4e7"
+ },
+ "outputs": [
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ "NGLWidget()"
+ ],
+ "application/vnd.jupyter.widget-view+json": {
+ "version_major": 2,
+ "version_minor": 0,
+ "model_id": "be446195da2b4ff2aec21ec5ff963a54"
+ }
+ },
+ "metadata": {
+ "application/vnd.jupyter.widget-view+json": {
+ "colab": {
+ "custom_widget_manager": {
+ "url": "https://ssl.gstatic.com/colaboratory-static/widgets/colab-cdn-widget-manager/d2e234f7cc04bf79/manager.min.js"
+ }
+ }
+ }
+ }
+ }
+ ],
+ "source": [
+ "# new molecule\n",
+ "show(mols_gen[idx])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [],
+ "metadata": {
+ "id": "KJr4h2mwXeTo"
+ },
+ "execution_count": null,
+ "outputs": []
+ }
+ ],
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "provenance": []
+ },
+ "gpuClass": "standard",
+ "kernelspec": {
+ "display_name": "Python 3",
+ "name": "python3"
+ },
+ "language_info": {
+ "name": "python"
+ },
+ "widgets": {
+ "application/vnd.jupyter.widget-state+json": {
+ "d90f304e9560472eacfbdd11e46765eb": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "HBoxModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_1c6246f15b654f4daa11c9bcf997b78c",
+ "IPY_MODEL_c2321b3bff6f490ca12040a20308f555",
+ "IPY_MODEL_b7feb522161f4cf4b7cc7c1a078ff12d"
+ ],
+ "layout": "IPY_MODEL_e2d368556e494ae7ae4e2e992af2cd4f"
+ }
+ },
+ "1c6246f15b654f4daa11c9bcf997b78c": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "HTMLModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_bbef741e76ec41b7ab7187b487a383df",
+ "placeholder": "",
+ "style": "IPY_MODEL_561f742d418d4721b0670cc8dd62e22c",
+ "value": "Downloading: 100%"
+ }
+ },
+ "c2321b3bff6f490ca12040a20308f555": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "FloatProgressModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "FloatProgressModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "ProgressView",
+ "bar_style": "success",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_872915dd1bb84f538c44e26badabafdd",
+ "max": 3271865,
+ "min": 0,
+ "orientation": "horizontal",
+ "style": "IPY_MODEL_d022575f1fa2446d891650897f187b4d",
+ "value": 3271865
+ }
+ },
+ "b7feb522161f4cf4b7cc7c1a078ff12d": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "HTMLModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_fdc393f3468c432aa0ada05e238a5436",
+ "placeholder": "",
+ "style": "IPY_MODEL_2c9362906e4b40189f16d14aa9a348da",
+ "value": " 3.27M/3.27M [00:01<00:00, 3.25MB/s]"
+ }
+ },
+ "e2d368556e494ae7ae4e2e992af2cd4f": {
+ "model_module": "@jupyter-widgets/base",
+ "model_name": "LayoutModel",
+ "model_module_version": "1.2.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "bbef741e76ec41b7ab7187b487a383df": {
+ "model_module": "@jupyter-widgets/base",
+ "model_name": "LayoutModel",
+ "model_module_version": "1.2.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "561f742d418d4721b0670cc8dd62e22c": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "DescriptionStyleModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "872915dd1bb84f538c44e26badabafdd": {
+ "model_module": "@jupyter-widgets/base",
+ "model_name": "LayoutModel",
+ "model_module_version": "1.2.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "d022575f1fa2446d891650897f187b4d": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "ProgressStyleModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "ProgressStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "bar_color": null,
+ "description_width": ""
+ }
+ },
+ "fdc393f3468c432aa0ada05e238a5436": {
+ "model_module": "@jupyter-widgets/base",
+ "model_name": "LayoutModel",
+ "model_module_version": "1.2.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "2c9362906e4b40189f16d14aa9a348da": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "DescriptionStyleModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "6010fc8daa7a44d5aec4b830ec2ebaa1": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "HBoxModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_7e0bb1b8d65249d3974200686b193be2",
+ "IPY_MODEL_ba98aa6d6a884e4ab8bbb5dfb5e4cf7a",
+ "IPY_MODEL_6526646be5ed415c84d1245b040e629b"
+ ],
+ "layout": "IPY_MODEL_24d31fc3576e43dd9f8301d2ef3a37ab"
+ }
+ },
+ "7e0bb1b8d65249d3974200686b193be2": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "HTMLModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_2918bfaadc8d4b1a9832522c40dfefb8",
+ "placeholder": "",
+ "style": "IPY_MODEL_a4bfdca35cc54dae8812720f1b276a08",
+ "value": "Downloading: 100%"
+ }
+ },
+ "ba98aa6d6a884e4ab8bbb5dfb5e4cf7a": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "FloatProgressModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "FloatProgressModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "ProgressView",
+ "bar_style": "success",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_e4901541199b45c6a18824627692fc39",
+ "max": 401,
+ "min": 0,
+ "orientation": "horizontal",
+ "style": "IPY_MODEL_f915cf874246446595206221e900b2fe",
+ "value": 401
+ }
+ },
+ "6526646be5ed415c84d1245b040e629b": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "HTMLModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_a9e388f22a9742aaaf538e22575c9433",
+ "placeholder": "",
+ "style": "IPY_MODEL_42f6c3db29d7484ba6b4f73590abd2f4",
+ "value": " 401/401 [00:00<00:00, 13.5kB/s]"
+ }
+ },
+ "24d31fc3576e43dd9f8301d2ef3a37ab": {
+ "model_module": "@jupyter-widgets/base",
+ "model_name": "LayoutModel",
+ "model_module_version": "1.2.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "2918bfaadc8d4b1a9832522c40dfefb8": {
+ "model_module": "@jupyter-widgets/base",
+ "model_name": "LayoutModel",
+ "model_module_version": "1.2.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "a4bfdca35cc54dae8812720f1b276a08": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "DescriptionStyleModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "e4901541199b45c6a18824627692fc39": {
+ "model_module": "@jupyter-widgets/base",
+ "model_name": "LayoutModel",
+ "model_module_version": "1.2.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "f915cf874246446595206221e900b2fe": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "ProgressStyleModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "ProgressStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "bar_color": null,
+ "description_width": ""
+ }
+ },
+ "a9e388f22a9742aaaf538e22575c9433": {
+ "model_module": "@jupyter-widgets/base",
+ "model_name": "LayoutModel",
+ "model_module_version": "1.2.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "42f6c3db29d7484ba6b4f73590abd2f4": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "DescriptionStyleModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "695ab5bbf30a4ab19df1f9f33469f314": {
+ "model_module": "nglview-js-widgets",
+ "model_name": "ColormakerRegistryModel",
+ "model_module_version": "3.0.1",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "nglview-js-widgets",
+ "_model_module_version": "3.0.1",
+ "_model_name": "ColormakerRegistryModel",
+ "_msg_ar": [],
+ "_msg_q": [],
+ "_ready": false,
+ "_view_count": null,
+ "_view_module": "nglview-js-widgets",
+ "_view_module_version": "3.0.1",
+ "_view_name": "ColormakerRegistryView",
+ "layout": "IPY_MODEL_eac6a8dcdc9d4335a2e51031793ead29"
+ }
+ },
+ "eac6a8dcdc9d4335a2e51031793ead29": {
+ "model_module": "@jupyter-widgets/base",
+ "model_name": "LayoutModel",
+ "model_module_version": "1.2.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "be446195da2b4ff2aec21ec5ff963a54": {
+ "model_module": "nglview-js-widgets",
+ "model_name": "NGLModel",
+ "model_module_version": "3.0.1",
+ "state": {
+ "_camera_orientation": [
+ -15.519693580202304,
+ -14.065056548036177,
+ -23.53197484807691,
+ 0,
+ -23.357853515109753,
+ 20.94055073042662,
+ 2.888695042134944,
+ 0,
+ 14.352363398292777,
+ 18.870825741878015,
+ -20.744689572909344,
+ 0,
+ 0.2724999189376831,
+ 0.6940000057220459,
+ -0.3734999895095825,
+ 1
+ ],
+ "_camera_str": "orthographic",
+ "_dom_classes": [],
+ "_gui_theme": null,
+ "_ibtn_fullscreen": "IPY_MODEL_2489b5e5648541fbbdceadb05632a050",
+ "_igui": null,
+ "_iplayer": "IPY_MODEL_01e0ba4e5da04914b4652b8d58565d7b",
+ "_model_module": "nglview-js-widgets",
+ "_model_module_version": "3.0.1",
+ "_model_name": "NGLModel",
+ "_ngl_color_dict": {},
+ "_ngl_coordinate_resource": {},
+ "_ngl_full_stage_parameters": {
+ "impostor": true,
+ "quality": "medium",
+ "workerDefault": true,
+ "sampleLevel": 0,
+ "backgroundColor": "white",
+ "rotateSpeed": 2,
+ "zoomSpeed": 1.2,
+ "panSpeed": 1,
+ "clipNear": 0,
+ "clipFar": 100,
+ "clipDist": 10,
+ "fogNear": 50,
+ "fogFar": 100,
+ "cameraFov": 40,
+ "cameraEyeSep": 0.3,
+ "cameraType": "perspective",
+ "lightColor": 14540253,
+ "lightIntensity": 1,
+ "ambientColor": 14540253,
+ "ambientIntensity": 0.2,
+ "hoverTimeout": 0,
+ "tooltip": true,
+ "mousePreset": "default"
+ },
+ "_ngl_msg_archive": [
+ {
+ "target": "Stage",
+ "type": "call_method",
+ "methodName": "loadFile",
+ "reconstruc_color_scheme": false,
+ "args": [
+ {
+ "type": "blob",
+ "data": "HETATM 1 C1 UNL 1 -0.025 3.128 2.316 1.00 0.00 C \nHETATM 2 H1 UNL 1 0.183 3.657 2.823 1.00 0.00 H \nHETATM 3 C2 UNL 1 0.590 3.559 0.963 1.00 0.00 C \nHETATM 4 C3 UNL 1 0.056 4.479 0.406 1.00 0.00 C \nHETATM 5 C4 UNL 1 -0.219 4.802 -1.065 1.00 0.00 C \nHETATM 6 H2 UNL 1 0.686 4.431 -1.575 1.00 0.00 H \nHETATM 7 H3 UNL 1 -0.524 5.217 -1.274 1.00 0.00 H \nHETATM 8 C5 UNL 1 -1.284 3.766 -1.342 1.00 0.00 C \nHETATM 9 N1 UNL 1 -1.073 2.494 -0.580 1.00 0.00 N \nHETATM 10 C6 UNL 1 -1.909 1.494 -0.964 1.00 0.00 C \nHETATM 11 O1 UNL 1 -2.487 1.531 -2.092 1.00 0.00 O \nHETATM 12 C7 UNL 1 -2.232 0.242 -0.130 1.00 0.00 C \nHETATM 13 C8 UNL 1 -2.161 -1.057 -1.037 1.00 0.00 C \nHETATM 14 C9 UNL 1 -0.744 -1.111 -1.610 1.00 0.00 C \nHETATM 15 N2 UNL 1 0.290 -0.917 -0.628 1.00 0.00 N \nHETATM 16 S1 UNL 1 1.717 -1.597 -0.914 1.00 0.00 S \nHETATM 17 O2 UNL 1 1.960 -1.671 -2.338 1.00 0.00 O \nHETATM 18 O3 UNL 1 2.713 -0.968 -0.082 1.00 0.00 O \nHETATM 19 C10 UNL 1 1.425 -3.170 -0.345 1.00 0.00 C \nHETATM 20 C11 UNL 1 1.225 -4.400 -1.271 1.00 0.00 C \nHETATM 21 C12 UNL 1 1.314 -5.913 -0.895 1.00 0.00 C \nHETATM 22 C13 UNL 1 1.823 -6.229 0.386 1.00 0.00 C \nHETATM 23 C14 UNL 1 2.031 -5.110 1.365 1.00 0.00 C \nHETATM 24 N3 UNL 1 1.850 -5.267 2.712 1.00 0.00 N \nHETATM 25 O4 UNL 1 1.382 -4.029 3.126 1.00 0.00 O \nHETATM 26 N4 UNL 1 1.300 -3.023 2.154 1.00 0.00 N \nHETATM 27 C15 UNL 1 1.731 -3.672 1.032 1.00 0.00 C \nHETATM 28 H4 UNL 1 2.380 -6.874 0.436 1.00 0.00 H \nHETATM 29 H5 UNL 1 0.704 -6.526 -1.420 1.00 0.00 H \nHETATM 30 H6 UNL 1 1.144 -4.035 -2.291 1.00 0.00 H \nHETATM 31 C16 UNL 1 0.044 -0.371 0.685 1.00 0.00 C \nHETATM 32 C17 UNL 1 -1.352 -0.045 1.077 1.00 0.00 C \nHETATM 33 H7 UNL 1 -1.395 0.770 1.768 1.00 0.00 H \nHETATM 34 H8 UNL 1 -1.792 -0.941 1.582 1.00 0.00 H \nHETATM 35 H9 UNL 1 0.583 -1.035 1.393 1.00 0.00 H \nHETATM 36 H10 UNL 1 0.664 0.613 0.663 1.00 0.00 H \nHETATM 37 H11 UNL 1 -0.631 -0.267 -2.335 1.00 0.00 H \nHETATM 38 H12 UNL 1 -0.571 -2.046 -2.098 1.00 0.00 H \nHETATM 39 H13 UNL 1 -2.872 -0.992 -1.826 1.00 0.00 H \nHETATM 40 H14 UNL 1 -2.370 -1.924 -0.444 1.00 0.00 H \nHETATM 41 H15 UNL 1 -3.258 0.364 0.197 1.00 0.00 H \nHETATM 42 C18 UNL 1 0.276 2.337 -0.078 1.00 0.00 C \nHETATM 43 H16 UNL 1 0.514 1.371 0.252 1.00 0.00 H \nHETATM 44 H17 UNL 1 0.988 2.413 -0.949 1.00 0.00 H \nHETATM 45 H18 UNL 1 -1.349 3.451 -2.379 1.00 0.00 H \nHETATM 46 H19 UNL 1 -2.224 4.055 -0.958 1.00 0.00 H \nHETATM 47 H20 UNL 1 0.793 5.486 0.669 1.00 0.00 H \nHETATM 48 H21 UNL 1 -0.849 4.974 0.937 1.00 0.00 H \nHETATM 49 H22 UNL 1 1.667 3.431 1.070 1.00 0.00 H \nHETATM 50 H23 UNL 1 0.379 2.143 2.689 1.00 0.00 H \nHETATM 51 H24 UNL 1 -1.094 2.983 2.223 1.00 0.00 H \nCONECT 1 2 3 50 51\nCONECT 3 4 42 49\nCONECT 4 5 47 48\nCONECT 5 6 7 8\nCONECT 8 9 45 46\nCONECT 9 10 42\nCONECT 10 11 11 12\nCONECT 12 13 32 41\nCONECT 13 14 39 40\nCONECT 14 15 37 38\nCONECT 15 16 31\nCONECT 16 17 17 18 18\nCONECT 16 19\nCONECT 19 20 20 27\nCONECT 20 21 30\nCONECT 21 22 22 29\nCONECT 22 23 28\nCONECT 23 24 24 27\nCONECT 24 25\nCONECT 25 26\nCONECT 26 27 27\nCONECT 31 32 35 36\nCONECT 32 33 34\nCONECT 42 43 44\nEND\n",
+ "binary": false
+ }
+ ],
+ "kwargs": {
+ "defaultRepresentation": true,
+ "ext": "pdb"
+ }
+ }
+ ],
+ "_ngl_original_stage_parameters": {
+ "impostor": true,
+ "quality": "medium",
+ "workerDefault": true,
+ "sampleLevel": 0,
+ "backgroundColor": "white",
+ "rotateSpeed": 2,
+ "zoomSpeed": 1.2,
+ "panSpeed": 1,
+ "clipNear": 0,
+ "clipFar": 100,
+ "clipDist": 10,
+ "fogNear": 50,
+ "fogFar": 100,
+ "cameraFov": 40,
+ "cameraEyeSep": 0.3,
+ "cameraType": "perspective",
+ "lightColor": 14540253,
+ "lightIntensity": 1,
+ "ambientColor": 14540253,
+ "ambientIntensity": 0.2,
+ "hoverTimeout": 0,
+ "tooltip": true,
+ "mousePreset": "default"
+ },
+ "_ngl_repr_dict": {
+ "0": {
+ "0": {
+ "type": "ball+stick",
+ "params": {
+ "lazy": false,
+ "visible": true,
+ "quality": "high",
+ "sphereDetail": 2,
+ "radialSegments": 20,
+ "openEnded": true,
+ "disableImpostor": false,
+ "aspectRatio": 1.5,
+ "lineOnly": false,
+ "cylinderOnly": false,
+ "multipleBond": "off",
+ "bondScale": 0.3,
+ "bondSpacing": 0.75,
+ "linewidth": 2,
+ "radiusType": "size",
+ "radiusData": {},
+ "radiusSize": 0.15,
+ "radiusScale": 2,
+ "assembly": "default",
+ "defaultAssembly": "",
+ "clipNear": 0,
+ "clipRadius": 0,
+ "clipCenter": {
+ "x": 0,
+ "y": 0,
+ "z": 0
+ },
+ "flatShaded": false,
+ "opacity": 1,
+ "depthWrite": true,
+ "side": "double",
+ "wireframe": false,
+ "colorScheme": "element",
+ "colorScale": "",
+ "colorReverse": false,
+ "colorValue": 9474192,
+ "colorMode": "hcl",
+ "roughness": 0.4,
+ "metalness": 0,
+ "diffuse": 16777215,
+ "diffuseInterior": false,
+ "useInteriorColor": true,
+ "interiorColor": 2236962,
+ "interiorDarkening": 0,
+ "matrix": {
+ "elements": [
+ 1,
+ 0,
+ 0,
+ 0,
+ 0,
+ 1,
+ 0,
+ 0,
+ 0,
+ 0,
+ 1,
+ 0,
+ 0,
+ 0,
+ 0,
+ 1
+ ]
+ },
+ "disablePicking": false,
+ "sele": ""
+ }
+ }
+ },
+ "1": {
+ "0": {
+ "type": "ball+stick",
+ "params": {
+ "lazy": false,
+ "visible": true,
+ "quality": "high",
+ "sphereDetail": 2,
+ "radialSegments": 20,
+ "openEnded": true,
+ "disableImpostor": false,
+ "aspectRatio": 1.5,
+ "lineOnly": false,
+ "cylinderOnly": false,
+ "multipleBond": "off",
+ "bondScale": 0.3,
+ "bondSpacing": 0.75,
+ "linewidth": 2,
+ "radiusType": "size",
+ "radiusData": {},
+ "radiusSize": 0.15,
+ "radiusScale": 2,
+ "assembly": "default",
+ "defaultAssembly": "",
+ "clipNear": 0,
+ "clipRadius": 0,
+ "clipCenter": {
+ "x": 0,
+ "y": 0,
+ "z": 0
+ },
+ "flatShaded": false,
+ "opacity": 1,
+ "depthWrite": true,
+ "side": "double",
+ "wireframe": false,
+ "colorScheme": "element",
+ "colorScale": "",
+ "colorReverse": false,
+ "colorValue": 9474192,
+ "colorMode": "hcl",
+ "roughness": 0.4,
+ "metalness": 0,
+ "diffuse": 16777215,
+ "diffuseInterior": false,
+ "useInteriorColor": true,
+ "interiorColor": 2236962,
+ "interiorDarkening": 0,
+ "matrix": {
+ "elements": [
+ 1,
+ 0,
+ 0,
+ 0,
+ 0,
+ 1,
+ 0,
+ 0,
+ 0,
+ 0,
+ 1,
+ 0,
+ 0,
+ 0,
+ 0,
+ 1
+ ]
+ },
+ "disablePicking": false,
+ "sele": ""
+ }
+ }
+ }
+ },
+ "_ngl_serialize": false,
+ "_ngl_version": "",
+ "_ngl_view_id": [
+ "FB989FD1-5B9C-446B-8914-6B58AF85446D"
+ ],
+ "_player_dict": {},
+ "_scene_position": {},
+ "_scene_rotation": {},
+ "_synced_model_ids": [],
+ "_synced_repr_model_ids": [],
+ "_view_count": null,
+ "_view_height": "",
+ "_view_module": "nglview-js-widgets",
+ "_view_module_version": "3.0.1",
+ "_view_name": "NGLView",
+ "_view_width": "",
+ "background": "white",
+ "frame": 0,
+ "gui_style": null,
+ "layout": "IPY_MODEL_c6596896148b4a8a9c57963b67c7782f",
+ "max_frame": 0,
+ "n_components": 2,
+ "picked": {}
+ }
+ },
+ "c6596896148b4a8a9c57963b67c7782f": {
+ "model_module": "@jupyter-widgets/base",
+ "model_name": "LayoutModel",
+ "model_module_version": "1.2.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "2489b5e5648541fbbdceadb05632a050": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "ButtonModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "ButtonModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "ButtonView",
+ "button_style": "",
+ "description": "",
+ "disabled": false,
+ "icon": "compress",
+ "layout": "IPY_MODEL_abce2a80e6304df3899109c6d6cac199",
+ "style": "IPY_MODEL_65195cb7a4134f4887e9dd19f3676462",
+ "tooltip": ""
+ }
+ },
+ "01e0ba4e5da04914b4652b8d58565d7b": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "HBoxModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_e5c0d75eb5e1447abd560c8f2c6017e1",
+ "IPY_MODEL_5146907ef6764654ad7d598baebc8b58"
+ ],
+ "layout": "IPY_MODEL_144ec959b7604a2cabb5ca46ae5e5379"
+ }
+ },
+ "c30e6c2f3e2a44dbbb3d63bd519acaa4": {
+ "model_module": "@jupyter-widgets/base",
+ "model_name": "LayoutModel",
+ "model_module_version": "1.2.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "f31c6e40e9b2466a9064a2669933ecd5": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "DescriptionStyleModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "19308ccac642498ab8b58462e3f1b0bb": {
+ "model_module": "@jupyter-widgets/base",
+ "model_name": "LayoutModel",
+ "model_module_version": "1.2.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "4a081cdc2ec3421ca79dd933b7e2b0c4": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "SliderStyleModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "SliderStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": "",
+ "handle_color": null
+ }
+ },
+ "e5c0d75eb5e1447abd560c8f2c6017e1": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "PlayModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "PlayModel",
+ "_playing": false,
+ "_repeat": false,
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "PlayView",
+ "description": "",
+ "description_tooltip": null,
+ "disabled": false,
+ "interval": 100,
+ "layout": "IPY_MODEL_c30e6c2f3e2a44dbbb3d63bd519acaa4",
+ "max": 0,
+ "min": 0,
+ "show_repeat": true,
+ "step": 1,
+ "style": "IPY_MODEL_f31c6e40e9b2466a9064a2669933ecd5",
+ "value": 0
+ }
+ },
+ "5146907ef6764654ad7d598baebc8b58": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "IntSliderModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "IntSliderModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "IntSliderView",
+ "continuous_update": true,
+ "description": "",
+ "description_tooltip": null,
+ "disabled": false,
+ "layout": "IPY_MODEL_19308ccac642498ab8b58462e3f1b0bb",
+ "max": 0,
+ "min": 0,
+ "orientation": "horizontal",
+ "readout": true,
+ "readout_format": "d",
+ "step": 1,
+ "style": "IPY_MODEL_4a081cdc2ec3421ca79dd933b7e2b0c4",
+ "value": 0
+ }
+ },
+ "144ec959b7604a2cabb5ca46ae5e5379": {
+ "model_module": "@jupyter-widgets/base",
+ "model_name": "LayoutModel",
+ "model_module_version": "1.2.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "abce2a80e6304df3899109c6d6cac199": {
+ "model_module": "@jupyter-widgets/base",
+ "model_name": "LayoutModel",
+ "model_module_version": "1.2.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": "34px"
+ }
+ },
+ "65195cb7a4134f4887e9dd19f3676462": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "ButtonStyleModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "ButtonStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "button_color": null,
+ "font_weight": ""
+ }
+ }
+ }
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}
\ No newline at end of file
diff --git a/diffusers/examples/research_projects/gligen/README.md b/diffusers/examples/research_projects/gligen/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..fa922617d984ff1f22c7d8d9f35f396adb85c33a
--- /dev/null
+++ b/diffusers/examples/research_projects/gligen/README.md
@@ -0,0 +1,156 @@
+# GLIGEN: Open-Set Grounded Text-to-Image Generation
+
+These scripts contain the code to prepare the grounding data and train the GLIGEN model on COCO dataset.
+
+### Install the requirements
+
+```bash
+conda create -n diffusers python==3.10
+conda activate diffusers
+pip install -r requirements.txt
+```
+
+And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
+
+```bash
+accelerate config
+```
+
+Or for a default accelerate configuration without answering questions about your environment
+
+```bash
+accelerate config default
+```
+
+Or if your environment doesn't support an interactive shell e.g. a notebook
+
+```python
+from accelerate.utils import write_basic_config
+
+write_basic_config()
+```
+
+### Prepare the training data
+
+If you want to make your own grounding data, you need to install the requirements.
+
+I used [RAM](https://github.com/xinyu1205/recognize-anything) to tag
+images, [Grounding DINO](https://github.com/IDEA-Research/GroundingDINO/issues?q=refer) to detect objects,
+and [BLIP2](https://huggingface.co/docs/transformers/en/model_doc/blip-2) to caption instances.
+
+Only RAM needs to be installed manually:
+
+```bash
+pip install git+https://github.com/xinyu1205/recognize-anything.git --no-deps
+```
+
+Download the pre-trained model:
+
+```bash
+huggingface-cli download --resume-download xinyu1205/recognize_anything_model ram_swin_large_14m.pth
+huggingface-cli download --resume-download IDEA-Research/grounding-dino-base
+huggingface-cli download --resume-download Salesforce/blip2-flan-t5-xxl
+huggingface-cli download --resume-download clip-vit-large-patch14
+huggingface-cli download --resume-download masterful/gligen-1-4-generation-text-box
+```
+
+Make the training data on 8 GPUs:
+
+```bash
+torchrun --master_port 17673 --nproc_per_node=8 make_datasets.py \
+ --data_root /mnt/workspace/workgroup/zhizhonghuang/dataset/COCO/train2017 \
+ --save_root /root/gligen_data \
+ --ram_checkpoint /root/.cache/huggingface/hub/models--xinyu1205--recognize_anything_model/snapshots/ebc52dc741e86466202a5ab8ab22eae6e7d48bf1/ram_swin_large_14m.pth
+```
+
+You can download the COCO training data from
+
+```bash
+huggingface-cli download --resume-download Hzzone/GLIGEN_COCO coco_train2017.pth
+```
+
+It's in the format of
+
+```json
+[
+ ...
+ {
+ 'file_path': Path,
+ 'annos': [
+ {
+ 'caption': Instance
+ Caption,
+ 'bbox': bbox
+ in
+ xyxy,
+ 'text_embeddings_before_projection': CLIP
+ text
+ embedding
+ before
+ linear
+ projection
+ }
+ ]
+ }
+ ...
+]
+```
+
+### Training commands
+
+The training script is heavily based
+on https://github.com/huggingface/diffusers/blob/main/examples/controlnet/train_controlnet.py
+
+```bash
+accelerate launch train_gligen_text.py \
+ --data_path /root/data/zhizhonghuang/coco_train2017.pth \
+ --image_path /mnt/workspace/workgroup/zhizhonghuang/dataset/COCO/train2017 \
+ --train_batch_size 8 \
+ --max_train_steps 100000 \
+ --checkpointing_steps 1000 \
+ --checkpoints_total_limit 10 \
+ --learning_rate 5e-5 \
+ --dataloader_num_workers 16 \
+ --mixed_precision fp16 \
+ --report_to wandb \
+ --tracker_project_name gligen \
+ --output_dir /root/data/zhizhonghuang/ckpt/GLIGEN_Text_Retrain_COCO
+```
+
+I trained the model on 8 A100 GPUs for about 11 hours (at least 24GB GPU memory). The generated images will follow the
+layout possibly at 50k iterations.
+
+Note that although the pre-trained GLIGEN model has been loaded, the parameters of `fuser` and `position_net` have been reset (see line 420 in `train_gligen_text.py`)
+
+The trained model can be downloaded from
+
+```bash
+huggingface-cli download --resume-download Hzzone/GLIGEN_COCO config.json diffusion_pytorch_model.safetensors
+```
+
+You can run `demo.ipynb` to visualize the generated images.
+
+Example prompts:
+
+```python
+prompt = 'A realistic image of landscape scene depicting a green car parking on the left of a blue truck, with a red air balloon and a bird in the sky'
+boxes = [[0.041015625, 0.548828125, 0.453125, 0.859375],
+ [0.525390625, 0.552734375, 0.93359375, 0.865234375],
+ [0.12890625, 0.015625, 0.412109375, 0.279296875],
+ [0.578125, 0.08203125, 0.857421875, 0.27734375]]
+gligen_phrases = ['a green car', 'a blue truck', 'a red air balloon', 'a bird']
+```
+
+Example images:
+
+
+### Citation
+
+```
+@article{li2023gligen,
+ title={GLIGEN: Open-Set Grounded Text-to-Image Generation},
+ author={Li, Yuheng and Liu, Haotian and Wu, Qingyang and Mu, Fangzhou and Yang, Jianwei and Gao, Jianfeng and Li, Chunyuan and Lee, Yong Jae},
+ journal={CVPR},
+ year={2023}
+}
+```
\ No newline at end of file
diff --git a/diffusers/examples/research_projects/gligen/dataset.py b/diffusers/examples/research_projects/gligen/dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b851bc6ded7d068c2920a66b3f1fd0e6916541d
--- /dev/null
+++ b/diffusers/examples/research_projects/gligen/dataset.py
@@ -0,0 +1,110 @@
+import os
+import random
+
+import torch
+import torchvision.transforms as transforms
+from PIL import Image
+
+
+def recalculate_box_and_verify_if_valid(x, y, w, h, image_size, original_image_size, min_box_size):
+ scale = image_size / min(original_image_size)
+ crop_y = (original_image_size[1] * scale - image_size) // 2
+ crop_x = (original_image_size[0] * scale - image_size) // 2
+ x0 = max(x * scale - crop_x, 0)
+ y0 = max(y * scale - crop_y, 0)
+ x1 = min((x + w) * scale - crop_x, image_size)
+ y1 = min((y + h) * scale - crop_y, image_size)
+ if (x1 - x0) * (y1 - y0) / (image_size * image_size) < min_box_size:
+ return False, (None, None, None, None)
+ return True, (x0, y0, x1, y1)
+
+
+class COCODataset(torch.utils.data.Dataset):
+ def __init__(
+ self,
+ data_path,
+ image_path,
+ image_size=512,
+ min_box_size=0.01,
+ max_boxes_per_data=8,
+ tokenizer=None,
+ ):
+ super().__init__()
+ self.min_box_size = min_box_size
+ self.max_boxes_per_data = max_boxes_per_data
+ self.image_size = image_size
+ self.image_path = image_path
+ self.tokenizer = tokenizer
+ self.transforms = transforms.Compose(
+ [
+ transforms.Resize(image_size, interpolation=transforms.InterpolationMode.BILINEAR),
+ transforms.CenterCrop(image_size),
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+ )
+
+ self.data_list = torch.load(data_path, map_location="cpu")
+
+ def __getitem__(self, index):
+ if self.max_boxes_per_data > 99:
+ assert False, "Are you sure setting such large number of boxes per image?"
+
+ out = {}
+
+ data = self.data_list[index]
+ image = Image.open(os.path.join(self.image_path, data["file_path"])).convert("RGB")
+ original_image_size = image.size
+ out["pixel_values"] = self.transforms(image)
+
+ annos = data["annos"]
+
+ areas, valid_annos = [], []
+ for anno in annos:
+ # x, y, w, h = anno['bbox']
+ x0, y0, x1, y1 = anno["bbox"]
+ x, y, w, h = x0, y0, x1 - x0, y1 - y0
+ valid, (x0, y0, x1, y1) = recalculate_box_and_verify_if_valid(
+ x, y, w, h, self.image_size, original_image_size, self.min_box_size
+ )
+ if valid:
+ anno["bbox"] = [x0, y0, x1, y1]
+ areas.append((x1 - x0) * (y1 - y0))
+ valid_annos.append(anno)
+
+ # Sort according to area and choose the largest N objects
+ wanted_idxs = torch.tensor(areas).sort(descending=True)[1]
+ wanted_idxs = wanted_idxs[: self.max_boxes_per_data]
+ valid_annos = [valid_annos[i] for i in wanted_idxs]
+
+ out["boxes"] = torch.zeros(self.max_boxes_per_data, 4)
+ out["masks"] = torch.zeros(self.max_boxes_per_data)
+ out["text_embeddings_before_projection"] = torch.zeros(self.max_boxes_per_data, 768)
+
+ for i, anno in enumerate(valid_annos):
+ out["boxes"][i] = torch.tensor(anno["bbox"]) / self.image_size
+ out["masks"][i] = 1
+ out["text_embeddings_before_projection"][i] = anno["text_embeddings_before_projection"]
+
+ prob_drop_boxes = 0.1
+ if random.random() < prob_drop_boxes:
+ out["masks"][:] = 0
+
+ caption = random.choice(data["captions"])
+
+ prob_drop_captions = 0.5
+ if random.random() < prob_drop_captions:
+ caption = ""
+ caption = self.tokenizer(
+ caption,
+ max_length=self.tokenizer.model_max_length,
+ padding="max_length",
+ truncation=True,
+ return_tensors="pt",
+ )
+ out["caption"] = caption
+
+ return out
+
+ def __len__(self):
+ return len(self.data_list)
diff --git a/diffusers/examples/research_projects/gligen/demo.ipynb b/diffusers/examples/research_projects/gligen/demo.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..571f1a0323a2bb99c662f5d6955836e04d91677b
--- /dev/null
+++ b/diffusers/examples/research_projects/gligen/demo.ipynb
@@ -0,0 +1,201 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "The autoreload extension is already loaded. To reload it, use:\n",
+ " %reload_ext autoreload\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/root/miniconda/envs/densecaption/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
+ " from .autonotebook import tqdm as notebook_tqdm\n"
+ ]
+ }
+ ],
+ "source": [
+ "%load_ext autoreload\n",
+ "%autoreload 2\n",
+ "\n",
+ "import torch\n",
+ "from diffusers import StableDiffusionGLIGENTextImagePipeline, StableDiffusionGLIGENPipeline"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "import diffusers\n",
+ "from diffusers import (\n",
+ " AutoencoderKL,\n",
+ " DDPMScheduler,\n",
+ " UNet2DConditionModel,\n",
+ " UniPCMultistepScheduler,\n",
+ " EulerDiscreteScheduler,\n",
+ ")\n",
+ "from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer\n",
+ "# pretrained_model_name_or_path = 'masterful/gligen-1-4-generation-text-box'\n",
+ "\n",
+ "pretrained_model_name_or_path = '/root/data/zhizhonghuang/checkpoints/models--masterful--gligen-1-4-generation-text-box/snapshots/d2820dc1e9ba6ca082051ce79cfd3eb468ae2c83'\n",
+ "\n",
+ "tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder=\"tokenizer\")\n",
+ "noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder=\"scheduler\")\n",
+ "text_encoder = CLIPTextModel.from_pretrained(\n",
+ " pretrained_model_name_or_path, subfolder=\"text_encoder\"\n",
+ ")\n",
+ "vae = AutoencoderKL.from_pretrained(\n",
+ " pretrained_model_name_or_path, subfolder=\"vae\"\n",
+ ")\n",
+ "# unet = UNet2DConditionModel.from_pretrained(\n",
+ "# pretrained_model_name_or_path, subfolder=\"unet\"\n",
+ "# )\n",
+ "\n",
+ "noise_scheduler = EulerDiscreteScheduler.from_config(noise_scheduler.config)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "unet = UNet2DConditionModel.from_pretrained(\n",
+ " '/root/data/zhizhonghuang/ckpt/GLIGEN_Text_Retrain_COCO'\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "You have disabled the safety checker for by passing `safety_checker=None`. Ensure that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered results in services or applications open to the public. Both the diffusers team and Hugging Face strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling it only for use-cases that involve analyzing network behavior or auditing its results. For more information, please have a look at https://github.com/huggingface/diffusers/pull/254 .\n"
+ ]
+ }
+ ],
+ "source": [
+ "pipe = StableDiffusionGLIGENPipeline(\n",
+ " vae,\n",
+ " text_encoder,\n",
+ " tokenizer,\n",
+ " unet,\n",
+ " noise_scheduler,\n",
+ " safety_checker=None,\n",
+ " feature_extractor=None,\n",
+ ")\n",
+ "pipe = pipe.to(\"cuda\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# prompt = 'A realistic image of landscape scene depicting a green car parking on the left of a blue truck, with a red air balloon and a bird in the sky'\n",
+ "# gen_boxes = [('a green car', [21, 281, 211, 159]), ('a blue truck', [269, 283, 209, 160]), ('a red air balloon', [66, 8, 145, 135]), ('a bird', [296, 42, 143, 100])]\n",
+ "\n",
+ "# prompt = 'A realistic top-down view of a wooden table with two apples on it'\n",
+ "# gen_boxes = [('a wooden table', [20, 148, 472, 216]), ('an apple', [150, 226, 100, 100]), ('an apple', [280, 226, 100, 100])]\n",
+ "\n",
+ "# prompt = 'A realistic scene of three skiers standing in a line on the snow near a palm tree'\n",
+ "# gen_boxes = [('a skier', [5, 152, 139, 168]), ('a skier', [278, 192, 121, 158]), ('a skier', [148, 173, 124, 155]), ('a palm tree', [404, 105, 103, 251])]\n",
+ "\n",
+ "prompt = 'An oil painting of a pink dolphin jumping on the left of a steam boat on the sea'\n",
+ "gen_boxes = [('a steam boat', [232, 225, 257, 149]), ('a jumping pink dolphin', [21, 249, 189, 123])]\n",
+ "\n",
+ "import numpy as np\n",
+ "\n",
+ "boxes = np.array([x[1] for x in gen_boxes])\n",
+ "boxes = boxes / 512\n",
+ "boxes[:, 2] = boxes[:, 0] + boxes[:, 2]\n",
+ "boxes[:, 3] = boxes[:, 1] + boxes[:, 3]\n",
+ "boxes = boxes.tolist()\n",
+ "gligen_phrases = [x[0] for x in gen_boxes]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/root/miniconda/envs/densecaption/lib/python3.11/site-packages/diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py:683: FutureWarning: Accessing config attribute `in_channels` directly via 'UNet2DConditionModel' object attribute is deprecated. Please access 'in_channels' over 'UNet2DConditionModel's config object instead, e.g. 'unet.config.in_channels'.\n",
+ " num_channels_latents = self.unet.in_channels\n",
+ "/root/miniconda/envs/densecaption/lib/python3.11/site-packages/diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py:716: FutureWarning: Accessing config attribute `cross_attention_dim` directly via 'UNet2DConditionModel' object attribute is deprecated. Please access 'cross_attention_dim' over 'UNet2DConditionModel's config object instead, e.g. 'unet.config.cross_attention_dim'.\n",
+ " max_objs, self.unet.cross_attention_dim, device=device, dtype=self.text_encoder.dtype\n",
+ "100%|██████████| 50/50 [01:21<00:00, 1.64s/it]\n"
+ ]
+ }
+ ],
+ "source": [
+ "images = pipe(\n",
+ " prompt=prompt,\n",
+ " gligen_phrases=gligen_phrases,\n",
+ " gligen_boxes=boxes,\n",
+ " gligen_scheduled_sampling_beta=1.0,\n",
+ " output_type=\"pil\",\n",
+ " num_inference_steps=50,\n",
+ " negative_prompt=\"artifacts, blurry, smooth texture, bad quality, distortions, unrealistic, distorted image, bad proportions, duplicate\",\n",
+ " num_images_per_prompt=16,\n",
+ ").images"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "diffusers.utils.make_image_grid(images, 4, len(images)//4)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "densecaption",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.11.9"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/diffusers/examples/research_projects/gligen/make_datasets.py b/diffusers/examples/research_projects/gligen/make_datasets.py
new file mode 100644
index 0000000000000000000000000000000000000000..37950a7b73ca2229b8e130035380d0c6a1e69129
--- /dev/null
+++ b/diffusers/examples/research_projects/gligen/make_datasets.py
@@ -0,0 +1,119 @@
+import argparse
+import os
+import random
+
+import torch
+import torchvision
+import torchvision.transforms as TS
+from PIL import Image
+from ram import inference_ram
+from ram.models import ram
+from tqdm import tqdm
+from transformers import (
+ AutoModelForZeroShotObjectDetection,
+ AutoProcessor,
+ Blip2ForConditionalGeneration,
+ Blip2Processor,
+ CLIPTextModel,
+ CLIPTokenizer,
+)
+
+
+torch.autograd.set_grad_enabled(False)
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser("Caption Generation script", add_help=False)
+ parser.add_argument("--data_root", type=str, required=True, help="path to COCO")
+ parser.add_argument("--save_root", type=str, required=True, help="path to save")
+ parser.add_argument("--ram_checkpoint", type=str, required=True, help="path to save")
+ args = parser.parse_args()
+
+ # ram_checkpoint = '/root/.cache/huggingface/hub/models--xinyu1205--recognize_anything_model/snapshots/ebc52dc741e86466202a5ab8ab22eae6e7d48bf1/ram_swin_large_14m.pth'
+ # data_root = '/mnt/workspace/workgroup/zhizhonghuang/dataset/COCO/train2017'
+ # save_root = '/root/gligen_data'
+ box_threshold = 0.25
+ text_threshold = 0.2
+
+ import torch.distributed as dist
+
+ dist.init_process_group(backend="nccl", init_method="env://")
+ local_rank = torch.distributed.get_rank() % torch.cuda.device_count()
+ device = f"cuda:{local_rank}"
+ torch.cuda.set_device(local_rank)
+
+ ram_model = ram(pretrained=args.ram_checkpoint, image_size=384, vit="swin_l").cuda().eval()
+ ram_processor = TS.Compose(
+ [TS.Resize((384, 384)), TS.ToTensor(), TS.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]
+ )
+
+ grounding_dino_processor = AutoProcessor.from_pretrained("IDEA-Research/grounding-dino-base")
+ grounding_dino_model = AutoModelForZeroShotObjectDetection.from_pretrained(
+ "IDEA-Research/grounding-dino-base"
+ ).cuda()
+
+ blip2_processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xxl")
+ blip2_model = Blip2ForConditionalGeneration.from_pretrained(
+ "Salesforce/blip2-flan-t5-xxl", torch_dtype=torch.float16
+ ).cuda()
+
+ clip_text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").cuda()
+ clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
+
+ image_paths = [os.path.join(args.data_root, x) for x in os.listdir(args.data_root)]
+ random.shuffle(image_paths)
+
+ for image_path in tqdm.tqdm(image_paths):
+ pth_path = os.path.join(args.save_root, os.path.basename(image_path))
+ if os.path.exists(pth_path):
+ continue
+
+ sample = {"file_path": os.path.basename(image_path), "annos": []}
+
+ raw_image = Image.open(image_path).convert("RGB")
+
+ res = inference_ram(ram_processor(raw_image).unsqueeze(0).cuda(), ram_model)
+
+ text = res[0].replace(" |", ".")
+
+ inputs = grounding_dino_processor(images=raw_image, text=text, return_tensors="pt")
+ inputs = {k: v.cuda() for k, v in inputs.items()}
+ outputs = grounding_dino_model(**inputs)
+
+ results = grounding_dino_processor.post_process_grounded_object_detection(
+ outputs,
+ inputs["input_ids"],
+ box_threshold=box_threshold,
+ text_threshold=text_threshold,
+ target_sizes=[raw_image.size[::-1]],
+ )
+ boxes = results[0]["boxes"]
+ labels = results[0]["labels"]
+ scores = results[0]["scores"]
+ indices = torchvision.ops.nms(boxes, scores, 0.5)
+ boxes = boxes[indices]
+ category_names = [labels[i] for i in indices]
+
+ for i, bbox in enumerate(boxes):
+ bbox = bbox.tolist()
+ inputs = blip2_processor(images=raw_image.crop(bbox), return_tensors="pt")
+ inputs = {k: v.cuda().to(torch.float16) for k, v in inputs.items()}
+ outputs = blip2_model.generate(**inputs)
+ caption = blip2_processor.decode(outputs[0], skip_special_tokens=True)
+ inputs = clip_tokenizer(
+ caption,
+ padding="max_length",
+ max_length=clip_tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ inputs = {k: v.cuda() for k, v in inputs.items()}
+ text_embeddings_before_projection = clip_text_encoder(**inputs).pooler_output.squeeze(0)
+
+ sample["annos"].append(
+ {
+ "caption": caption,
+ "bbox": bbox,
+ "text_embeddings_before_projection": text_embeddings_before_projection,
+ }
+ )
+ torch.save(sample, pth_path)
diff --git a/diffusers/examples/research_projects/gligen/requirements.txt b/diffusers/examples/research_projects/gligen/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..b0dc16b7578cf84956329750d1f2146f3655f6e7
--- /dev/null
+++ b/diffusers/examples/research_projects/gligen/requirements.txt
@@ -0,0 +1,11 @@
+accelerate>=0.16.0
+torchvision
+transformers>=4.25.1
+ftfy
+tensorboard
+Jinja2
+diffusers
+scipy
+timm
+fairscale
+wandb
\ No newline at end of file
diff --git a/diffusers/examples/research_projects/gligen/train_gligen_text.py b/diffusers/examples/research_projects/gligen/train_gligen_text.py
new file mode 100644
index 0000000000000000000000000000000000000000..6645d19b73e6d761b9e752454380aa4bcfcd9ecd
--- /dev/null
+++ b/diffusers/examples/research_projects/gligen/train_gligen_text.py
@@ -0,0 +1,715 @@
+# from accelerate.utils import write_basic_config
+#
+# write_basic_config()
+import argparse
+import logging
+import math
+import os
+import shutil
+from pathlib import Path
+
+import accelerate
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import ProjectConfiguration, set_seed
+from packaging import version
+from tqdm.auto import tqdm
+
+import diffusers
+from diffusers import (
+ AutoencoderKL,
+ DDPMScheduler,
+ EulerDiscreteScheduler,
+ StableDiffusionGLIGENPipeline,
+ UNet2DConditionModel,
+)
+from diffusers.optimization import get_scheduler
+from diffusers.utils import is_wandb_available, make_image_grid
+from diffusers.utils.import_utils import is_xformers_available
+from diffusers.utils.torch_utils import is_compiled_module
+
+
+if is_wandb_available():
+ pass
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+# check_min_version("0.28.0.dev0")
+
+logger = get_logger(__name__)
+
+
+@torch.no_grad()
+def log_validation(vae, text_encoder, tokenizer, unet, noise_scheduler, args, accelerator, step, weight_dtype):
+ if accelerator.is_main_process:
+ print("generate test images...")
+ unet = accelerator.unwrap_model(unet)
+ vae.to(accelerator.device, dtype=torch.float32)
+
+ pipeline = StableDiffusionGLIGENPipeline(
+ vae,
+ text_encoder,
+ tokenizer,
+ unet,
+ EulerDiscreteScheduler.from_config(noise_scheduler.config),
+ safety_checker=None,
+ feature_extractor=None,
+ )
+ pipeline = pipeline.to(accelerator.device)
+ pipeline.set_progress_bar_config(disable=not accelerator.is_main_process)
+ if args.enable_xformers_memory_efficient_attention:
+ pipeline.enable_xformers_memory_efficient_attention()
+
+ if args.seed is None:
+ generator = None
+ else:
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
+
+ prompt = "A realistic image of landscape scene depicting a green car parking on the left of a blue truck, with a red air balloon and a bird in the sky"
+ boxes = [
+ [0.041015625, 0.548828125, 0.453125, 0.859375],
+ [0.525390625, 0.552734375, 0.93359375, 0.865234375],
+ [0.12890625, 0.015625, 0.412109375, 0.279296875],
+ [0.578125, 0.08203125, 0.857421875, 0.27734375],
+ ]
+ gligen_phrases = ["a green car", "a blue truck", "a red air balloon", "a bird"]
+ images = pipeline(
+ prompt=prompt,
+ gligen_phrases=gligen_phrases,
+ gligen_boxes=boxes,
+ gligen_scheduled_sampling_beta=1.0,
+ output_type="pil",
+ num_inference_steps=50,
+ negative_prompt="artifacts, blurry, smooth texture, bad quality, distortions, unrealistic, distorted image, bad proportions, duplicate",
+ num_images_per_prompt=4,
+ generator=generator,
+ ).images
+ os.makedirs(os.path.join(args.output_dir, "images"), exist_ok=True)
+ make_image_grid(images, 1, 4).save(
+ os.path.join(args.output_dir, "images", f"generated-images-{step:06d}-{accelerator.process_index:02d}.png")
+ )
+
+ vae.to(accelerator.device, dtype=weight_dtype)
+
+
+def parse_args(input_args=None):
+ parser = argparse.ArgumentParser(description="Simple example of a ControlNet training script.")
+ parser.add_argument(
+ "--data_path",
+ type=str,
+ default="coco_train2017.pth",
+ help="Path to training dataset.",
+ )
+ parser.add_argument(
+ "--image_path",
+ type=str,
+ default="coco_train2017.pth",
+ help="Path to training images.",
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="controlnet-model",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument("--seed", type=int, default=0, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=1)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. "
+ "In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference."
+ "Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components."
+ "See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step"
+ "instructions."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=5e-6,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--lr_num_cycles",
+ type=int,
+ default=1,
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
+ )
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument(
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
+ )
+ parser.add_argument(
+ "--set_grads_to_none",
+ action="store_true",
+ help=(
+ "Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain"
+ " behaviors, so disable this argument if it causes any problems. More info:"
+ " https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html"
+ ),
+ )
+ parser.add_argument(
+ "--tracker_project_name",
+ type=str,
+ default="train_controlnet",
+ help=(
+ "The `project_name` argument passed to Accelerator.init_trackers for"
+ " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
+ ),
+ )
+ args = parser.parse_args()
+ return args
+
+
+def main(args):
+ logging_dir = Path(args.output_dir, args.logging_dir)
+
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ )
+
+ # Disable AMP for MPS.
+ if torch.backends.mps.is_available():
+ accelerator.native_amp = False
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ # import correct text encoder class
+ # text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)
+ # Load scheduler and models
+ from transformers import CLIPTextModel, CLIPTokenizer
+
+ pretrained_model_name_or_path = "masterful/gligen-1-4-generation-text-box"
+ tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer")
+ noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler")
+ text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder="text_encoder")
+
+ vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")
+ unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder="unet")
+
+ # Taken from [Sayak Paul's Diffusers PR #6511](https://github.com/huggingface/diffusers/pull/6511/files)
+ def unwrap_model(model):
+ model = accelerator.unwrap_model(model)
+ model = model._orig_mod if is_compiled_module(model) else model
+ return model
+
+ # `accelerate` 0.16.0 will have better support for customized saving
+ if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
+ def save_model_hook(models, weights, output_dir):
+ if accelerator.is_main_process:
+ i = len(weights) - 1
+
+ while len(weights) > 0:
+ weights.pop()
+ model = models[i]
+
+ sub_dir = "unet"
+ model.save_pretrained(os.path.join(output_dir, sub_dir))
+
+ i -= 1
+
+ def load_model_hook(models, input_dir):
+ while len(models) > 0:
+ # pop models so that they are not loaded again
+ model = models.pop()
+
+ # load diffusers style into model
+ load_model = unet.from_pretrained(input_dir, subfolder="unet")
+ model.register_to_config(**load_model.config)
+
+ model.load_state_dict(load_model.state_dict())
+ del load_model
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
+
+ vae.requires_grad_(False)
+ unet.requires_grad_(False)
+ text_encoder.requires_grad_(False)
+
+ if args.enable_xformers_memory_efficient_attention:
+ if is_xformers_available():
+ import xformers
+
+ xformers_version = version.parse(xformers.__version__)
+ if xformers_version == version.parse("0.0.16"):
+ logger.warning(
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
+ )
+ unet.enable_xformers_memory_efficient_attention()
+ # controlnet.enable_xformers_memory_efficient_attention()
+ else:
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
+
+ # if args.gradient_checkpointing:
+ # controlnet.enable_gradient_checkpointing()
+
+ # Check that all trainable models are in full precision
+ low_precision_error_string = (
+ " Please make sure to always have all model weights in full float32 precision when starting training - even if"
+ " doing mixed precision training, copy of the weights should still be float32."
+ )
+
+ if unwrap_model(unet).dtype != torch.float32:
+ raise ValueError(f"Controlnet loaded as datatype {unwrap_model(unet).dtype}. {low_precision_error_string}")
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ optimizer_class = torch.optim.AdamW
+ # Optimizer creation
+ for n, m in unet.named_modules():
+ if ("fuser" in n) or ("position_net" in n):
+ import torch.nn as nn
+
+ if isinstance(m, (nn.Linear, nn.LayerNorm)):
+ m.reset_parameters()
+ params_to_optimize = []
+ for n, p in unet.named_parameters():
+ if ("fuser" in n) or ("position_net" in n):
+ p.requires_grad = True
+ params_to_optimize.append(p)
+ optimizer = optimizer_class(
+ params_to_optimize,
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ from dataset import COCODataset
+
+ train_dataset = COCODataset(
+ data_path=args.data_path,
+ image_path=args.image_path,
+ tokenizer=tokenizer,
+ image_size=args.resolution,
+ max_boxes_per_data=30,
+ )
+
+ print("num samples: ", len(train_dataset))
+
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset,
+ shuffle=True,
+ # collate_fn=collate_fn,
+ batch_size=args.train_batch_size,
+ num_workers=args.dataloader_num_workers,
+ )
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
+ num_cycles=args.lr_num_cycles,
+ power=args.lr_power,
+ )
+
+ # Prepare everything with our `accelerator`.
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ unet, optimizer, train_dataloader, lr_scheduler
+ )
+
+ # For mixed precision training we cast the text_encoder and vae weights to half-precision
+ # as these models are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ # Move vae, unet and text_encoder to device and cast to weight_dtype
+ vae.to(accelerator.device, dtype=weight_dtype)
+ # unet.to(accelerator.device, dtype=weight_dtype)
+ unet.to(accelerator.device, dtype=torch.float32)
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ tracker_config = dict(vars(args))
+
+ # tensorboard cannot handle list types for config
+ # tracker_config.pop("validation_prompt")
+ # tracker_config.pop("validation_image")
+
+ accelerator.init_trackers(args.tracker_project_name, config=tracker_config)
+
+ # Train!
+ # total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ # logger.info("***** Running training *****")
+ # logger.info(f" Num examples = {len(train_dataset)}")
+ # logger.info(f" Num batches each epoch = {len(train_dataloader)}")
+ # logger.info(f" Num Epochs = {args.num_train_epochs}")
+ # logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ # logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ # logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ # logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the most recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ initial_global_step = 0
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ initial_global_step = global_step
+ first_epoch = global_step // num_update_steps_per_epoch
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
+
+ log_validation(
+ vae,
+ text_encoder,
+ tokenizer,
+ unet,
+ noise_scheduler,
+ args,
+ accelerator,
+ global_step,
+ weight_dtype,
+ )
+
+ # image_logs = None
+ for epoch in range(first_epoch, args.num_train_epochs):
+ for step, batch in enumerate(train_dataloader):
+ with accelerator.accumulate(unet):
+ # Convert images to latent space
+ latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
+ latents = latents * vae.config.scaling_factor
+
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(latents)
+ bsz = latents.shape[0]
+ # Sample a random timestep for each image
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
+ timesteps = timesteps.long()
+
+ # Add noise to the latents according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
+
+ with torch.no_grad():
+ # Get the text embedding for conditioning
+ encoder_hidden_states = text_encoder(
+ batch["caption"]["input_ids"].squeeze(1),
+ # batch['caption']['attention_mask'].squeeze(1),
+ return_dict=False,
+ )[0]
+
+ cross_attention_kwargs = {}
+ cross_attention_kwargs["gligen"] = {
+ "boxes": batch["boxes"],
+ "positive_embeddings": batch["text_embeddings_before_projection"],
+ "masks": batch["masks"],
+ }
+ # Predict the noise residual
+ model_pred = unet(
+ noisy_latents,
+ timesteps,
+ encoder_hidden_states=encoder_hidden_states,
+ cross_attention_kwargs=cross_attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ # Get the target for loss depending on the prediction type
+ if noise_scheduler.config.prediction_type == "epsilon":
+ target = noise
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
+ else:
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
+
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ accelerator.clip_grad_norm_(params_to_optimize, args.max_grad_norm)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad(set_to_none=args.set_grads_to_none)
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ global_step += 1
+
+ if global_step % args.checkpointing_steps == 0:
+ if accelerator.is_main_process:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step:06d}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ # if args.validation_prompt is not None and global_step % args.validation_steps == 0:
+ log_validation(
+ vae,
+ text_encoder,
+ tokenizer,
+ unet,
+ noise_scheduler,
+ args,
+ accelerator,
+ global_step,
+ weight_dtype,
+ )
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+ accelerator.log(logs, step=global_step)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ # Create the pipeline using using the trained modules and save it.
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ unet = unwrap_model(unet)
+ unet.save_pretrained(args.output_dir)
+ #
+ # # Run a final round of validation.
+ # image_logs = None
+ # if args.validation_prompt is not None:
+ # image_logs = log_validation(
+ # vae=vae,
+ # text_encoder=text_encoder,
+ # tokenizer=tokenizer,
+ # unet=unet,
+ # controlnet=None,
+ # args=args,
+ # accelerator=accelerator,
+ # weight_dtype=weight_dtype,
+ # step=global_step,
+ # is_final_validation=True,
+ # )
+ #
+ # if args.push_to_hub:
+ # save_model_card(
+ # repo_id,
+ # image_logs=image_logs,
+ # base_model=args.pretrained_model_name_or_path,
+ # repo_folder=args.output_dir,
+ # )
+ # upload_folder(
+ # repo_id=repo_id,
+ # folder_path=args.output_dir,
+ # commit_message="End of training",
+ # ignore_patterns=["step_*", "epoch_*"],
+ # )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ main(args)
diff --git a/diffusers/examples/research_projects/instructpix2pix_lora/README.md b/diffusers/examples/research_projects/instructpix2pix_lora/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..cfcd98926c074f9d1e32a73f81b0a46a6bf5c6f2
--- /dev/null
+++ b/diffusers/examples/research_projects/instructpix2pix_lora/README.md
@@ -0,0 +1,75 @@
+# InstructPix2Pix text-to-edit-image fine-tuning
+This extended LoRA training script was authored by [Aiden-Frost](https://github.com/Aiden-Frost).
+This is an experimental LoRA extension of [this example](https://github.com/huggingface/diffusers/blob/main/examples/instruct_pix2pix/train_instruct_pix2pix.py). This script provides further support add LoRA layers for unet model.
+
+## Training script example
+
+```bash
+export MODEL_ID="timbrooks/instruct-pix2pix"
+export DATASET_ID="instruction-tuning-sd/cartoonization"
+export OUTPUT_DIR="instructPix2Pix-cartoonization"
+
+accelerate launch finetune_instruct_pix2pix.py \
+ --pretrained_model_name_or_path=$MODEL_ID \
+ --dataset_name=$DATASET_ID \
+ --enable_xformers_memory_efficient_attention \
+ --resolution=256 --random_flip \
+ --train_batch_size=2 --gradient_accumulation_steps=4 --gradient_checkpointing \
+ --max_train_steps=15000 \
+ --checkpointing_steps=5000 --checkpoints_total_limit=1 \
+ --learning_rate=5e-05 --lr_warmup_steps=0 \
+ --val_image_url="https://hf.co/datasets/diffusers/diffusers-images-docs/resolve/main/mountain.png" \
+ --validation_prompt="Generate a cartoonized version of the natural image" \
+ --seed=42 \
+ --rank=4 \
+ --output_dir=$OUTPUT_DIR \
+ --report_to=wandb \
+ --push_to_hub
+```
+
+## Inference
+After training the model and the lora weight of the model is stored in the ```$OUTPUT_DIR```.
+
+```py
+# load the base model pipeline
+pipe_lora = StableDiffusionInstructPix2PixPipeline.from_pretrained("timbrooks/instruct-pix2pix")
+
+# Load LoRA weights from the provided path
+output_dir = "path/to/lora_weight_directory"
+pipe_lora.unet.load_attn_procs(output_dir)
+
+input_image_path = "/path/to/input_image"
+input_image = Image.open(input_image_path)
+edited_images = pipe_lora(num_images_per_prompt=1, prompt=args.edit_prompt, image=input_image, num_inference_steps=1000).images
+edited_images[0].show()
+```
+
+## Results
+
+Here is an example of using the script to train a instructpix2pix model.
+Trained on google colab T4 GPU
+
+```bash
+MODEL_ID="timbrooks/instruct-pix2pix"
+DATASET_ID="instruction-tuning-sd/cartoonization"
+TRAIN_EPOCHS=100
+```
+
+Below are few examples for given the input image, edit_prompt and the edited_image (output of the model)
+
+
+
+
+
+
+Here are some rough statistics about the training model using this script
+
+
+
+
+
+## References
+
+* InstructPix2Pix - https://github.com/timothybrooks/instruct-pix2pix
+* Dataset and example training script - https://huggingface.co/blog/instruction-tuning-sd
+* For more information about the project - https://github.com/Aiden-Frost/Efficiently-teaching-counting-and-cartoonization-to-InstructPix2Pix.-
\ No newline at end of file
diff --git a/diffusers/examples/research_projects/instructpix2pix_lora/train_instruct_pix2pix_lora.py b/diffusers/examples/research_projects/instructpix2pix_lora/train_instruct_pix2pix_lora.py
new file mode 100644
index 0000000000000000000000000000000000000000..997d448fa281bc0c5ba807d8efd176344eb0a6c4
--- /dev/null
+++ b/diffusers/examples/research_projects/instructpix2pix_lora/train_instruct_pix2pix_lora.py
@@ -0,0 +1,1063 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Script to fine-tune Stable Diffusion for InstructPix2Pix."""
+
+import argparse
+import logging
+import math
+import os
+import shutil
+from contextlib import nullcontext
+from pathlib import Path
+
+import accelerate
+import datasets
+import numpy as np
+import PIL
+import requests
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import ProjectConfiguration, set_seed
+from datasets import load_dataset
+from huggingface_hub import create_repo, upload_folder
+from packaging import version
+from torchvision import transforms
+from tqdm.auto import tqdm
+from transformers import CLIPTextModel, CLIPTokenizer
+
+import diffusers
+from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionInstructPix2PixPipeline, UNet2DConditionModel
+from diffusers.models.lora import LoRALinearLayer
+from diffusers.optimization import get_scheduler
+from diffusers.training_utils import EMAModel
+from diffusers.utils import check_min_version, deprecate, is_wandb_available
+from diffusers.utils.import_utils import is_xformers_available
+
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.26.0.dev0")
+
+logger = get_logger(__name__, log_level="INFO")
+
+DATASET_NAME_MAPPING = {
+ "fusing/instructpix2pix-1000-samples": ("input_image", "edit_prompt", "edited_image"),
+}
+WANDB_TABLE_COL_NAMES = ["original_image", "edited_image", "edit_prompt"]
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Simple example of a training script for InstructPix2Pix.")
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--variant",
+ type=str,
+ default=None,
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
+ )
+ parser.add_argument(
+ "--dataset_name",
+ type=str,
+ default=None,
+ help=(
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
+ " or to a folder containing files that 🤗 Datasets can understand."
+ ),
+ )
+ parser.add_argument(
+ "--dataset_config_name",
+ type=str,
+ default=None,
+ help="The config of the Dataset, leave as None if there's only one config.",
+ )
+ parser.add_argument(
+ "--train_data_dir",
+ type=str,
+ default=None,
+ help=(
+ "A folder containing the training data. Folder contents must follow the structure described in"
+ " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
+ " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
+ ),
+ )
+ parser.add_argument(
+ "--original_image_column",
+ type=str,
+ default="input_image",
+ help="The column of the dataset containing the original image on which edits where made.",
+ )
+ parser.add_argument(
+ "--edited_image_column",
+ type=str,
+ default="edited_image",
+ help="The column of the dataset containing the edited image.",
+ )
+ parser.add_argument(
+ "--edit_prompt_column",
+ type=str,
+ default="edit_prompt",
+ help="The column of the dataset containing the edit instruction.",
+ )
+ parser.add_argument(
+ "--val_image_url",
+ type=str,
+ default=None,
+ help="URL to the original image that you would like to edit (used during inference for debugging purposes).",
+ )
+ parser.add_argument(
+ "--validation_prompt", type=str, default=None, help="A prompt that is sampled during training for inference."
+ )
+ parser.add_argument(
+ "--num_validation_images",
+ type=int,
+ default=4,
+ help="Number of images that should be generated during validation with `validation_prompt`.",
+ )
+ parser.add_argument(
+ "--validation_epochs",
+ type=int,
+ default=1,
+ help=(
+ "Run fine-tuning validation every X epochs. The validation process consists of running the prompt"
+ " `args.validation_prompt` multiple times: `args.num_validation_images`."
+ ),
+ )
+ parser.add_argument(
+ "--max_train_samples",
+ type=int,
+ default=None,
+ help=(
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ ),
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="instruct-pix2pix-model",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=256,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--center_crop",
+ default=False,
+ action="store_true",
+ help=(
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
+ " cropped. The images will be resized to the resolution first before cropping."
+ ),
+ )
+ parser.add_argument(
+ "--random_flip",
+ action="store_true",
+ help="whether to randomly flip images horizontally",
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=100)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--conditioning_dropout_prob",
+ type=float,
+ default=None,
+ help="Conditioning dropout probability. Drops out the conditionings (image and edit prompt) used in training InstructPix2Pix. See section 3.2.1 in the paper: https://arxiv.org/abs/2211.09800.",
+ )
+ parser.add_argument(
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")
+ parser.add_argument(
+ "--non_ema_revision",
+ type=str,
+ default=None,
+ required=False,
+ help=(
+ "Revision of pretrained non-ema model identifier. Must be a branch, tag or git identifier of the local or"
+ " remote repository specified with --pretrained_model_name_or_path."
+ ),
+ )
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
+ )
+ parser.add_argument(
+ "--rank",
+ type=int,
+ default=4,
+ help=("The dimension of the LoRA update matrices."),
+ )
+
+ args = parser.parse_args()
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ # Sanity checks
+ if args.dataset_name is None and args.train_data_dir is None:
+ raise ValueError("Need either a dataset name or a training folder.")
+
+ # default to using the same revision for the non-ema model if not specified
+ if args.non_ema_revision is None:
+ args.non_ema_revision = args.revision
+
+ return args
+
+
+def convert_to_np(image, resolution):
+ image = image.convert("RGB").resize((resolution, resolution))
+ return np.array(image).transpose(2, 0, 1)
+
+
+def download_image(url):
+ image = PIL.Image.open(requests.get(url, stream=True).raw)
+ image = PIL.ImageOps.exif_transpose(image)
+ image = image.convert("RGB")
+ return image
+
+
+def main():
+ args = parse_args()
+
+ if args.report_to == "wandb" and args.hub_token is not None:
+ raise ValueError(
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
+ " Please use `huggingface-cli login` to authenticate with the Hub."
+ )
+
+ if args.non_ema_revision is not None:
+ deprecate(
+ "non_ema_revision!=None",
+ "0.15.0",
+ message=(
+ "Downloading 'non_ema' weights from revision branches of the Hub is deprecated. Please make sure to"
+ " use `--variant=non_ema` instead."
+ ),
+ )
+ logging_dir = os.path.join(args.output_dir, args.logging_dir)
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ )
+
+ # Disable AMP for MPS.
+ if torch.backends.mps.is_available():
+ accelerator.native_amp = False
+
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
+
+ if args.report_to == "wandb":
+ if not is_wandb_available():
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
+ import wandb
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ datasets.utils.logging.set_verbosity_warning()
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ datasets.utils.logging.set_verbosity_error()
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
+ ).repo_id
+
+ # Load scheduler, tokenizer and models.
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
+ tokenizer = CLIPTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
+ )
+ text_encoder = CLIPTextModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
+ )
+ vae = AutoencoderKL.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant
+ )
+ unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.non_ema_revision
+ )
+
+ # Freeze vae, text_encoder and unet
+ vae.requires_grad_(False)
+ text_encoder.requires_grad_(False)
+ unet.requires_grad_(False)
+
+ # referred to https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_lora.py
+ unet_lora_parameters = []
+ for attn_processor_name, attn_processor in unet.attn_processors.items():
+ # Parse the attention module.
+ attn_module = unet
+ for n in attn_processor_name.split(".")[:-1]:
+ attn_module = getattr(attn_module, n)
+
+ # Set the `lora_layer` attribute of the attention-related matrices.
+ attn_module.to_q.set_lora_layer(
+ LoRALinearLayer(
+ in_features=attn_module.to_q.in_features, out_features=attn_module.to_q.out_features, rank=args.rank
+ )
+ )
+ attn_module.to_k.set_lora_layer(
+ LoRALinearLayer(
+ in_features=attn_module.to_k.in_features, out_features=attn_module.to_k.out_features, rank=args.rank
+ )
+ )
+
+ attn_module.to_v.set_lora_layer(
+ LoRALinearLayer(
+ in_features=attn_module.to_v.in_features, out_features=attn_module.to_v.out_features, rank=args.rank
+ )
+ )
+ attn_module.to_out[0].set_lora_layer(
+ LoRALinearLayer(
+ in_features=attn_module.to_out[0].in_features,
+ out_features=attn_module.to_out[0].out_features,
+ rank=args.rank,
+ )
+ )
+
+ # Accumulate the LoRA params to optimize.
+ unet_lora_parameters.extend(attn_module.to_q.lora_layer.parameters())
+ unet_lora_parameters.extend(attn_module.to_k.lora_layer.parameters())
+ unet_lora_parameters.extend(attn_module.to_v.lora_layer.parameters())
+ unet_lora_parameters.extend(attn_module.to_out[0].lora_layer.parameters())
+
+ # Create EMA for the unet.
+ if args.use_ema:
+ ema_unet = EMAModel(unet.parameters(), model_cls=UNet2DConditionModel, model_config=unet.config)
+
+ if args.enable_xformers_memory_efficient_attention:
+ if is_xformers_available():
+ import xformers
+
+ xformers_version = version.parse(xformers.__version__)
+ if xformers_version == version.parse("0.0.16"):
+ logger.warning(
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
+ )
+ unet.enable_xformers_memory_efficient_attention()
+ else:
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
+
+ # `accelerate` 0.16.0 will have better support for customized saving
+ if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
+ def save_model_hook(models, weights, output_dir):
+ if accelerator.is_main_process:
+ if args.use_ema:
+ ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema"))
+
+ for i, model in enumerate(models):
+ model.save_pretrained(os.path.join(output_dir, "unet"))
+
+ # make sure to pop weight so that corresponding model is not saved again
+ weights.pop()
+
+ def load_model_hook(models, input_dir):
+ if args.use_ema:
+ load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), UNet2DConditionModel)
+ ema_unet.load_state_dict(load_model.state_dict())
+ ema_unet.to(accelerator.device)
+ del load_model
+
+ for i in range(len(models)):
+ # pop models so that they are not loaded again
+ model = models.pop()
+
+ # load diffusers style into model
+ load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet")
+ model.register_to_config(**load_model.config)
+
+ model.load_state_dict(load_model.state_dict())
+ del load_model
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
+
+ if args.gradient_checkpointing:
+ unet.enable_gradient_checkpointing()
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ # Initialize the optimizer
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
+ )
+
+ optimizer_cls = bnb.optim.AdamW8bit
+ else:
+ optimizer_cls = torch.optim.AdamW
+
+ # train on only unet_lora_parameters
+ optimizer = optimizer_cls(
+ unet_lora_parameters,
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ # Get the datasets: you can either provide your own training and evaluation files (see below)
+ # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
+
+ # In distributed training, the load_dataset function guarantees that only one local process can concurrently
+ # download the dataset.
+ if args.dataset_name is not None:
+ # Downloading and loading a dataset from the hub.
+ dataset = load_dataset(
+ args.dataset_name,
+ args.dataset_config_name,
+ cache_dir=args.cache_dir,
+ )
+ else:
+ data_files = {}
+ if args.train_data_dir is not None:
+ data_files["train"] = os.path.join(args.train_data_dir, "**")
+ dataset = load_dataset(
+ "imagefolder",
+ data_files=data_files,
+ cache_dir=args.cache_dir,
+ )
+ # See more about loading custom images at
+ # https://huggingface.co/docs/datasets/main/en/image_load#imagefolder
+
+ # Preprocessing the datasets.
+ # We need to tokenize inputs and targets.
+ column_names = dataset["train"].column_names
+
+ # 6. Get the column names for input/target.
+ dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None)
+ if args.original_image_column is None:
+ original_image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
+ else:
+ original_image_column = args.original_image_column
+ if original_image_column not in column_names:
+ raise ValueError(
+ f"--original_image_column' value '{args.original_image_column}' needs to be one of: {', '.join(column_names)}"
+ )
+ if args.edit_prompt_column is None:
+ edit_prompt_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
+ else:
+ edit_prompt_column = args.edit_prompt_column
+ if edit_prompt_column not in column_names:
+ raise ValueError(
+ f"--edit_prompt_column' value '{args.edit_prompt_column}' needs to be one of: {', '.join(column_names)}"
+ )
+ if args.edited_image_column is None:
+ edited_image_column = dataset_columns[2] if dataset_columns is not None else column_names[2]
+ else:
+ edited_image_column = args.edited_image_column
+ if edited_image_column not in column_names:
+ raise ValueError(
+ f"--edited_image_column' value '{args.edited_image_column}' needs to be one of: {', '.join(column_names)}"
+ )
+
+ # Preprocessing the datasets.
+ # We need to tokenize input captions and transform the images.
+ def tokenize_captions(captions):
+ inputs = tokenizer(
+ captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
+ )
+ return inputs.input_ids
+
+ # Preprocessing the datasets.
+ train_transforms = transforms.Compose(
+ [
+ transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
+ transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
+ ]
+ )
+
+ def preprocess_images(examples):
+ original_images = np.concatenate(
+ [convert_to_np(image, args.resolution) for image in examples[original_image_column]]
+ )
+ edited_images = np.concatenate(
+ [convert_to_np(image, args.resolution) for image in examples[edited_image_column]]
+ )
+ # We need to ensure that the original and the edited images undergo the same
+ # augmentation transforms.
+ images = np.concatenate([original_images, edited_images])
+ images = torch.tensor(images)
+ images = 2 * (images / 255) - 1
+ return train_transforms(images)
+
+ def preprocess_train(examples):
+ # Preprocess images.
+ preprocessed_images = preprocess_images(examples)
+ # Since the original and edited images were concatenated before
+ # applying the transformations, we need to separate them and reshape
+ # them accordingly.
+ original_images, edited_images = preprocessed_images.chunk(2)
+ original_images = original_images.reshape(-1, 3, args.resolution, args.resolution)
+ edited_images = edited_images.reshape(-1, 3, args.resolution, args.resolution)
+
+ # Collate the preprocessed images into the `examples`.
+ examples["original_pixel_values"] = original_images
+ examples["edited_pixel_values"] = edited_images
+
+ # Preprocess the captions.
+ captions = list(examples[edit_prompt_column])
+ examples["input_ids"] = tokenize_captions(captions)
+ return examples
+
+ with accelerator.main_process_first():
+ if args.max_train_samples is not None:
+ dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
+ # Set the training transforms
+ train_dataset = dataset["train"].with_transform(preprocess_train)
+
+ def collate_fn(examples):
+ original_pixel_values = torch.stack([example["original_pixel_values"] for example in examples])
+ original_pixel_values = original_pixel_values.to(memory_format=torch.contiguous_format).float()
+ edited_pixel_values = torch.stack([example["edited_pixel_values"] for example in examples])
+ edited_pixel_values = edited_pixel_values.to(memory_format=torch.contiguous_format).float()
+ input_ids = torch.stack([example["input_ids"] for example in examples])
+ return {
+ "original_pixel_values": original_pixel_values,
+ "edited_pixel_values": edited_pixel_values,
+ "input_ids": input_ids,
+ }
+
+ # DataLoaders creation:
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset,
+ shuffle=True,
+ collate_fn=collate_fn,
+ batch_size=args.train_batch_size,
+ num_workers=args.dataloader_num_workers,
+ )
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
+ )
+
+ # Prepare everything with our `accelerator`.
+ unet, unet_lora_parameters, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ unet, unet_lora_parameters, optimizer, train_dataloader, lr_scheduler
+ )
+
+ if args.use_ema:
+ ema_unet.to(accelerator.device)
+
+ # For mixed precision training we cast the text_encoder and vae weights to half-precision
+ # as these models are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ # Move text_encode and vae to gpu and cast to weight_dtype
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
+ vae.to(accelerator.device, dtype=weight_dtype)
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ accelerator.init_trackers("instruct-pix2pix", config=vars(args))
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the most recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ resume_global_step = global_step * args.gradient_accumulation_steps
+ first_epoch = global_step // num_update_steps_per_epoch
+ resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
+
+ # Only show the progress bar once on each machine.
+ progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
+ progress_bar.set_description("Steps")
+
+ for epoch in range(first_epoch, args.num_train_epochs):
+ unet.train()
+ train_loss = 0.0
+ for step, batch in enumerate(train_dataloader):
+ # Skip steps until we reach the resumed step
+ if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
+ if step % args.gradient_accumulation_steps == 0:
+ progress_bar.update(1)
+ continue
+
+ with accelerator.accumulate(unet):
+ # We want to learn the denoising process w.r.t the edited images which
+ # are conditioned on the original image (which was edited) and the edit instruction.
+ # So, first, convert images to latent space.
+ latents = vae.encode(batch["edited_pixel_values"].to(weight_dtype)).latent_dist.sample()
+ latents = latents * vae.config.scaling_factor
+
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(latents)
+ bsz = latents.shape[0]
+ # Sample a random timestep for each image
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
+ timesteps = timesteps.long()
+
+ # Add noise to the latents according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
+
+ # Get the text embedding for conditioning.
+ encoder_hidden_states = text_encoder(batch["input_ids"])[0]
+
+ # Get the additional image embedding for conditioning.
+ # Instead of getting a diagonal Gaussian here, we simply take the mode.
+ original_image_embeds = vae.encode(batch["original_pixel_values"].to(weight_dtype)).latent_dist.mode()
+
+ # Conditioning dropout to support classifier-free guidance during inference. For more details
+ # check out the section 3.2.1 of the original paper https://arxiv.org/abs/2211.09800.
+ if args.conditioning_dropout_prob is not None:
+ random_p = torch.rand(bsz, device=latents.device, generator=generator)
+ # Sample masks for the edit prompts.
+ prompt_mask = random_p < 2 * args.conditioning_dropout_prob
+ prompt_mask = prompt_mask.reshape(bsz, 1, 1)
+ # Final text conditioning.
+ null_conditioning = text_encoder(tokenize_captions([""]).to(accelerator.device))[0]
+ encoder_hidden_states = torch.where(prompt_mask, null_conditioning, encoder_hidden_states)
+
+ # Sample masks for the original images.
+ image_mask_dtype = original_image_embeds.dtype
+ image_mask = 1 - (
+ (random_p >= args.conditioning_dropout_prob).to(image_mask_dtype)
+ * (random_p < 3 * args.conditioning_dropout_prob).to(image_mask_dtype)
+ )
+ image_mask = image_mask.reshape(bsz, 1, 1, 1)
+ # Final image conditioning.
+ original_image_embeds = image_mask * original_image_embeds
+
+ # Concatenate the `original_image_embeds` with the `noisy_latents`.
+ concatenated_noisy_latents = torch.cat([noisy_latents, original_image_embeds], dim=1)
+
+ # Get the target for loss depending on the prediction type
+ if noise_scheduler.config.prediction_type == "epsilon":
+ target = noise
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
+ else:
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
+
+ # Predict the noise residual and compute loss
+ model_pred = unet(concatenated_noisy_latents, timesteps, encoder_hidden_states).sample
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
+
+ # Gather the losses across all processes for logging (if we use distributed training).
+ avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
+ train_loss += avg_loss.item() / args.gradient_accumulation_steps
+
+ # Backpropagate
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ accelerator.clip_grad_norm_(unet_lora_parameters, args.max_grad_norm)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ if args.use_ema:
+ ema_unet.step(unet_lora_parameters)
+ progress_bar.update(1)
+ global_step += 1
+ accelerator.log({"train_loss": train_loss}, step=global_step)
+ train_loss = 0.0
+
+ if global_step % args.checkpointing_steps == 0:
+ if accelerator.is_main_process:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ if accelerator.is_main_process:
+ if (
+ (args.val_image_url is not None)
+ and (args.validation_prompt is not None)
+ and (epoch % args.validation_epochs == 0)
+ ):
+ logger.info(
+ f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
+ f" {args.validation_prompt}."
+ )
+ # create pipeline
+ if args.use_ema:
+ # Store the UNet parameters temporarily and load the EMA parameters to perform inference.
+ ema_unet.store(unet.parameters())
+ ema_unet.copy_to(unet.parameters())
+ # The models need unwrapping because for compatibility in distributed training mode.
+ pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ unet=accelerator.unwrap_model(unet),
+ text_encoder=accelerator.unwrap_model(text_encoder),
+ vae=accelerator.unwrap_model(vae),
+ revision=args.revision,
+ variant=args.variant,
+ torch_dtype=weight_dtype,
+ )
+ pipeline = pipeline.to(accelerator.device)
+ pipeline.set_progress_bar_config(disable=True)
+
+ # run inference
+ original_image = download_image(args.val_image_url)
+ edited_images = []
+ if torch.backends.mps.is_available():
+ autocast_ctx = nullcontext()
+ else:
+ autocast_ctx = torch.autocast(accelerator.device.type)
+
+ with autocast_ctx:
+ for _ in range(args.num_validation_images):
+ edited_images.append(
+ pipeline(
+ args.validation_prompt,
+ image=original_image,
+ num_inference_steps=20,
+ image_guidance_scale=1.5,
+ guidance_scale=7,
+ generator=generator,
+ ).images[0]
+ )
+
+ for tracker in accelerator.trackers:
+ if tracker.name == "wandb":
+ wandb_table = wandb.Table(columns=WANDB_TABLE_COL_NAMES)
+ for edited_image in edited_images:
+ wandb_table.add_data(
+ wandb.Image(original_image), wandb.Image(edited_image), args.validation_prompt
+ )
+ tracker.log({"validation": wandb_table})
+ if args.use_ema:
+ # Switch back to the original UNet parameters.
+ ema_unet.restore(unet.parameters())
+
+ del pipeline
+ torch.cuda.empty_cache()
+
+ # Create the pipeline using the trained modules and save it.
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ unet = accelerator.unwrap_model(unet)
+ if args.use_ema:
+ ema_unet.copy_to(unet.parameters())
+
+ pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ text_encoder=accelerator.unwrap_model(text_encoder),
+ vae=accelerator.unwrap_model(vae),
+ unet=unet,
+ revision=args.revision,
+ variant=args.variant,
+ )
+ # store only LORA layers
+ unet.save_attn_procs(args.output_dir)
+
+ if args.push_to_hub:
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ if args.validation_prompt is not None:
+ edited_images = []
+ pipeline = pipeline.to(accelerator.device)
+ with torch.autocast(str(accelerator.device).replace(":0", "")):
+ for _ in range(args.num_validation_images):
+ edited_images.append(
+ pipeline(
+ args.validation_prompt,
+ image=original_image,
+ num_inference_steps=20,
+ image_guidance_scale=1.5,
+ guidance_scale=7,
+ generator=generator,
+ ).images[0]
+ )
+
+ for tracker in accelerator.trackers:
+ if tracker.name == "wandb":
+ wandb_table = wandb.Table(columns=WANDB_TABLE_COL_NAMES)
+ for edited_image in edited_images:
+ wandb_table.add_data(
+ wandb.Image(original_image), wandb.Image(edited_image), args.validation_prompt
+ )
+ tracker.log({"test": wandb_table})
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/diffusers/examples/research_projects/intel_opts/README.md b/diffusers/examples/research_projects/intel_opts/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..b2f75b3b0e3c4e7c90f51e55980a5b6044073f37
--- /dev/null
+++ b/diffusers/examples/research_projects/intel_opts/README.md
@@ -0,0 +1,36 @@
+## Diffusers examples with Intel optimizations
+
+**This research project is not actively maintained by the diffusers team. For any questions or comments, please make sure to tag @hshen14 .**
+
+This aims to provide diffusers examples with Intel optimizations such as Bfloat16 for training/fine-tuning acceleration and 8-bit integer (INT8) for inference acceleration on Intel platforms.
+
+## Accelerating the fine-tuning for textual inversion
+
+We accelerate the fine-tuning for textual inversion with Intel Extension for PyTorch. The [examples](textual_inversion) enable both single node and multi-node distributed training with Bfloat16 support on Intel Xeon Scalable Processor.
+
+## Accelerating the inference for Stable Diffusion using Bfloat16
+
+We start the inference acceleration with Bfloat16 using Intel Extension for PyTorch. The [script](inference_bf16.py) is generally designed to support standard Stable Diffusion models with Bfloat16 support.
+```bash
+pip install diffusers transformers accelerate scipy safetensors
+
+export KMP_BLOCKTIME=1
+export KMP_SETTINGS=1
+export KMP_AFFINITY=granularity=fine,compact,1,0
+
+# Intel OpenMP
+export OMP_NUM_THREADS=< Cores to use >
+export LD_PRELOAD=${LD_PRELOAD}:/path/to/lib/libiomp5.so
+# Jemalloc is a recommended malloc implementation that emphasizes fragmentation avoidance and scalable concurrency support.
+export LD_PRELOAD=${LD_PRELOAD}:/path/to/lib/libjemalloc.so
+export MALLOC_CONF="oversize_threshold:1,background_thread:true,metadata_thp:auto,dirty_decay_ms:-1,muzzy_decay_ms:9000000000"
+
+# Launch with default DDIM
+numactl --membind -C python python inference_bf16.py
+# Launch with DPMSolverMultistepScheduler
+numactl --membind -C python python inference_bf16.py --dpm
+```
+
+## Accelerating the inference for Stable Diffusion using INT8
+
+Coming soon ...
diff --git a/diffusers/examples/research_projects/intel_opts/inference_bf16.py b/diffusers/examples/research_projects/intel_opts/inference_bf16.py
new file mode 100644
index 0000000000000000000000000000000000000000..96ec709f433cd13dad0b93d5368d61e169b9df28
--- /dev/null
+++ b/diffusers/examples/research_projects/intel_opts/inference_bf16.py
@@ -0,0 +1,56 @@
+import argparse
+
+import intel_extension_for_pytorch as ipex
+import torch
+
+from diffusers import DPMSolverMultistepScheduler, StableDiffusionPipeline
+
+
+parser = argparse.ArgumentParser("Stable Diffusion script with intel optimization", add_help=False)
+parser.add_argument("--dpm", action="store_true", help="Enable DPMSolver or not")
+parser.add_argument("--steps", default=None, type=int, help="Num inference steps")
+args = parser.parse_args()
+
+
+device = "cpu"
+prompt = "a lovely in red dress and hat, in the snowly and brightly night, with many brighly buildings"
+
+model_id = "path-to-your-trained-model"
+pipe = StableDiffusionPipeline.from_pretrained(model_id)
+if args.dpm:
+ pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
+pipe = pipe.to(device)
+
+# to channels last
+pipe.unet = pipe.unet.to(memory_format=torch.channels_last)
+pipe.vae = pipe.vae.to(memory_format=torch.channels_last)
+pipe.text_encoder = pipe.text_encoder.to(memory_format=torch.channels_last)
+if pipe.requires_safety_checker:
+ pipe.safety_checker = pipe.safety_checker.to(memory_format=torch.channels_last)
+
+# optimize with ipex
+sample = torch.randn(2, 4, 64, 64)
+timestep = torch.rand(1) * 999
+encoder_hidden_status = torch.randn(2, 77, 768)
+input_example = (sample, timestep, encoder_hidden_status)
+try:
+ pipe.unet = ipex.optimize(pipe.unet.eval(), dtype=torch.bfloat16, inplace=True, sample_input=input_example)
+except Exception:
+ pipe.unet = ipex.optimize(pipe.unet.eval(), dtype=torch.bfloat16, inplace=True)
+pipe.vae = ipex.optimize(pipe.vae.eval(), dtype=torch.bfloat16, inplace=True)
+pipe.text_encoder = ipex.optimize(pipe.text_encoder.eval(), dtype=torch.bfloat16, inplace=True)
+if pipe.requires_safety_checker:
+ pipe.safety_checker = ipex.optimize(pipe.safety_checker.eval(), dtype=torch.bfloat16, inplace=True)
+
+# compute
+seed = 666
+generator = torch.Generator(device).manual_seed(seed)
+generate_kwargs = {"generator": generator}
+if args.steps is not None:
+ generate_kwargs["num_inference_steps"] = args.steps
+
+with torch.cpu.amp.autocast(enabled=True, dtype=torch.bfloat16):
+ image = pipe(prompt, **generate_kwargs).images[0]
+
+# save image
+image.save("generated.png")
diff --git a/diffusers/examples/research_projects/intel_opts/textual_inversion/README.md b/diffusers/examples/research_projects/intel_opts/textual_inversion/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..3339b8e2cb63674481baf23de1bbe26772120ba7
--- /dev/null
+++ b/diffusers/examples/research_projects/intel_opts/textual_inversion/README.md
@@ -0,0 +1,68 @@
+## Textual Inversion fine-tuning example
+
+[Textual inversion](https://arxiv.org/abs/2208.01618) is a method to personalize text2image models like stable diffusion on your own images using just 3-5 examples.
+The `textual_inversion.py` script shows how to implement the training procedure and adapt it for stable diffusion.
+
+## Training with Intel Extension for PyTorch
+
+Intel Extension for PyTorch provides the optimizations for faster training and inference on CPUs. You can leverage the training example "textual_inversion.py". Follow the [instructions](https://github.com/huggingface/diffusers/tree/main/examples/textual_inversion) to get the model and [dataset](https://huggingface.co/sd-concepts-library/dicoo2) before running the script.
+
+The example supports both single node and multi-node distributed training:
+
+### Single node training
+
+```bash
+export MODEL_NAME="CompVis/stable-diffusion-v1-4"
+export DATA_DIR="path-to-dir-containing-dicoo-images"
+
+python textual_inversion.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --train_data_dir=$DATA_DIR \
+ --learnable_property="object" \
+ --placeholder_token="" --initializer_token="toy" \
+ --seed=7 \
+ --resolution=512 \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=1 \
+ --max_train_steps=3000 \
+ --learning_rate=2.5e-03 --scale_lr \
+ --output_dir="textual_inversion_dicoo"
+```
+
+Note: Bfloat16 is available on Intel Xeon Scalable Processors Cooper Lake or Sapphire Rapids. You may not get performance speedup without Bfloat16 support.
+
+### Multi-node distributed training
+
+Before running the scripts, make sure to install the library's training dependencies successfully:
+
+```bash
+python -m pip install oneccl_bind_pt==1.13 -f https://developer.intel.com/ipex-whl-stable-cpu
+```
+
+```bash
+export MODEL_NAME="CompVis/stable-diffusion-v1-4"
+export DATA_DIR="path-to-dir-containing-dicoo-images"
+
+oneccl_bindings_for_pytorch_path=$(python -c "from oneccl_bindings_for_pytorch import cwd; print(cwd)")
+source $oneccl_bindings_for_pytorch_path/env/setvars.sh
+
+python -m intel_extension_for_pytorch.cpu.launch --distributed \
+ --hostfile hostfile --nnodes 2 --nproc_per_node 2 textual_inversion.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --train_data_dir=$DATA_DIR \
+ --learnable_property="object" \
+ --placeholder_token="" --initializer_token="toy" \
+ --seed=7 \
+ --resolution=512 \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=1 \
+ --max_train_steps=750 \
+ --learning_rate=2.5e-03 --scale_lr \
+ --output_dir="textual_inversion_dicoo"
+```
+The above is a simple distributed training usage on 2 nodes with 2 processes on each node. Add the right hostname or ip address in the "hostfile" and make sure these 2 nodes are reachable from each other. For more details, please refer to the [user guide](https://github.com/intel/torch-ccl).
+
+
+### Reference
+
+We publish a [Medium blog](https://medium.com/intel-analytics-software/personalized-stable-diffusion-with-few-shot-fine-tuning-on-a-single-cpu-f01a3316b13) on how to create your own Stable Diffusion model on CPUs using textual inversion. Try it out now, if you have interests.
diff --git a/diffusers/examples/research_projects/intel_opts/textual_inversion/requirements.txt b/diffusers/examples/research_projects/intel_opts/textual_inversion/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..af7ed6b21f6fb4518930a37786199643b1c60ece
--- /dev/null
+++ b/diffusers/examples/research_projects/intel_opts/textual_inversion/requirements.txt
@@ -0,0 +1,7 @@
+accelerate>=0.16.0
+torchvision
+transformers>=4.21.0
+ftfy
+tensorboard
+Jinja2
+intel_extension_for_pytorch>=1.13
diff --git a/diffusers/examples/research_projects/intel_opts/textual_inversion/textual_inversion_bf16.py b/diffusers/examples/research_projects/intel_opts/textual_inversion/textual_inversion_bf16.py
new file mode 100644
index 0000000000000000000000000000000000000000..ea4a0d255b680ffce9c143cacfbcb711e114fdd7
--- /dev/null
+++ b/diffusers/examples/research_projects/intel_opts/textual_inversion/textual_inversion_bf16.py
@@ -0,0 +1,646 @@
+import argparse
+import itertools
+import math
+import os
+import random
+from pathlib import Path
+
+import intel_extension_for_pytorch as ipex
+import numpy as np
+import PIL
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import ProjectConfiguration, set_seed
+from huggingface_hub import create_repo, upload_folder
+
+# TODO: remove and import from diffusers.utils when the new version of diffusers is released
+from packaging import version
+from PIL import Image
+from torch.utils.data import Dataset
+from torchvision import transforms
+from tqdm.auto import tqdm
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
+
+from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel
+from diffusers.optimization import get_scheduler
+from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
+from diffusers.utils import check_min_version
+
+
+if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
+ PIL_INTERPOLATION = {
+ "linear": PIL.Image.Resampling.BILINEAR,
+ "bilinear": PIL.Image.Resampling.BILINEAR,
+ "bicubic": PIL.Image.Resampling.BICUBIC,
+ "lanczos": PIL.Image.Resampling.LANCZOS,
+ "nearest": PIL.Image.Resampling.NEAREST,
+ }
+else:
+ PIL_INTERPOLATION = {
+ "linear": PIL.Image.LINEAR,
+ "bilinear": PIL.Image.BILINEAR,
+ "bicubic": PIL.Image.BICUBIC,
+ "lanczos": PIL.Image.LANCZOS,
+ "nearest": PIL.Image.NEAREST,
+ }
+# ------------------------------------------------------------------------------
+
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.13.0.dev0")
+
+
+logger = get_logger(__name__)
+
+
+def save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path):
+ logger.info("Saving embeddings")
+ learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[placeholder_token_id]
+ learned_embeds_dict = {args.placeholder_token: learned_embeds.detach().cpu()}
+ torch.save(learned_embeds_dict, save_path)
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ parser.add_argument(
+ "--save_steps",
+ type=int,
+ default=500,
+ help="Save learned_embeds.bin every X updates steps.",
+ )
+ parser.add_argument(
+ "--only_save_embeds",
+ action="store_true",
+ default=False,
+ help="Save only the embeddings for the new concept.",
+ )
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--tokenizer_name",
+ type=str,
+ default=None,
+ help="Pretrained tokenizer name or path if not the same as model_name",
+ )
+ parser.add_argument(
+ "--train_data_dir", type=str, default=None, required=True, help="A folder containing the training data."
+ )
+ parser.add_argument(
+ "--placeholder_token",
+ type=str,
+ default=None,
+ required=True,
+ help="A token to use as a placeholder for the concept.",
+ )
+ parser.add_argument(
+ "--initializer_token", type=str, default=None, required=True, help="A token to use as initializer word."
+ )
+ parser.add_argument("--learnable_property", type=str, default="object", help="Choose between 'object' and 'style'")
+ parser.add_argument("--repeats", type=int, default=100, help="How many times to repeat the training data.")
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="text-inversion-model",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution."
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=100)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=5000,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=True,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default="no",
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose"
+ "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
+ "and an Nvidia Ampere GPU."
+ ),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+
+ args = parser.parse_args()
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ if args.train_data_dir is None:
+ raise ValueError("You must specify a train data directory.")
+
+ return args
+
+
+imagenet_templates_small = [
+ "a photo of a {}",
+ "a rendering of a {}",
+ "a cropped photo of the {}",
+ "the photo of a {}",
+ "a photo of a clean {}",
+ "a photo of a dirty {}",
+ "a dark photo of the {}",
+ "a photo of my {}",
+ "a photo of the cool {}",
+ "a close-up photo of a {}",
+ "a bright photo of the {}",
+ "a cropped photo of a {}",
+ "a photo of the {}",
+ "a good photo of the {}",
+ "a photo of one {}",
+ "a close-up photo of the {}",
+ "a rendition of the {}",
+ "a photo of the clean {}",
+ "a rendition of a {}",
+ "a photo of a nice {}",
+ "a good photo of a {}",
+ "a photo of the nice {}",
+ "a photo of the small {}",
+ "a photo of the weird {}",
+ "a photo of the large {}",
+ "a photo of a cool {}",
+ "a photo of a small {}",
+]
+
+imagenet_style_templates_small = [
+ "a painting in the style of {}",
+ "a rendering in the style of {}",
+ "a cropped painting in the style of {}",
+ "the painting in the style of {}",
+ "a clean painting in the style of {}",
+ "a dirty painting in the style of {}",
+ "a dark painting in the style of {}",
+ "a picture in the style of {}",
+ "a cool painting in the style of {}",
+ "a close-up painting in the style of {}",
+ "a bright painting in the style of {}",
+ "a cropped painting in the style of {}",
+ "a good painting in the style of {}",
+ "a close-up painting in the style of {}",
+ "a rendition in the style of {}",
+ "a nice painting in the style of {}",
+ "a small painting in the style of {}",
+ "a weird painting in the style of {}",
+ "a large painting in the style of {}",
+]
+
+
+class TextualInversionDataset(Dataset):
+ def __init__(
+ self,
+ data_root,
+ tokenizer,
+ learnable_property="object", # [object, style]
+ size=512,
+ repeats=100,
+ interpolation="bicubic",
+ flip_p=0.5,
+ set="train",
+ placeholder_token="*",
+ center_crop=False,
+ ):
+ self.data_root = data_root
+ self.tokenizer = tokenizer
+ self.learnable_property = learnable_property
+ self.size = size
+ self.placeholder_token = placeholder_token
+ self.center_crop = center_crop
+ self.flip_p = flip_p
+
+ self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)]
+
+ self.num_images = len(self.image_paths)
+ self._length = self.num_images
+
+ if set == "train":
+ self._length = self.num_images * repeats
+
+ self.interpolation = {
+ "linear": PIL_INTERPOLATION["linear"],
+ "bilinear": PIL_INTERPOLATION["bilinear"],
+ "bicubic": PIL_INTERPOLATION["bicubic"],
+ "lanczos": PIL_INTERPOLATION["lanczos"],
+ }[interpolation]
+
+ self.templates = imagenet_style_templates_small if learnable_property == "style" else imagenet_templates_small
+ self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p)
+
+ def __len__(self):
+ return self._length
+
+ def __getitem__(self, i):
+ example = {}
+ image = Image.open(self.image_paths[i % self.num_images])
+
+ if not image.mode == "RGB":
+ image = image.convert("RGB")
+
+ placeholder_string = self.placeholder_token
+ text = random.choice(self.templates).format(placeholder_string)
+
+ example["input_ids"] = self.tokenizer(
+ text,
+ padding="max_length",
+ truncation=True,
+ max_length=self.tokenizer.model_max_length,
+ return_tensors="pt",
+ ).input_ids[0]
+
+ # default to score-sde preprocessing
+ img = np.array(image).astype(np.uint8)
+
+ if self.center_crop:
+ crop = min(img.shape[0], img.shape[1])
+ (
+ h,
+ w,
+ ) = (
+ img.shape[0],
+ img.shape[1],
+ )
+ img = img[(h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2]
+
+ image = Image.fromarray(img)
+ image = image.resize((self.size, self.size), resample=self.interpolation)
+
+ image = self.flip_transform(image)
+ image = np.array(image).astype(np.uint8)
+ image = (image / 127.5 - 1.0).astype(np.float32)
+
+ example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1)
+ return example
+
+
+def freeze_params(params):
+ for param in params:
+ param.requires_grad = False
+
+
+def main():
+ args = parse_args()
+
+ if args.report_to == "wandb" and args.hub_token is not None:
+ raise ValueError(
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
+ " Please use `huggingface-cli login` to authenticate with the Hub."
+ )
+
+ logging_dir = os.path.join(args.output_dir, args.logging_dir)
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ )
+
+ # Disable AMP for MPS.
+ if torch.backends.mps.is_available():
+ accelerator.native_amp = False
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
+ ).repo_id
+
+ # Load the tokenizer and add the placeholder token as a additional special token
+ if args.tokenizer_name:
+ tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
+ elif args.pretrained_model_name_or_path:
+ tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
+
+ # Add the placeholder token in tokenizer
+ num_added_tokens = tokenizer.add_tokens(args.placeholder_token)
+ if num_added_tokens == 0:
+ raise ValueError(
+ f"The tokenizer already contains the token {args.placeholder_token}. Please pass a different"
+ " `placeholder_token` that is not already in the tokenizer."
+ )
+
+ # Convert the initializer_token, placeholder_token to ids
+ token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False)
+ # Check if initializer_token is a single token or a sequence of tokens
+ if len(token_ids) > 1:
+ raise ValueError("The initializer token must be a single token.")
+
+ initializer_token_id = token_ids[0]
+ placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token)
+
+ # Load models and create wrapper for stable diffusion
+ text_encoder = CLIPTextModel.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="text_encoder",
+ revision=args.revision,
+ )
+ vae = AutoencoderKL.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="vae",
+ revision=args.revision,
+ )
+ unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="unet",
+ revision=args.revision,
+ )
+
+ # Resize the token embeddings as we are adding new special tokens to the tokenizer
+ text_encoder.resize_token_embeddings(len(tokenizer))
+
+ # Initialise the newly added placeholder token with the embeddings of the initializer token
+ token_embeds = text_encoder.get_input_embeddings().weight.data
+ token_embeds[placeholder_token_id] = token_embeds[initializer_token_id]
+
+ # Freeze vae and unet
+ freeze_params(vae.parameters())
+ freeze_params(unet.parameters())
+ # Freeze all parameters except for the token embeddings in text encoder
+ params_to_freeze = itertools.chain(
+ text_encoder.text_model.encoder.parameters(),
+ text_encoder.text_model.final_layer_norm.parameters(),
+ text_encoder.text_model.embeddings.position_embedding.parameters(),
+ )
+ freeze_params(params_to_freeze)
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ # Initialize the optimizer
+ optimizer = torch.optim.AdamW(
+ text_encoder.get_input_embeddings().parameters(), # only optimize the embeddings
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
+
+ train_dataset = TextualInversionDataset(
+ data_root=args.train_data_dir,
+ tokenizer=tokenizer,
+ size=args.resolution,
+ placeholder_token=args.placeholder_token,
+ repeats=args.repeats,
+ learnable_property=args.learnable_property,
+ center_crop=args.center_crop,
+ set="train",
+ )
+ train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=args.train_batch_size, shuffle=True)
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
+ )
+
+ text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ text_encoder, optimizer, train_dataloader, lr_scheduler
+ )
+
+ # Move vae and unet to device
+ vae.to(accelerator.device)
+ unet.to(accelerator.device)
+
+ # Keep vae and unet in eval model as we don't train these
+ vae.eval()
+ unet.eval()
+
+ unet = ipex.optimize(unet, dtype=torch.bfloat16, inplace=True)
+ vae = ipex.optimize(vae, dtype=torch.bfloat16, inplace=True)
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ accelerator.init_trackers("textual_inversion", config=vars(args))
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ # Only show the progress bar once on each machine.
+ progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
+ progress_bar.set_description("Steps")
+ global_step = 0
+
+ text_encoder.train()
+ text_encoder, optimizer = ipex.optimize(text_encoder, optimizer=optimizer, dtype=torch.bfloat16)
+
+ for epoch in range(args.num_train_epochs):
+ for step, batch in enumerate(train_dataloader):
+ with torch.cpu.amp.autocast(enabled=True, dtype=torch.bfloat16):
+ with accelerator.accumulate(text_encoder):
+ # Convert images to latent space
+ latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach()
+ latents = latents * vae.config.scaling_factor
+
+ # Sample noise that we'll add to the latents
+ noise = torch.randn(latents.shape).to(latents.device)
+ bsz = latents.shape[0]
+ # Sample a random timestep for each image
+ timesteps = torch.randint(
+ 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device
+ ).long()
+
+ # Add noise to the latents according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
+
+ # Get the text embedding for conditioning
+ encoder_hidden_states = text_encoder(batch["input_ids"])[0]
+
+ # Predict the noise residual
+ model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
+
+ # Get the target for loss depending on the prediction type
+ if noise_scheduler.config.prediction_type == "epsilon":
+ target = noise
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
+ else:
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
+
+ loss = F.mse_loss(model_pred, target, reduction="none").mean([1, 2, 3]).mean()
+ accelerator.backward(loss)
+
+ # Zero out the gradients for all token embeddings except the newly added
+ # embeddings for the concept, as we only want to optimize the concept embeddings
+ if accelerator.num_processes > 1:
+ grads = text_encoder.module.get_input_embeddings().weight.grad
+ else:
+ grads = text_encoder.get_input_embeddings().weight.grad
+ # Get the index for tokens that we want to zero the grads for
+ index_grads_to_zero = torch.arange(len(tokenizer)) != placeholder_token_id
+ grads.data[index_grads_to_zero, :] = grads.data[index_grads_to_zero, :].fill_(0)
+
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ global_step += 1
+ if global_step % args.save_steps == 0:
+ save_path = os.path.join(args.output_dir, f"learned_embeds-steps-{global_step}.bin")
+ save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path)
+
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+ accelerator.log(logs, step=global_step)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ accelerator.wait_for_everyone()
+
+ # Create the pipeline using using the trained modules and save it.
+ if accelerator.is_main_process:
+ if args.push_to_hub and args.only_save_embeds:
+ logger.warning("Enabling full model saving because --push_to_hub=True was specified.")
+ save_full_model = True
+ else:
+ save_full_model = not args.only_save_embeds
+ if save_full_model:
+ pipeline = StableDiffusionPipeline(
+ text_encoder=accelerator.unwrap_model(text_encoder),
+ vae=vae,
+ unet=unet,
+ tokenizer=tokenizer,
+ scheduler=PNDMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler"),
+ safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"),
+ feature_extractor=CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32"),
+ )
+ pipeline.save_pretrained(args.output_dir)
+ # Save the newly trained embeddings
+ save_path = os.path.join(args.output_dir, "learned_embeds.bin")
+ save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path)
+
+ if args.push_to_hub:
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/diffusers/examples/research_projects/intel_opts/textual_inversion_dfq/README.md b/diffusers/examples/research_projects/intel_opts/textual_inversion_dfq/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..4a227cdb4d63585cc0f0ab76424be8a0b2c5b604
--- /dev/null
+++ b/diffusers/examples/research_projects/intel_opts/textual_inversion_dfq/README.md
@@ -0,0 +1,93 @@
+# Distillation for quantization on Textual Inversion models to personalize text2image
+
+[Textual inversion](https://arxiv.org/abs/2208.01618) is a method to personalize text2image models like stable diffusion on your own images._By using just 3-5 images new concepts can be taught to Stable Diffusion and the model personalized on your own images_
+The `textual_inversion.py` script shows how to implement the training procedure and adapt it for stable diffusion.
+We have enabled distillation for quantization in `textual_inversion.py` to do quantization aware training as well as distillation on the model generated by Textual Inversion method.
+
+## Installing the dependencies
+
+Before running the scripts, make sure to install the library's training dependencies:
+
+```bash
+pip install -r requirements.txt
+```
+
+## Prepare Datasets
+
+One picture which is from the huggingface datasets [sd-concepts-library/dicoo2](https://huggingface.co/sd-concepts-library/dicoo2) is needed, and save it to the `./dicoo` directory. The picture is shown below:
+
+
+
+
+
+## Get a FP32 Textual Inversion model
+
+Use the following command to fine-tune the Stable Diffusion model on the above dataset to obtain the FP32 Textual Inversion model.
+
+```bash
+export MODEL_NAME="CompVis/stable-diffusion-v1-4"
+export DATA_DIR="./dicoo"
+
+accelerate launch textual_inversion.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --train_data_dir=$DATA_DIR \
+ --learnable_property="object" \
+ --placeholder_token="" --initializer_token="toy" \
+ --resolution=512 \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=4 \
+ --max_train_steps=3000 \
+ --learning_rate=5.0e-04 --scale_lr \
+ --lr_scheduler="constant" \
+ --lr_warmup_steps=0 \
+ --output_dir="dicoo_model"
+```
+
+## Do distillation for quantization
+
+Distillation for quantization is a method that combines [intermediate layer knowledge distillation](https://github.com/intel/neural-compressor/blob/master/docs/source/distillation.md#intermediate-layer-knowledge-distillation) and [quantization aware training](https://github.com/intel/neural-compressor/blob/master/docs/source/quantization.md#quantization-aware-training) in the same training process to improve the performance of the quantized model. Provided a FP32 model, the distillation for quantization approach will take this model itself as the teacher model and transfer the knowledges of the specified layers to the student model, i.e. quantized version of the FP32 model, during the quantization aware training process.
+
+Once you have the FP32 Textual Inversion model, the following command will take the FP32 Textual Inversion model as input to do distillation for quantization and generate the INT8 Textual Inversion model.
+
+```bash
+export FP32_MODEL_NAME="./dicoo_model"
+export DATA_DIR="./dicoo"
+
+accelerate launch textual_inversion.py \
+ --pretrained_model_name_or_path=$FP32_MODEL_NAME \
+ --train_data_dir=$DATA_DIR \
+ --use_ema --learnable_property="object" \
+ --placeholder_token="" --initializer_token="toy" \
+ --resolution=512 \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=4 \
+ --max_train_steps=300 \
+ --learning_rate=5.0e-04 --max_grad_norm=3 \
+ --lr_scheduler="constant" \
+ --lr_warmup_steps=0 \
+ --output_dir="int8_model" \
+ --do_quantization --do_distillation --verify_loading
+```
+
+After the distillation for quantization process, the quantized UNet would be 4 times smaller (3279MB -> 827MB).
+
+## Inference
+
+Once you have trained a INT8 model with the above command, the inference can be done simply using the `text2images.py` script. Make sure to include the `placeholder_token` in your prompt.
+
+```bash
+export INT8_MODEL_NAME="./int8_model"
+
+python text2images.py \
+ --pretrained_model_name_or_path=$INT8_MODEL_NAME \
+ --caption "a lovely in red dress and hat, in the snowly and brightly night, with many brighly buildings." \
+ --images_num 4
+```
+
+Here is the comparison of images generated by the FP32 model (left) and INT8 model (right) respectively:
+
+
+
+
+
+
diff --git a/diffusers/examples/research_projects/intel_opts/textual_inversion_dfq/requirements.txt b/diffusers/examples/research_projects/intel_opts/textual_inversion_dfq/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..cbd4c957be441a1aaf9a52e7ff02d772cb9d302b
--- /dev/null
+++ b/diffusers/examples/research_projects/intel_opts/textual_inversion_dfq/requirements.txt
@@ -0,0 +1,7 @@
+accelerate
+torchvision
+transformers>=4.25.0
+ftfy
+tensorboard
+modelcards
+neural-compressor
\ No newline at end of file
diff --git a/diffusers/examples/research_projects/intel_opts/textual_inversion_dfq/text2images.py b/diffusers/examples/research_projects/intel_opts/textual_inversion_dfq/text2images.py
new file mode 100644
index 0000000000000000000000000000000000000000..a99d727712eb44b875576443837c81a442c72a6f
--- /dev/null
+++ b/diffusers/examples/research_projects/intel_opts/textual_inversion_dfq/text2images.py
@@ -0,0 +1,112 @@
+import argparse
+import math
+import os
+
+import torch
+from neural_compressor.utils.pytorch import load
+from PIL import Image
+from transformers import CLIPTextModel, CLIPTokenizer
+
+from diffusers import AutoencoderKL, StableDiffusionPipeline, UNet2DConditionModel
+
+
+def parse_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "-m",
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "-c",
+ "--caption",
+ type=str,
+ default="robotic cat with wings",
+ help="Text used to generate images.",
+ )
+ parser.add_argument(
+ "-n",
+ "--images_num",
+ type=int,
+ default=4,
+ help="How much images to generate.",
+ )
+ parser.add_argument(
+ "-s",
+ "--seed",
+ type=int,
+ default=42,
+ help="Seed for random process.",
+ )
+ parser.add_argument(
+ "-ci",
+ "--cuda_id",
+ type=int,
+ default=0,
+ help="cuda_id.",
+ )
+ args = parser.parse_args()
+ return args
+
+
+def image_grid(imgs, rows, cols):
+ if not len(imgs) == rows * cols:
+ raise ValueError("The specified number of rows and columns are not correct.")
+
+ w, h = imgs[0].size
+ grid = Image.new("RGB", size=(cols * w, rows * h))
+ grid_w, grid_h = grid.size
+
+ for i, img in enumerate(imgs):
+ grid.paste(img, box=(i % cols * w, i // cols * h))
+ return grid
+
+
+def generate_images(
+ pipeline,
+ prompt="robotic cat with wings",
+ guidance_scale=7.5,
+ num_inference_steps=50,
+ num_images_per_prompt=1,
+ seed=42,
+):
+ generator = torch.Generator(pipeline.device).manual_seed(seed)
+ images = pipeline(
+ prompt,
+ guidance_scale=guidance_scale,
+ num_inference_steps=num_inference_steps,
+ generator=generator,
+ num_images_per_prompt=num_images_per_prompt,
+ ).images
+ _rows = int(math.sqrt(num_images_per_prompt))
+ grid = image_grid(images, rows=_rows, cols=num_images_per_prompt // _rows)
+ return grid, images
+
+
+args = parse_args()
+# Load models and create wrapper for stable diffusion
+tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
+text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
+vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
+unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
+
+pipeline = StableDiffusionPipeline.from_pretrained(
+ args.pretrained_model_name_or_path, text_encoder=text_encoder, vae=vae, unet=unet, tokenizer=tokenizer
+)
+pipeline.safety_checker = lambda images, clip_input: (images, False)
+if os.path.exists(os.path.join(args.pretrained_model_name_or_path, "best_model.pt")):
+ unet = load(args.pretrained_model_name_or_path, model=unet)
+ unet.eval()
+ setattr(pipeline, "unet", unet)
+else:
+ unet = unet.to(torch.device("cuda", args.cuda_id))
+pipeline = pipeline.to(unet.device)
+grid, images = generate_images(pipeline, prompt=args.caption, num_images_per_prompt=args.images_num, seed=args.seed)
+grid.save(os.path.join(args.pretrained_model_name_or_path, "{}.png".format("_".join(args.caption.split()))))
+dirname = os.path.join(args.pretrained_model_name_or_path, "_".join(args.caption.split()))
+os.makedirs(dirname, exist_ok=True)
+for idx, image in enumerate(images):
+ image.save(os.path.join(dirname, "{}.png".format(idx + 1)))
diff --git a/diffusers/examples/research_projects/intel_opts/textual_inversion_dfq/textual_inversion.py b/diffusers/examples/research_projects/intel_opts/textual_inversion_dfq/textual_inversion.py
new file mode 100644
index 0000000000000000000000000000000000000000..43667187596ef9b038d2ad03cbbdb02c0a6e40cf
--- /dev/null
+++ b/diffusers/examples/research_projects/intel_opts/textual_inversion_dfq/textual_inversion.py
@@ -0,0 +1,996 @@
+import argparse
+import itertools
+import math
+import os
+import random
+from pathlib import Path
+from typing import Iterable
+
+import numpy as np
+import PIL
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+from accelerate import Accelerator
+from accelerate.utils import ProjectConfiguration, set_seed
+from huggingface_hub import create_repo, upload_folder
+from neural_compressor.utils import logger
+from packaging import version
+from PIL import Image
+from torch.utils.data import Dataset
+from torchvision import transforms
+from tqdm.auto import tqdm
+from transformers import CLIPTextModel, CLIPTokenizer
+
+from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
+from diffusers.optimization import get_scheduler
+from diffusers.utils import make_image_grid
+
+
+if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
+ PIL_INTERPOLATION = {
+ "linear": PIL.Image.Resampling.BILINEAR,
+ "bilinear": PIL.Image.Resampling.BILINEAR,
+ "bicubic": PIL.Image.Resampling.BICUBIC,
+ "lanczos": PIL.Image.Resampling.LANCZOS,
+ "nearest": PIL.Image.Resampling.NEAREST,
+ }
+else:
+ PIL_INTERPOLATION = {
+ "linear": PIL.Image.LINEAR,
+ "bilinear": PIL.Image.BILINEAR,
+ "bicubic": PIL.Image.BICUBIC,
+ "lanczos": PIL.Image.LANCZOS,
+ "nearest": PIL.Image.NEAREST,
+ }
+# ------------------------------------------------------------------------------
+
+
+def save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path):
+ logger.info("Saving embeddings")
+ learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[placeholder_token_id]
+ learned_embeds_dict = {args.placeholder_token: learned_embeds.detach().cpu()}
+ torch.save(learned_embeds_dict, save_path)
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Example of distillation for quantization on Textual Inversion.")
+ parser.add_argument(
+ "--save_steps",
+ type=int,
+ default=500,
+ help="Save learned_embeds.bin every X updates steps.",
+ )
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--tokenizer_name",
+ type=str,
+ default=None,
+ help="Pretrained tokenizer name or path if not the same as model_name",
+ )
+ parser.add_argument(
+ "--train_data_dir", type=str, default=None, required=True, help="A folder containing the training data."
+ )
+ parser.add_argument(
+ "--placeholder_token",
+ type=str,
+ default=None,
+ required=True,
+ help="A token to use as a placeholder for the concept.",
+ )
+ parser.add_argument(
+ "--initializer_token", type=str, default=None, required=True, help="A token to use as initializer word."
+ )
+ parser.add_argument("--learnable_property", type=str, default="object", help="Choose between 'object' and 'style'")
+ parser.add_argument("--repeats", type=int, default=100, help="How many times to repeat the training data.")
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="text-inversion-model",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+ parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution"
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=100)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=5000,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default="no",
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose"
+ "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
+ "and an Nvidia Ampere GPU."
+ ),
+ )
+ parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--do_quantization", action="store_true", help="Whether or not to do quantization.")
+ parser.add_argument("--do_distillation", action="store_true", help="Whether or not to do distillation.")
+ parser.add_argument(
+ "--verify_loading", action="store_true", help="Whether or not to verify the loading of the quantized model."
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+
+ args = parser.parse_args()
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ if args.train_data_dir is None:
+ raise ValueError("You must specify a train data directory.")
+
+ return args
+
+
+imagenet_templates_small = [
+ "a photo of a {}",
+ "a rendering of a {}",
+ "a cropped photo of the {}",
+ "the photo of a {}",
+ "a photo of a clean {}",
+ "a photo of a dirty {}",
+ "a dark photo of the {}",
+ "a photo of my {}",
+ "a photo of the cool {}",
+ "a close-up photo of a {}",
+ "a bright photo of the {}",
+ "a cropped photo of a {}",
+ "a photo of the {}",
+ "a good photo of the {}",
+ "a photo of one {}",
+ "a close-up photo of the {}",
+ "a rendition of the {}",
+ "a photo of the clean {}",
+ "a rendition of a {}",
+ "a photo of a nice {}",
+ "a good photo of a {}",
+ "a photo of the nice {}",
+ "a photo of the small {}",
+ "a photo of the weird {}",
+ "a photo of the large {}",
+ "a photo of a cool {}",
+ "a photo of a small {}",
+]
+
+imagenet_style_templates_small = [
+ "a painting in the style of {}",
+ "a rendering in the style of {}",
+ "a cropped painting in the style of {}",
+ "the painting in the style of {}",
+ "a clean painting in the style of {}",
+ "a dirty painting in the style of {}",
+ "a dark painting in the style of {}",
+ "a picture in the style of {}",
+ "a cool painting in the style of {}",
+ "a close-up painting in the style of {}",
+ "a bright painting in the style of {}",
+ "a cropped painting in the style of {}",
+ "a good painting in the style of {}",
+ "a close-up painting in the style of {}",
+ "a rendition in the style of {}",
+ "a nice painting in the style of {}",
+ "a small painting in the style of {}",
+ "a weird painting in the style of {}",
+ "a large painting in the style of {}",
+]
+
+
+# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14
+class EMAModel:
+ """
+ Exponential Moving Average of models weights
+ """
+
+ def __init__(self, parameters: Iterable[torch.nn.Parameter], decay=0.9999):
+ parameters = list(parameters)
+ self.shadow_params = [p.clone().detach() for p in parameters]
+
+ self.decay = decay
+ self.optimization_step = 0
+
+ def get_decay(self, optimization_step):
+ """
+ Compute the decay factor for the exponential moving average.
+ """
+ value = (1 + optimization_step) / (10 + optimization_step)
+ return 1 - min(self.decay, value)
+
+ @torch.no_grad()
+ def step(self, parameters):
+ parameters = list(parameters)
+
+ self.optimization_step += 1
+ self.decay = self.get_decay(self.optimization_step)
+
+ for s_param, param in zip(self.shadow_params, parameters):
+ if param.requires_grad:
+ tmp = self.decay * (s_param - param)
+ s_param.sub_(tmp)
+ else:
+ s_param.copy_(param)
+
+ torch.cuda.empty_cache()
+
+ def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None:
+ """
+ Copy current averaged parameters into given collection of parameters.
+ Args:
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
+ updated with the stored moving averages. If `None`, the
+ parameters with which this `ExponentialMovingAverage` was
+ initialized will be used.
+ """
+ parameters = list(parameters)
+ for s_param, param in zip(self.shadow_params, parameters):
+ param.data.copy_(s_param.data)
+
+ def to(self, device=None, dtype=None) -> None:
+ r"""Move internal buffers of the ExponentialMovingAverage to `device`.
+ Args:
+ device: like `device` argument to `torch.Tensor.to`
+ """
+ # .to() on the tensors handles None correctly
+ self.shadow_params = [
+ p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device)
+ for p in self.shadow_params
+ ]
+
+
+class TextualInversionDataset(Dataset):
+ def __init__(
+ self,
+ data_root,
+ tokenizer,
+ learnable_property="object", # [object, style]
+ size=512,
+ repeats=100,
+ interpolation="bicubic",
+ flip_p=0.5,
+ set="train",
+ placeholder_token="*",
+ center_crop=False,
+ ):
+ self.data_root = data_root
+ self.tokenizer = tokenizer
+ self.learnable_property = learnable_property
+ self.size = size
+ self.placeholder_token = placeholder_token
+ self.center_crop = center_crop
+ self.flip_p = flip_p
+
+ self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)]
+
+ self.num_images = len(self.image_paths)
+ self._length = self.num_images
+
+ if set == "train":
+ self._length = self.num_images * repeats
+
+ self.interpolation = {
+ "linear": PIL_INTERPOLATION["linear"],
+ "bilinear": PIL_INTERPOLATION["bilinear"],
+ "bicubic": PIL_INTERPOLATION["bicubic"],
+ "lanczos": PIL_INTERPOLATION["lanczos"],
+ }[interpolation]
+
+ self.templates = imagenet_style_templates_small if learnable_property == "style" else imagenet_templates_small
+ self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p)
+
+ def __len__(self):
+ return self._length
+
+ def __getitem__(self, i):
+ example = {}
+ image = Image.open(self.image_paths[i % self.num_images])
+
+ if not image.mode == "RGB":
+ image = image.convert("RGB")
+
+ placeholder_string = self.placeholder_token
+ text = random.choice(self.templates).format(placeholder_string)
+
+ example["input_ids"] = self.tokenizer(
+ text,
+ padding="max_length",
+ truncation=True,
+ max_length=self.tokenizer.model_max_length,
+ return_tensors="pt",
+ ).input_ids[0]
+
+ # default to score-sde preprocessing
+ img = np.array(image).astype(np.uint8)
+
+ if self.center_crop:
+ crop = min(img.shape[0], img.shape[1])
+ (
+ h,
+ w,
+ ) = (
+ img.shape[0],
+ img.shape[1],
+ )
+ img = img[(h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2]
+
+ image = Image.fromarray(img)
+ image = image.resize((self.size, self.size), resample=self.interpolation)
+
+ image = self.flip_transform(image)
+ image = np.array(image).astype(np.uint8)
+ image = (image / 127.5 - 1.0).astype(np.float32)
+
+ example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1)
+ return example
+
+
+def freeze_params(params):
+ for param in params:
+ param.requires_grad = False
+
+
+def generate_images(pipeline, prompt="", guidance_scale=7.5, num_inference_steps=50, num_images_per_prompt=1, seed=42):
+ generator = torch.Generator(pipeline.device).manual_seed(seed)
+ images = pipeline(
+ prompt,
+ guidance_scale=guidance_scale,
+ num_inference_steps=num_inference_steps,
+ generator=generator,
+ num_images_per_prompt=num_images_per_prompt,
+ ).images
+ _rows = int(math.sqrt(num_images_per_prompt))
+ grid = make_image_grid(images, rows=_rows, cols=num_images_per_prompt // _rows)
+ return grid
+
+
+def main():
+ args = parse_args()
+ logging_dir = os.path.join(args.output_dir, args.logging_dir)
+
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with="tensorboard",
+ project_config=accelerator_project_config,
+ )
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
+ ).repo_id
+
+ # Load the tokenizer and add the placeholder token as a additional special token
+ if args.tokenizer_name:
+ tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
+ elif args.pretrained_model_name_or_path:
+ tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
+
+ # Load models and create wrapper for stable diffusion
+ noise_scheduler = DDPMScheduler.from_config(args.pretrained_model_name_or_path, subfolder="scheduler")
+ text_encoder = CLIPTextModel.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="text_encoder",
+ revision=args.revision,
+ )
+ vae = AutoencoderKL.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="vae",
+ revision=args.revision,
+ )
+ unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="unet",
+ revision=args.revision,
+ )
+
+ train_unet = False
+ # Freeze vae and unet
+ freeze_params(vae.parameters())
+ if not args.do_quantization and not args.do_distillation:
+ # Add the placeholder token in tokenizer
+ num_added_tokens = tokenizer.add_tokens(args.placeholder_token)
+ if num_added_tokens == 0:
+ raise ValueError(
+ f"The tokenizer already contains the token {args.placeholder_token}. Please pass a different"
+ " `placeholder_token` that is not already in the tokenizer."
+ )
+
+ # Convert the initializer_token, placeholder_token to ids
+ token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False)
+ # Check if initializer_token is a single token or a sequence of tokens
+ if len(token_ids) > 1:
+ raise ValueError("The initializer token must be a single token.")
+
+ initializer_token_id = token_ids[0]
+ placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token)
+ # Resize the token embeddings as we are adding new special tokens to the tokenizer
+ text_encoder.resize_token_embeddings(len(tokenizer))
+
+ # Initialise the newly added placeholder token with the embeddings of the initializer token
+ token_embeds = text_encoder.get_input_embeddings().weight.data
+ token_embeds[placeholder_token_id] = token_embeds[initializer_token_id]
+
+ freeze_params(unet.parameters())
+ # Freeze all parameters except for the token embeddings in text encoder
+ params_to_freeze = itertools.chain(
+ text_encoder.text_model.encoder.parameters(),
+ text_encoder.text_model.final_layer_norm.parameters(),
+ text_encoder.text_model.embeddings.position_embedding.parameters(),
+ )
+ freeze_params(params_to_freeze)
+ else:
+ train_unet = True
+ freeze_params(text_encoder.parameters())
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ # Initialize the optimizer
+ optimizer = torch.optim.AdamW(
+ # only optimize the unet or embeddings of text_encoder
+ unet.parameters() if train_unet else text_encoder.get_input_embeddings().parameters(),
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ train_dataset = TextualInversionDataset(
+ data_root=args.train_data_dir,
+ tokenizer=tokenizer,
+ size=args.resolution,
+ placeholder_token=args.placeholder_token,
+ repeats=args.repeats,
+ learnable_property=args.learnable_property,
+ center_crop=args.center_crop,
+ set="train",
+ )
+ train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=args.train_batch_size, shuffle=True)
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
+ )
+
+ if not train_unet:
+ text_encoder = accelerator.prepare(text_encoder)
+ unet.to(accelerator.device)
+ unet.eval()
+ else:
+ unet = accelerator.prepare(unet)
+ text_encoder.to(accelerator.device)
+ text_encoder.eval()
+ optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)
+
+ # Move vae to device
+ vae.to(accelerator.device)
+
+ # Keep vae in eval model as we don't train these
+ vae.eval()
+
+ compression_manager = None
+
+ def train_func(model):
+ if train_unet:
+ unet_ = model
+ text_encoder_ = text_encoder
+ else:
+ unet_ = unet
+ text_encoder_ = model
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ accelerator.init_trackers("textual_inversion", config=vars(args))
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ # Only show the progress bar once on each machine.
+ progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
+ progress_bar.set_description("Steps")
+ global_step = 0
+
+ if train_unet and args.use_ema:
+ ema_unet = EMAModel(unet_.parameters())
+
+ for epoch in range(args.num_train_epochs):
+ model.train()
+ train_loss = 0.0
+ for step, batch in enumerate(train_dataloader):
+ with accelerator.accumulate(model):
+ # Convert images to latent space
+ latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach()
+ latents = latents * 0.18215
+
+ # Sample noise that we'll add to the latents
+ noise = torch.randn(latents.shape).to(latents.device)
+ bsz = latents.shape[0]
+ # Sample a random timestep for each image
+ timesteps = torch.randint(
+ 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device
+ ).long()
+
+ # Add noise to the latents according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
+
+ # Get the text embedding for conditioning
+ encoder_hidden_states = text_encoder_(batch["input_ids"])[0]
+
+ # Predict the noise residual
+ model_pred = unet_(noisy_latents, timesteps, encoder_hidden_states).sample
+
+ loss = F.mse_loss(model_pred, noise, reduction="none").mean([1, 2, 3]).mean()
+ if train_unet and compression_manager:
+ unet_inputs = {
+ "sample": noisy_latents,
+ "timestep": timesteps,
+ "encoder_hidden_states": encoder_hidden_states,
+ }
+ loss = compression_manager.callbacks.on_after_compute_loss(unet_inputs, model_pred, loss)
+
+ # Gather the losses across all processes for logging (if we use distributed training).
+ avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
+ train_loss += avg_loss.item() / args.gradient_accumulation_steps
+
+ # Backpropagate
+ accelerator.backward(loss)
+
+ if train_unet:
+ if accelerator.sync_gradients:
+ accelerator.clip_grad_norm_(unet_.parameters(), args.max_grad_norm)
+ else:
+ # Zero out the gradients for all token embeddings except the newly added
+ # embeddings for the concept, as we only want to optimize the concept embeddings
+ if accelerator.num_processes > 1:
+ grads = text_encoder_.module.get_input_embeddings().weight.grad
+ else:
+ grads = text_encoder_.get_input_embeddings().weight.grad
+ # Get the index for tokens that we want to zero the grads for
+ index_grads_to_zero = torch.arange(len(tokenizer)) != placeholder_token_id
+ grads.data[index_grads_to_zero, :] = grads.data[index_grads_to_zero, :].fill_(0)
+
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ if train_unet and args.use_ema:
+ ema_unet.step(unet_.parameters())
+ progress_bar.update(1)
+ global_step += 1
+ accelerator.log({"train_loss": train_loss}, step=global_step)
+ train_loss = 0.0
+ if not train_unet and global_step % args.save_steps == 0:
+ save_path = os.path.join(args.output_dir, f"learned_embeds-steps-{global_step}.bin")
+ save_progress(text_encoder_, placeholder_token_id, accelerator, args, save_path)
+
+ logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+ accelerator.log(logs, step=global_step)
+
+ if global_step >= args.max_train_steps:
+ break
+ accelerator.wait_for_everyone()
+
+ if train_unet and args.use_ema:
+ ema_unet.copy_to(unet_.parameters())
+
+ if not train_unet:
+ return text_encoder_
+
+ if not train_unet:
+ text_encoder = train_func(text_encoder)
+ else:
+ import copy
+
+ model = copy.deepcopy(unet)
+ confs = []
+ if args.do_quantization:
+ from neural_compressor import QuantizationAwareTrainingConfig
+
+ q_conf = QuantizationAwareTrainingConfig()
+ confs.append(q_conf)
+
+ if args.do_distillation:
+ teacher_model = copy.deepcopy(model)
+
+ def attention_fetcher(x):
+ return x.sample
+
+ layer_mappings = [
+ [
+ [
+ "conv_in",
+ ]
+ ],
+ [
+ [
+ "time_embedding",
+ ]
+ ],
+ [["down_blocks.0.attentions.0", attention_fetcher]],
+ [["down_blocks.0.attentions.1", attention_fetcher]],
+ [
+ [
+ "down_blocks.0.resnets.0",
+ ]
+ ],
+ [
+ [
+ "down_blocks.0.resnets.1",
+ ]
+ ],
+ [
+ [
+ "down_blocks.0.downsamplers.0",
+ ]
+ ],
+ [["down_blocks.1.attentions.0", attention_fetcher]],
+ [["down_blocks.1.attentions.1", attention_fetcher]],
+ [
+ [
+ "down_blocks.1.resnets.0",
+ ]
+ ],
+ [
+ [
+ "down_blocks.1.resnets.1",
+ ]
+ ],
+ [
+ [
+ "down_blocks.1.downsamplers.0",
+ ]
+ ],
+ [["down_blocks.2.attentions.0", attention_fetcher]],
+ [["down_blocks.2.attentions.1", attention_fetcher]],
+ [
+ [
+ "down_blocks.2.resnets.0",
+ ]
+ ],
+ [
+ [
+ "down_blocks.2.resnets.1",
+ ]
+ ],
+ [
+ [
+ "down_blocks.2.downsamplers.0",
+ ]
+ ],
+ [
+ [
+ "down_blocks.3.resnets.0",
+ ]
+ ],
+ [
+ [
+ "down_blocks.3.resnets.1",
+ ]
+ ],
+ [
+ [
+ "up_blocks.0.resnets.0",
+ ]
+ ],
+ [
+ [
+ "up_blocks.0.resnets.1",
+ ]
+ ],
+ [
+ [
+ "up_blocks.0.resnets.2",
+ ]
+ ],
+ [
+ [
+ "up_blocks.0.upsamplers.0",
+ ]
+ ],
+ [["up_blocks.1.attentions.0", attention_fetcher]],
+ [["up_blocks.1.attentions.1", attention_fetcher]],
+ [["up_blocks.1.attentions.2", attention_fetcher]],
+ [
+ [
+ "up_blocks.1.resnets.0",
+ ]
+ ],
+ [
+ [
+ "up_blocks.1.resnets.1",
+ ]
+ ],
+ [
+ [
+ "up_blocks.1.resnets.2",
+ ]
+ ],
+ [
+ [
+ "up_blocks.1.upsamplers.0",
+ ]
+ ],
+ [["up_blocks.2.attentions.0", attention_fetcher]],
+ [["up_blocks.2.attentions.1", attention_fetcher]],
+ [["up_blocks.2.attentions.2", attention_fetcher]],
+ [
+ [
+ "up_blocks.2.resnets.0",
+ ]
+ ],
+ [
+ [
+ "up_blocks.2.resnets.1",
+ ]
+ ],
+ [
+ [
+ "up_blocks.2.resnets.2",
+ ]
+ ],
+ [
+ [
+ "up_blocks.2.upsamplers.0",
+ ]
+ ],
+ [["up_blocks.3.attentions.0", attention_fetcher]],
+ [["up_blocks.3.attentions.1", attention_fetcher]],
+ [["up_blocks.3.attentions.2", attention_fetcher]],
+ [
+ [
+ "up_blocks.3.resnets.0",
+ ]
+ ],
+ [
+ [
+ "up_blocks.3.resnets.1",
+ ]
+ ],
+ [
+ [
+ "up_blocks.3.resnets.2",
+ ]
+ ],
+ [["mid_block.attentions.0", attention_fetcher]],
+ [
+ [
+ "mid_block.resnets.0",
+ ]
+ ],
+ [
+ [
+ "mid_block.resnets.1",
+ ]
+ ],
+ [
+ [
+ "conv_out",
+ ]
+ ],
+ ]
+ layer_names = [layer_mapping[0][0] for layer_mapping in layer_mappings]
+ if not set(layer_names).issubset([n[0] for n in model.named_modules()]):
+ raise ValueError(
+ "Provided model is not compatible with the default layer_mappings, "
+ 'please use the model fine-tuned from "CompVis/stable-diffusion-v1-4", '
+ "or modify the layer_mappings variable to fit your model."
+ f"\nDefault layer_mappings are as such:\n{layer_mappings}"
+ )
+ from neural_compressor.config import DistillationConfig, IntermediateLayersKnowledgeDistillationLossConfig
+
+ distillation_criterion = IntermediateLayersKnowledgeDistillationLossConfig(
+ layer_mappings=layer_mappings,
+ loss_types=["MSE"] * len(layer_mappings),
+ loss_weights=[1.0 / len(layer_mappings)] * len(layer_mappings),
+ add_origin_loss=True,
+ )
+ d_conf = DistillationConfig(teacher_model=teacher_model, criterion=distillation_criterion)
+ confs.append(d_conf)
+
+ from neural_compressor.training import prepare_compression
+
+ compression_manager = prepare_compression(model, confs)
+ compression_manager.callbacks.on_train_begin()
+ model = compression_manager.model
+ train_func(model)
+ compression_manager.callbacks.on_train_end()
+
+ # Save the resulting model and its corresponding configuration in the given directory
+ model.save(args.output_dir)
+
+ logger.info(f"Optimized model saved to: {args.output_dir}.")
+
+ # change to framework model for further use
+ model = model.model
+
+ # Create the pipeline using using the trained modules and save it.
+ templates = imagenet_style_templates_small if args.learnable_property == "style" else imagenet_templates_small
+ prompt = templates[0].format(args.placeholder_token)
+ if accelerator.is_main_process:
+ pipeline = StableDiffusionPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ text_encoder=accelerator.unwrap_model(text_encoder),
+ vae=vae,
+ unet=accelerator.unwrap_model(unet),
+ tokenizer=tokenizer,
+ )
+ pipeline.save_pretrained(args.output_dir)
+ pipeline = pipeline.to(unet.device)
+ baseline_model_images = generate_images(pipeline, prompt=prompt, seed=args.seed)
+ baseline_model_images.save(
+ os.path.join(args.output_dir, "{}_baseline_model.png".format("_".join(prompt.split())))
+ )
+
+ if not train_unet:
+ # Also save the newly trained embeddings
+ save_path = os.path.join(args.output_dir, "learned_embeds.bin")
+ save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path)
+ else:
+ setattr(pipeline, "unet", accelerator.unwrap_model(model))
+ if args.do_quantization:
+ pipeline = pipeline.to(torch.device("cpu"))
+
+ optimized_model_images = generate_images(pipeline, prompt=prompt, seed=args.seed)
+ optimized_model_images.save(
+ os.path.join(args.output_dir, "{}_optimized_model.png".format("_".join(prompt.split())))
+ )
+
+ if args.push_to_hub:
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ accelerator.end_training()
+
+ if args.do_quantization and args.verify_loading:
+ # Load the model obtained after Intel Neural Compressor quantization
+ from neural_compressor.utils.pytorch import load
+
+ loaded_model = load(args.output_dir, model=unet)
+ loaded_model.eval()
+
+ setattr(pipeline, "unet", loaded_model)
+ if args.do_quantization:
+ pipeline = pipeline.to(torch.device("cpu"))
+
+ loaded_model_images = generate_images(pipeline, prompt=prompt, seed=args.seed)
+ if loaded_model_images != optimized_model_images:
+ logger.info("The quantized model was not successfully loaded.")
+ else:
+ logger.info("The quantized model was successfully loaded.")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/diffusers/examples/research_projects/lora/README.md b/diffusers/examples/research_projects/lora/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..643f664ce1eba50a8871c9b9ba75f914acdd5f5f
--- /dev/null
+++ b/diffusers/examples/research_projects/lora/README.md
@@ -0,0 +1,83 @@
+# Stable Diffusion text-to-image fine-tuning
+This extended LoRA training script was authored by [haofanwang](https://github.com/haofanwang).
+This is an experimental LoRA extension of [this example](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_lora.py). We further support add LoRA layers for text encoder.
+
+## Training with LoRA
+
+Low-Rank Adaption of Large Language Models was first introduced by Microsoft in [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen*.
+
+In a nutshell, LoRA allows adapting pretrained models by adding pairs of rank-decomposition matrices to existing weights and **only** training those newly added weights. This has a couple of advantages:
+
+- Previous pretrained weights are kept frozen so that model is not prone to [catastrophic forgetting](https://www.pnas.org/doi/10.1073/pnas.1611835114).
+- Rank-decomposition matrices have significantly fewer parameters than original model, which means that trained LoRA weights are easily portable.
+- LoRA attention layers allow to control to which extent the model is adapted toward new training images via a `scale` parameter.
+
+[cloneofsimo](https://github.com/cloneofsimo) was the first to try out LoRA training for Stable Diffusion in the popular [lora](https://github.com/cloneofsimo/lora) GitHub repository.
+
+With LoRA, it's possible to fine-tune Stable Diffusion on a custom image-caption pair dataset
+on consumer GPUs like Tesla T4, Tesla V100.
+
+### Training
+
+First, you need to set up your development environment as is explained in the [installation section](#installing-the-dependencies). Make sure to set the `MODEL_NAME` and `DATASET_NAME` environment variables. Here, we will use [Stable Diffusion v1-4](https://hf.co/CompVis/stable-diffusion-v1-4) and the [Narutos dataset](https://huggingface.co/datasets/lambdalabs/naruto-blip-captions).
+
+**___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.___**
+
+**___Note: It is quite useful to monitor the training progress by regularly generating sample images during training. [Weights and Biases](https://docs.wandb.ai/quickstart) is a nice solution to easily see generating images during training. All you need to do is to run `pip install wandb` before training to automatically log images.___**
+
+```bash
+export MODEL_NAME="CompVis/stable-diffusion-v1-4"
+export DATASET_NAME="lambdalabs/naruto-blip-captions"
+```
+
+For this example we want to directly store the trained LoRA embeddings on the Hub, so
+we need to be logged in and add the `--push_to_hub` flag.
+
+```bash
+huggingface-cli login
+```
+
+Now we can start training!
+
+```bash
+accelerate launch --mixed_precision="fp16" train_text_to_image_lora.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --dataset_name=$DATASET_NAME --caption_column="text" \
+ --resolution=512 --random_flip \
+ --train_batch_size=1 \
+ --num_train_epochs=100 --checkpointing_steps=5000 \
+ --learning_rate=1e-04 --lr_scheduler="constant" --lr_warmup_steps=0 \
+ --seed=42 \
+ --output_dir="sd-naruto-model-lora" \
+ --validation_prompt="cute dragon creature" --report_to="wandb"
+ --use_peft \
+ --lora_r=4 --lora_alpha=32 \
+ --lora_text_encoder_r=4 --lora_text_encoder_alpha=32
+```
+
+The above command will also run inference as fine-tuning progresses and log the results to Weights and Biases.
+
+**___Note: When using LoRA we can use a much higher learning rate compared to non-LoRA fine-tuning. Here we use *1e-4* instead of the usual *1e-5*. Also, by using LoRA, it's possible to run `train_text_to_image_lora.py` in consumer GPUs like T4 or V100.___**
+
+The final LoRA embedding weights have been uploaded to [sayakpaul/sd-model-finetuned-lora-t4](https://huggingface.co/sayakpaul/sd-model-finetuned-lora-t4). **___Note: [The final weights](https://huggingface.co/sayakpaul/sd-model-finetuned-lora-t4/blob/main/pytorch_lora_weights.bin) are only 3 MB in size, which is orders of magnitudes smaller than the original model.___**
+
+You can check some inference samples that were logged during the course of the fine-tuning process [here](https://wandb.ai/sayakpaul/text2image-fine-tune/runs/q4lc0xsw).
+
+### Inference
+
+Once you have trained a model using above command, the inference can be done simply using the `StableDiffusionPipeline` after loading the trained LoRA weights. You
+need to pass the `output_dir` for loading the LoRA weights which, in this case, is `sd-naruto-model-lora`.
+
+```python
+from diffusers import StableDiffusionPipeline
+import torch
+
+model_path = "sayakpaul/sd-model-finetuned-lora-t4"
+pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16)
+pipe.unet.load_attn_procs(model_path)
+pipe.to("cuda")
+
+prompt = "A naruto with green eyes and red legs."
+image = pipe(prompt, num_inference_steps=30, guidance_scale=7.5).images[0]
+image.save("naruto.png")
+```
\ No newline at end of file
diff --git a/diffusers/examples/research_projects/lora/requirements.txt b/diffusers/examples/research_projects/lora/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..89a1b73e70728b2395ca5f121f22def70f2076f9
--- /dev/null
+++ b/diffusers/examples/research_projects/lora/requirements.txt
@@ -0,0 +1,8 @@
+accelerate>=0.16.0
+torchvision
+transformers>=4.25.1
+datasets
+ftfy
+tensorboard
+Jinja2
+git+https://github.com/huggingface/peft.git
\ No newline at end of file
diff --git a/diffusers/examples/research_projects/lora/train_text_to_image_lora.py b/diffusers/examples/research_projects/lora/train_text_to_image_lora.py
new file mode 100644
index 0000000000000000000000000000000000000000..1ebc1422b06424077dee45760a61a7727b182b3e
--- /dev/null
+++ b/diffusers/examples/research_projects/lora/train_text_to_image_lora.py
@@ -0,0 +1,1026 @@
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Fine-tuning script for Stable Diffusion for text2image with support for LoRA."""
+
+import argparse
+import itertools
+import json
+import logging
+import math
+import os
+import random
+from pathlib import Path
+
+import datasets
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import ProjectConfiguration, set_seed
+from datasets import load_dataset
+from huggingface_hub import create_repo, upload_folder
+from packaging import version
+from torchvision import transforms
+from tqdm.auto import tqdm
+from transformers import CLIPTextModel, CLIPTokenizer
+
+import diffusers
+from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel
+from diffusers.loaders import AttnProcsLayers
+from diffusers.models.attention_processor import LoRAAttnProcessor
+from diffusers.optimization import get_scheduler
+from diffusers.utils import check_min_version, is_wandb_available
+from diffusers.utils.import_utils import is_xformers_available
+
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.14.0.dev0")
+
+logger = get_logger(__name__, log_level="INFO")
+
+
+def save_model_card(repo_id: str, images=None, base_model=str, dataset_name=str, repo_folder=None):
+ img_str = ""
+ for i, image in enumerate(images):
+ image.save(os.path.join(repo_folder, f"image_{i}.png"))
+ img_str += f"\n"
+
+ yaml = f"""
+---
+license: creativeml-openrail-m
+base_model: {base_model}
+tags:
+- stable-diffusion
+- stable-diffusion-diffusers
+- text-to-image
+- diffusers
+- diffusers-training
+- lora
+inference: true
+---
+ """
+ model_card = f"""
+# LoRA text2image fine-tuning - {repo_id}
+These are LoRA adaption weights for {base_model}. The weights were fine-tuned on the {dataset_name} dataset. You can find some example images in the following. \n
+{img_str}
+"""
+ with open(os.path.join(repo_folder, "README.md"), "w") as f:
+ f.write(yaml + model_card)
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--dataset_name",
+ type=str,
+ default=None,
+ help=(
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
+ " or to a folder containing files that 🤗 Datasets can understand."
+ ),
+ )
+ parser.add_argument(
+ "--dataset_config_name",
+ type=str,
+ default=None,
+ help="The config of the Dataset, leave as None if there's only one config.",
+ )
+ parser.add_argument(
+ "--train_data_dir",
+ type=str,
+ default=None,
+ help=(
+ "A folder containing the training data. Folder contents must follow the structure described in"
+ " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
+ " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
+ ),
+ )
+ parser.add_argument(
+ "--image_column", type=str, default="image", help="The column of the dataset containing an image."
+ )
+ parser.add_argument(
+ "--caption_column",
+ type=str,
+ default="text",
+ help="The column of the dataset containing a caption or a list of captions.",
+ )
+ parser.add_argument(
+ "--validation_prompt", type=str, default=None, help="A prompt that is sampled during training for inference."
+ )
+ parser.add_argument(
+ "--num_validation_images",
+ type=int,
+ default=4,
+ help="Number of images that should be generated during validation with `validation_prompt`.",
+ )
+ parser.add_argument(
+ "--validation_epochs",
+ type=int,
+ default=1,
+ help=(
+ "Run fine-tuning validation every X epochs. The validation process consists of running the prompt"
+ " `args.validation_prompt` multiple times: `args.num_validation_images`."
+ ),
+ )
+ parser.add_argument(
+ "--max_train_samples",
+ type=int,
+ default=None,
+ help=(
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ ),
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="sd-model-finetuned-lora",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--center_crop",
+ default=False,
+ action="store_true",
+ help=(
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
+ " cropped. The images will be resized to the resolution first before cropping."
+ ),
+ )
+ parser.add_argument(
+ "--random_flip",
+ action="store_true",
+ help="whether to randomly flip images horizontally",
+ )
+ parser.add_argument("--train_text_encoder", action="store_true", help="Whether to train the text encoder")
+
+ # lora args
+ parser.add_argument("--use_peft", action="store_true", help="Whether to use peft to support lora")
+ parser.add_argument("--lora_r", type=int, default=4, help="Lora rank, only used if use_lora is True")
+ parser.add_argument("--lora_alpha", type=int, default=32, help="Lora alpha, only used if lora is True")
+ parser.add_argument("--lora_dropout", type=float, default=0.0, help="Lora dropout, only used if use_lora is True")
+ parser.add_argument(
+ "--lora_bias",
+ type=str,
+ default="none",
+ help="Bias type for Lora. Can be 'none', 'all' or 'lora_only', only used if use_lora is True",
+ )
+ parser.add_argument(
+ "--lora_text_encoder_r",
+ type=int,
+ default=4,
+ help="Lora rank for text encoder, only used if `use_lora` and `train_text_encoder` are True",
+ )
+ parser.add_argument(
+ "--lora_text_encoder_alpha",
+ type=int,
+ default=32,
+ help="Lora alpha for text encoder, only used if `use_lora` and `train_text_encoder` are True",
+ )
+ parser.add_argument(
+ "--lora_text_encoder_dropout",
+ type=float,
+ default=0.0,
+ help="Lora dropout for text encoder, only used if `use_lora` and `train_text_encoder` are True",
+ )
+ parser.add_argument(
+ "--lora_text_encoder_bias",
+ type=str,
+ default="none",
+ help="Bias type for Lora. Can be 'none', 'all' or 'lora_only', only used if use_lora and `train_text_encoder` are True",
+ )
+
+ parser.add_argument(
+ "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=100)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=(
+ "Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
+ " See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
+ " for more docs"
+ ),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
+ )
+
+ args = parser.parse_args()
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ # Sanity checks
+ if args.dataset_name is None and args.train_data_dir is None:
+ raise ValueError("Need either a dataset name or a training folder.")
+
+ return args
+
+
+DATASET_NAME_MAPPING = {
+ "lambdalabs/naruto-blip-captions": ("image", "text"),
+}
+
+
+def main():
+ args = parse_args()
+ if args.report_to == "wandb" and args.hub_token is not None:
+ raise ValueError(
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
+ " Please use `huggingface-cli login` to authenticate with the Hub."
+ )
+
+ logging_dir = os.path.join(args.output_dir, args.logging_dir)
+
+ accelerator_project_config = ProjectConfiguration(
+ total_limit=args.checkpoints_total_limit, project_dir=args.output_dir, logging_dir=logging_dir
+ )
+
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ )
+
+ # Disable AMP for MPS.
+ if torch.backends.mps.is_available():
+ accelerator.native_amp = False
+
+ if args.report_to == "wandb":
+ if not is_wandb_available():
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
+ import wandb
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ datasets.utils.logging.set_verbosity_warning()
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ datasets.utils.logging.set_verbosity_error()
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
+ ).repo_id
+
+ # Load scheduler, tokenizer and models.
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
+ tokenizer = CLIPTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
+ )
+ text_encoder = CLIPTextModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
+ )
+ vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
+ unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
+ )
+
+ # For mixed precision training we cast the text_encoder and vae weights to half-precision
+ # as these models are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ if args.use_peft:
+ from peft import LoraConfig, LoraModel, get_peft_model_state_dict, set_peft_model_state_dict
+
+ UNET_TARGET_MODULES = ["to_q", "to_v", "query", "value"]
+ TEXT_ENCODER_TARGET_MODULES = ["q_proj", "v_proj"]
+
+ config = LoraConfig(
+ r=args.lora_r,
+ lora_alpha=args.lora_alpha,
+ target_modules=UNET_TARGET_MODULES,
+ lora_dropout=args.lora_dropout,
+ bias=args.lora_bias,
+ )
+ unet = LoraModel(config, unet)
+
+ vae.requires_grad_(False)
+ if args.train_text_encoder:
+ config = LoraConfig(
+ r=args.lora_text_encoder_r,
+ lora_alpha=args.lora_text_encoder_alpha,
+ target_modules=TEXT_ENCODER_TARGET_MODULES,
+ lora_dropout=args.lora_text_encoder_dropout,
+ bias=args.lora_text_encoder_bias,
+ )
+ text_encoder = LoraModel(config, text_encoder)
+ else:
+ # freeze parameters of models to save more memory
+ unet.requires_grad_(False)
+ vae.requires_grad_(False)
+
+ text_encoder.requires_grad_(False)
+
+ # now we will add new LoRA weights to the attention layers
+ # It's important to realize here how many attention weights will be added and of which sizes
+ # The sizes of the attention layers consist only of two different variables:
+ # 1) - the "hidden_size", which is increased according to `unet.config.block_out_channels`.
+ # 2) - the "cross attention size", which is set to `unet.config.cross_attention_dim`.
+
+ # Let's first see how many attention processors we will have to set.
+ # For Stable Diffusion, it should be equal to:
+ # - down blocks (2x attention layers) * (2x transformer layers) * (3x down blocks) = 12
+ # - mid blocks (2x attention layers) * (1x transformer layers) * (1x mid blocks) = 2
+ # - up blocks (2x attention layers) * (3x transformer layers) * (3x down blocks) = 18
+ # => 32 layers
+
+ # Set correct lora layers
+ lora_attn_procs = {}
+ for name in unet.attn_processors.keys():
+ cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
+ if name.startswith("mid_block"):
+ hidden_size = unet.config.block_out_channels[-1]
+ elif name.startswith("up_blocks"):
+ block_id = int(name[len("up_blocks.")])
+ hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
+ elif name.startswith("down_blocks"):
+ block_id = int(name[len("down_blocks.")])
+ hidden_size = unet.config.block_out_channels[block_id]
+
+ lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
+
+ unet.set_attn_processor(lora_attn_procs)
+ lora_layers = AttnProcsLayers(unet.attn_processors)
+
+ # Move unet, vae and text_encoder to device and cast to weight_dtype
+ vae.to(accelerator.device, dtype=weight_dtype)
+ if not args.train_text_encoder:
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
+
+ if args.enable_xformers_memory_efficient_attention:
+ if is_xformers_available():
+ import xformers
+
+ xformers_version = version.parse(xformers.__version__)
+ if xformers_version == version.parse("0.0.16"):
+ logger.warning(
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
+ )
+ unet.enable_xformers_memory_efficient_attention()
+ else:
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ # Initialize the optimizer
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
+ )
+
+ optimizer_cls = bnb.optim.AdamW8bit
+ else:
+ optimizer_cls = torch.optim.AdamW
+
+ if args.use_peft:
+ # Optimizer creation
+ params_to_optimize = (
+ itertools.chain(unet.parameters(), text_encoder.parameters())
+ if args.train_text_encoder
+ else unet.parameters()
+ )
+ optimizer = optimizer_cls(
+ params_to_optimize,
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+ else:
+ optimizer = optimizer_cls(
+ lora_layers.parameters(),
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ # Get the datasets: you can either provide your own training and evaluation files (see below)
+ # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
+
+ # In distributed training, the load_dataset function guarantees that only one local process can concurrently
+ # download the dataset.
+ if args.dataset_name is not None:
+ # Downloading and loading a dataset from the hub.
+ dataset = load_dataset(
+ args.dataset_name,
+ args.dataset_config_name,
+ cache_dir=args.cache_dir,
+ )
+ else:
+ data_files = {}
+ if args.train_data_dir is not None:
+ data_files["train"] = os.path.join(args.train_data_dir, "**")
+ dataset = load_dataset(
+ "imagefolder",
+ data_files=data_files,
+ cache_dir=args.cache_dir,
+ )
+ # See more about loading custom images at
+ # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
+
+ # Preprocessing the datasets.
+ # We need to tokenize inputs and targets.
+ column_names = dataset["train"].column_names
+
+ # 6. Get the column names for input/target.
+ dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None)
+ if args.image_column is None:
+ image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
+ else:
+ image_column = args.image_column
+ if image_column not in column_names:
+ raise ValueError(
+ f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}"
+ )
+ if args.caption_column is None:
+ caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
+ else:
+ caption_column = args.caption_column
+ if caption_column not in column_names:
+ raise ValueError(
+ f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}"
+ )
+
+ # Preprocessing the datasets.
+ # We need to tokenize input captions and transform the images.
+ def tokenize_captions(examples, is_train=True):
+ captions = []
+ for caption in examples[caption_column]:
+ if isinstance(caption, str):
+ captions.append(caption)
+ elif isinstance(caption, (list, np.ndarray)):
+ # take a random caption if there are multiple
+ captions.append(random.choice(caption) if is_train else caption[0])
+ else:
+ raise ValueError(
+ f"Caption column `{caption_column}` should contain either strings or lists of strings."
+ )
+ inputs = tokenizer(
+ captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
+ )
+ return inputs.input_ids
+
+ # Preprocessing the datasets.
+ train_transforms = transforms.Compose(
+ [
+ transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
+ transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
+ transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+ )
+
+ def preprocess_train(examples):
+ images = [image.convert("RGB") for image in examples[image_column]]
+ examples["pixel_values"] = [train_transforms(image) for image in images]
+ examples["input_ids"] = tokenize_captions(examples)
+ return examples
+
+ with accelerator.main_process_first():
+ if args.max_train_samples is not None:
+ dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
+ # Set the training transforms
+ train_dataset = dataset["train"].with_transform(preprocess_train)
+
+ def collate_fn(examples):
+ pixel_values = torch.stack([example["pixel_values"] for example in examples])
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
+ input_ids = torch.stack([example["input_ids"] for example in examples])
+ return {"pixel_values": pixel_values, "input_ids": input_ids}
+
+ # DataLoaders creation:
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset,
+ shuffle=True,
+ collate_fn=collate_fn,
+ batch_size=args.train_batch_size,
+ num_workers=args.dataloader_num_workers,
+ )
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
+ )
+
+ # Prepare everything with our `accelerator`.
+ if args.use_peft:
+ if args.train_text_encoder:
+ unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ unet, text_encoder, optimizer, train_dataloader, lr_scheduler
+ )
+ else:
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ unet, optimizer, train_dataloader, lr_scheduler
+ )
+ else:
+ lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ lora_layers, optimizer, train_dataloader, lr_scheduler
+ )
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ accelerator.init_trackers("text2image-fine-tune", config=vars(args))
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the most recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ resume_global_step = global_step * args.gradient_accumulation_steps
+ first_epoch = global_step // num_update_steps_per_epoch
+ resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
+
+ # Only show the progress bar once on each machine.
+ progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
+ progress_bar.set_description("Steps")
+
+ for epoch in range(first_epoch, args.num_train_epochs):
+ unet.train()
+ if args.train_text_encoder:
+ text_encoder.train()
+ train_loss = 0.0
+ for step, batch in enumerate(train_dataloader):
+ # Skip steps until we reach the resumed step
+ if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
+ if step % args.gradient_accumulation_steps == 0:
+ progress_bar.update(1)
+ continue
+
+ with accelerator.accumulate(unet):
+ # Convert images to latent space
+ latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
+ latents = latents * vae.config.scaling_factor
+
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(latents)
+ bsz = latents.shape[0]
+ # Sample a random timestep for each image
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
+ timesteps = timesteps.long()
+
+ # Add noise to the latents according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
+
+ # Get the text embedding for conditioning
+ encoder_hidden_states = text_encoder(batch["input_ids"])[0]
+
+ # Get the target for loss depending on the prediction type
+ if noise_scheduler.config.prediction_type == "epsilon":
+ target = noise
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
+ else:
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
+
+ # Predict the noise residual and compute loss
+ model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
+
+ # Gather the losses across all processes for logging (if we use distributed training).
+ avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
+ train_loss += avg_loss.item() / args.gradient_accumulation_steps
+
+ # Backpropagate
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ if args.use_peft:
+ params_to_clip = (
+ itertools.chain(unet.parameters(), text_encoder.parameters())
+ if args.train_text_encoder
+ else unet.parameters()
+ )
+ else:
+ params_to_clip = lora_layers.parameters()
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ global_step += 1
+ accelerator.log({"train_loss": train_loss}, step=global_step)
+ train_loss = 0.0
+
+ if global_step % args.checkpointing_steps == 0:
+ if accelerator.is_main_process:
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ if accelerator.is_main_process:
+ if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
+ logger.info(
+ f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
+ f" {args.validation_prompt}."
+ )
+ # create pipeline
+ pipeline = DiffusionPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ unet=accelerator.unwrap_model(unet),
+ text_encoder=accelerator.unwrap_model(text_encoder),
+ revision=args.revision,
+ torch_dtype=weight_dtype,
+ )
+ pipeline = pipeline.to(accelerator.device)
+ pipeline.set_progress_bar_config(disable=True)
+
+ # run inference
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
+ images = []
+ for _ in range(args.num_validation_images):
+ images.append(
+ pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]
+ )
+
+ if accelerator.is_main_process:
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
+ if tracker.name == "wandb":
+ tracker.log(
+ {
+ "validation": [
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
+ for i, image in enumerate(images)
+ ]
+ }
+ )
+
+ del pipeline
+ torch.cuda.empty_cache()
+
+ # Save the lora layers
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ if args.use_peft:
+ lora_config = {}
+ unwarpped_unet = accelerator.unwrap_model(unet)
+ state_dict = get_peft_model_state_dict(unwarpped_unet, state_dict=accelerator.get_state_dict(unet))
+ lora_config["peft_config"] = unwarpped_unet.get_peft_config_as_dict(inference=True)
+ if args.train_text_encoder:
+ unwarpped_text_encoder = accelerator.unwrap_model(text_encoder)
+ text_encoder_state_dict = get_peft_model_state_dict(
+ unwarpped_text_encoder, state_dict=accelerator.get_state_dict(text_encoder)
+ )
+ text_encoder_state_dict = {f"text_encoder_{k}": v for k, v in text_encoder_state_dict.items()}
+ state_dict.update(text_encoder_state_dict)
+ lora_config["text_encoder_peft_config"] = unwarpped_text_encoder.get_peft_config_as_dict(
+ inference=True
+ )
+
+ accelerator.save(state_dict, os.path.join(args.output_dir, f"{global_step}_lora.pt"))
+ with open(os.path.join(args.output_dir, f"{global_step}_lora_config.json"), "w") as f:
+ json.dump(lora_config, f)
+ else:
+ unet = unet.to(torch.float32)
+ unet.save_attn_procs(args.output_dir)
+
+ if args.push_to_hub:
+ save_model_card(
+ repo_id,
+ images=images,
+ base_model=args.pretrained_model_name_or_path,
+ dataset_name=args.dataset_name,
+ repo_folder=args.output_dir,
+ )
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ # Final inference
+ # Load previous pipeline
+ pipeline = DiffusionPipeline.from_pretrained(
+ args.pretrained_model_name_or_path, revision=args.revision, torch_dtype=weight_dtype
+ )
+
+ if args.use_peft:
+
+ def load_and_set_lora_ckpt(pipe, ckpt_dir, global_step, device, dtype):
+ with open(os.path.join(args.output_dir, f"{global_step}_lora_config.json"), "r") as f:
+ lora_config = json.load(f)
+ print(lora_config)
+
+ checkpoint = os.path.join(args.output_dir, f"{global_step}_lora.pt")
+ lora_checkpoint_sd = torch.load(checkpoint)
+ unet_lora_ds = {k: v for k, v in lora_checkpoint_sd.items() if "text_encoder_" not in k}
+ text_encoder_lora_ds = {
+ k.replace("text_encoder_", ""): v for k, v in lora_checkpoint_sd.items() if "text_encoder_" in k
+ }
+
+ unet_config = LoraConfig(**lora_config["peft_config"])
+ pipe.unet = LoraModel(unet_config, pipe.unet)
+ set_peft_model_state_dict(pipe.unet, unet_lora_ds)
+
+ if "text_encoder_peft_config" in lora_config:
+ text_encoder_config = LoraConfig(**lora_config["text_encoder_peft_config"])
+ pipe.text_encoder = LoraModel(text_encoder_config, pipe.text_encoder)
+ set_peft_model_state_dict(pipe.text_encoder, text_encoder_lora_ds)
+
+ if dtype in (torch.float16, torch.bfloat16):
+ pipe.unet.half()
+ pipe.text_encoder.half()
+
+ pipe.to(device)
+ return pipe
+
+ pipeline = load_and_set_lora_ckpt(pipeline, args.output_dir, global_step, accelerator.device, weight_dtype)
+
+ else:
+ pipeline = pipeline.to(accelerator.device)
+ # load attention processors
+ pipeline.unet.load_attn_procs(args.output_dir)
+
+ # run inference
+ if args.seed is not None:
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
+ else:
+ generator = None
+ images = []
+ for _ in range(args.num_validation_images):
+ images.append(pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0])
+
+ if accelerator.is_main_process:
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC")
+ if tracker.name == "wandb":
+ tracker.log(
+ {
+ "test": [
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
+ for i, image in enumerate(images)
+ ]
+ }
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/diffusers/examples/research_projects/multi_subject_dreambooth/README.md b/diffusers/examples/research_projects/multi_subject_dreambooth/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..7c2e6f400935613391863b2cc73544baea1c144e
--- /dev/null
+++ b/diffusers/examples/research_projects/multi_subject_dreambooth/README.md
@@ -0,0 +1,338 @@
+# Multi Subject DreamBooth training
+
+[DreamBooth](https://arxiv.org/abs/2208.12242) is a method to personalize text2image models like stable diffusion given just a few(3~5) images of a subject.
+This `train_multi_subject_dreambooth.py` script shows how to implement the training procedure for one or more subjects and adapt it for stable diffusion. Note that this code is based off of the `examples/dreambooth/train_dreambooth.py` script as of 01/06/2022.
+
+This script was added by @kopsahlong, and is not actively maintained. However, if you come across anything that could use fixing, feel free to open an issue and tag @kopsahlong.
+
+## Running locally with PyTorch
+### Installing the dependencies
+
+Before running the script, make sure to install the library's training dependencies:
+
+To start, execute the following steps in a new virtual environment:
+```bash
+git clone https://github.com/huggingface/diffusers
+cd diffusers
+pip install -e .
+```
+
+Then cd into the folder `diffusers/examples/research_projects/multi_subject_dreambooth` and run the following:
+```bash
+pip install -r requirements.txt
+```
+
+And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
+
+```bash
+accelerate config
+```
+
+Or for a default accelerate configuration without answering questions about your environment
+
+```bash
+accelerate config default
+```
+
+Or if your environment doesn't support an interactive shell e.g. a notebook
+
+```python
+from accelerate.utils import write_basic_config
+write_basic_config()
+```
+
+### Multi Subject Training Example
+In order to have your model learn multiple concepts at once, we simply add in the additional data directories and prompts to our `instance_data_dir` and `instance_prompt` (as well as `class_data_dir` and `class_prompt` if `--with_prior_preservation` is specified) as one comma separated string.
+
+See an example with 2 subjects below, which learns a model for one dog subject and one human subject:
+
+```bash
+export MODEL_NAME="CompVis/stable-diffusion-v1-4"
+export OUTPUT_DIR="path-to-save-model"
+
+# Subject 1
+export INSTANCE_DIR_1="path-to-instance-images-concept-1"
+export INSTANCE_PROMPT_1="a photo of a sks dog"
+export CLASS_DIR_1="path-to-class-images-dog"
+export CLASS_PROMPT_1="a photo of a dog"
+
+# Subject 2
+export INSTANCE_DIR_2="path-to-instance-images-concept-2"
+export INSTANCE_PROMPT_2="a photo of a t@y person"
+export CLASS_DIR_2="path-to-class-images-person"
+export CLASS_PROMPT_2="a photo of a person"
+
+accelerate launch train_multi_subject_dreambooth.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --instance_data_dir="$INSTANCE_DIR_1,$INSTANCE_DIR_2" \
+ --output_dir=$OUTPUT_DIR \
+ --train_text_encoder \
+ --instance_prompt="$INSTANCE_PROMPT_1,$INSTANCE_PROMPT_2" \
+ --with_prior_preservation \
+ --prior_loss_weight=1.0 \
+ --class_data_dir="$CLASS_DIR_1,$CLASS_DIR_2" \
+ --class_prompt="$CLASS_PROMPT_1,$CLASS_PROMPT_2"\
+ --num_class_images=50 \
+ --resolution=512 \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=1 \
+ --learning_rate=1e-6 \
+ --lr_scheduler="constant" \
+ --lr_warmup_steps=0 \
+ --max_train_steps=1500
+```
+
+This example shows training for 2 subjects, but please note that the model can be trained on any number of new concepts. This can be done by continuing to add in the corresponding directories and prompts to the corresponding comma separated string.
+
+Note also that in this script, `sks` and `t@y` were used as tokens to learn the new subjects ([this thread](https://github.com/XavierXiao/Dreambooth-Stable-Diffusion/issues/71) inspired the use of `t@y` as our second identifier). However, there may be better rare tokens to experiment with, and results also seemed to be good when more intuitive words are used.
+
+**Important**: New parameters are added to the script, making possible to validate the progress of the training by
+generating images at specified steps. Taking also into account that a comma separated list in a text field for a prompt
+it's never a good idea (simply because it is very common in prompts to have them as part of a regular text) we
+introduce the `concept_list` parameter: allowing to specify a json-like file where you can define the different
+configuration for each subject that you want to train.
+
+An example of how to generate the file:
+```python
+import json
+
+# here we are using parameters for prior-preservation and validation as well.
+concepts_list = [
+ {
+ "instance_prompt": "drawing of a t@y meme",
+ "class_prompt": "drawing of a meme",
+ "instance_data_dir": "/some_folder/meme_toy",
+ "class_data_dir": "/data/meme",
+ "validation_prompt": "drawing of a t@y meme about football in Uruguay",
+ "validation_negative_prompt": "black and white"
+ },
+ {
+ "instance_prompt": "drawing of a sks sir",
+ "class_prompt": "drawing of a sir",
+ "instance_data_dir": "/some_other_folder/sir_sks",
+ "class_data_dir": "/data/sir",
+ "validation_prompt": "drawing of a sks sir with the Uruguayan sun in his chest",
+ "validation_negative_prompt": "an old man",
+ "validation_guidance_scale": 20,
+ "validation_number_images": 3,
+ "validation_inference_steps": 10
+ }
+]
+
+with open("concepts_list.json", "w") as f:
+ json.dump(concepts_list, f, indent=4)
+```
+And then just point to the file when executing the script:
+
+```bash
+# exports...
+accelerate launch train_multi_subject_dreambooth.py \
+# more parameters...
+--concepts_list="concepts_list.json"
+```
+
+You can use the helper from the script to get a better sense of each parameter.
+
+### Inference
+
+Once you have trained a model using above command, the inference can be done simply using the `StableDiffusionPipeline`. Make sure to include the `identifier`(e.g. sks in above example) in your prompt.
+
+```python
+from diffusers import StableDiffusionPipeline
+import torch
+
+model_id = "path-to-your-trained-model"
+pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
+
+prompt = "A photo of a t@y person petting an sks dog"
+image = pipe(prompt, num_inference_steps=200, guidance_scale=7.5).images[0]
+
+image.save("person-petting-dog.png")
+```
+
+### Inference from a training checkpoint
+
+You can also perform inference from one of the checkpoints saved during the training process, if you used the `--checkpointing_steps` argument. Please, refer to [the documentation](https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint) to see how to do it.
+
+## Additional Dreambooth documentation
+Because the `train_multi_subject_dreambooth.py` script here was forked from an original version of `train_dreambooth.py` in the `examples/dreambooth` folder, I've included the original applicable training documentation for single subject examples below.
+
+This should explain how to play with training variables such as prior preservation, fine tuning the text encoder, etc. which is still applicable to our multi subject training code. Note also that the examples below, which are single subject examples, also work with `train_multi_subject_dreambooth.py`, as this script supports 1 (or more) subjects.
+
+### Single subject dog toy example
+
+Let's get our dataset. Download images from [here](https://drive.google.com/drive/folders/1BO_dyz-p65qhBRRMRA4TbZ8qW4rB99JZ) and save them in a directory. This will be our training data.
+
+And launch the training using
+
+**___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.___**
+
+```bash
+export MODEL_NAME="CompVis/stable-diffusion-v1-4"
+export INSTANCE_DIR="path-to-instance-images"
+export OUTPUT_DIR="path-to-save-model"
+
+accelerate launch train_dreambooth.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --instance_data_dir=$INSTANCE_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --instance_prompt="a photo of sks dog" \
+ --resolution=512 \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=1 \
+ --learning_rate=5e-6 \
+ --lr_scheduler="constant" \
+ --lr_warmup_steps=0 \
+ --max_train_steps=400
+```
+
+### Training with prior-preservation loss
+
+Prior-preservation is used to avoid overfitting and language-drift. Refer to the paper to learn more about it. For prior-preservation we first generate images using the model with a class prompt and then use those during training along with our data.
+According to the paper, it's recommended to generate `num_epochs * num_samples` images for prior-preservation. 200-300 works well for most cases. The `num_class_images` flag sets the number of images to generate with the class prompt. You can place existing images in `class_data_dir`, and the training script will generate any additional images so that `num_class_images` are present in `class_data_dir` during training time.
+
+```bash
+export MODEL_NAME="CompVis/stable-diffusion-v1-4"
+export INSTANCE_DIR="path-to-instance-images"
+export CLASS_DIR="path-to-class-images"
+export OUTPUT_DIR="path-to-save-model"
+
+accelerate launch train_dreambooth.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --instance_data_dir=$INSTANCE_DIR \
+ --class_data_dir=$CLASS_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --with_prior_preservation --prior_loss_weight=1.0 \
+ --instance_prompt="a photo of sks dog" \
+ --class_prompt="a photo of dog" \
+ --resolution=512 \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=1 \
+ --learning_rate=5e-6 \
+ --lr_scheduler="constant" \
+ --lr_warmup_steps=0 \
+ --num_class_images=200 \
+ --max_train_steps=800
+```
+
+
+### Training on a 16GB GPU:
+
+With the help of gradient checkpointing and the 8-bit optimizer from bitsandbytes it's possible to run train dreambooth on a 16GB GPU.
+
+To install `bitandbytes` please refer to this [readme](https://github.com/TimDettmers/bitsandbytes#requirements--installation).
+
+```bash
+export MODEL_NAME="CompVis/stable-diffusion-v1-4"
+export INSTANCE_DIR="path-to-instance-images"
+export CLASS_DIR="path-to-class-images"
+export OUTPUT_DIR="path-to-save-model"
+
+accelerate launch train_dreambooth.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --instance_data_dir=$INSTANCE_DIR \
+ --class_data_dir=$CLASS_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --with_prior_preservation --prior_loss_weight=1.0 \
+ --instance_prompt="a photo of sks dog" \
+ --class_prompt="a photo of dog" \
+ --resolution=512 \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=2 --gradient_checkpointing \
+ --use_8bit_adam \
+ --learning_rate=5e-6 \
+ --lr_scheduler="constant" \
+ --lr_warmup_steps=0 \
+ --num_class_images=200 \
+ --max_train_steps=800
+```
+
+### Training on a 8 GB GPU:
+
+By using [DeepSpeed](https://www.deepspeed.ai/) it's possible to offload some
+tensors from VRAM to either CPU or NVME allowing to train with less VRAM.
+
+DeepSpeed needs to be enabled with `accelerate config`. During configuration
+answer yes to "Do you want to use DeepSpeed?". With DeepSpeed stage 2, fp16
+mixed precision and offloading both parameters and optimizer state to cpu it's
+possible to train on under 8 GB VRAM with a drawback of requiring significantly
+more RAM (about 25 GB). See [documentation](https://huggingface.co/docs/accelerate/usage_guides/deepspeed) for more DeepSpeed configuration options.
+
+Changing the default Adam optimizer to DeepSpeed's special version of Adam
+`deepspeed.ops.adam.DeepSpeedCPUAdam` gives a substantial speedup but enabling
+it requires CUDA toolchain with the same version as pytorch. 8-bit optimizer
+does not seem to be compatible with DeepSpeed at the moment.
+
+```bash
+export MODEL_NAME="CompVis/stable-diffusion-v1-4"
+export INSTANCE_DIR="path-to-instance-images"
+export CLASS_DIR="path-to-class-images"
+export OUTPUT_DIR="path-to-save-model"
+
+accelerate launch --mixed_precision="fp16" train_dreambooth.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --instance_data_dir=$INSTANCE_DIR \
+ --class_data_dir=$CLASS_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --with_prior_preservation --prior_loss_weight=1.0 \
+ --instance_prompt="a photo of sks dog" \
+ --class_prompt="a photo of dog" \
+ --resolution=512 \
+ --train_batch_size=1 \
+ --sample_batch_size=1 \
+ --gradient_accumulation_steps=1 --gradient_checkpointing \
+ --learning_rate=5e-6 \
+ --lr_scheduler="constant" \
+ --lr_warmup_steps=0 \
+ --num_class_images=200 \
+ --max_train_steps=800
+```
+
+### Fine-tune text encoder with the UNet.
+
+The script also allows to fine-tune the `text_encoder` along with the `unet`. It's been observed experimentally that fine-tuning `text_encoder` gives much better results especially on faces.
+Pass the `--train_text_encoder` argument to the script to enable training `text_encoder`.
+
+___Note: Training text encoder requires more memory, with this option the training won't fit on 16GB GPU. It needs at least 24GB VRAM.___
+
+```bash
+export MODEL_NAME="CompVis/stable-diffusion-v1-4"
+export INSTANCE_DIR="path-to-instance-images"
+export CLASS_DIR="path-to-class-images"
+export OUTPUT_DIR="path-to-save-model"
+
+accelerate launch train_dreambooth.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --train_text_encoder \
+ --instance_data_dir=$INSTANCE_DIR \
+ --class_data_dir=$CLASS_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --with_prior_preservation --prior_loss_weight=1.0 \
+ --instance_prompt="a photo of sks dog" \
+ --class_prompt="a photo of dog" \
+ --resolution=512 \
+ --train_batch_size=1 \
+ --use_8bit_adam \
+ --gradient_checkpointing \
+ --learning_rate=2e-6 \
+ --lr_scheduler="constant" \
+ --lr_warmup_steps=0 \
+ --num_class_images=200 \
+ --max_train_steps=800
+```
+
+### Using DreamBooth for other pipelines than Stable Diffusion
+
+Altdiffusion also supports dreambooth now, the running command is basically the same as above, all you need to do is replace the `MODEL_NAME` like this:
+One can now simply change the `pretrained_model_name_or_path` to another architecture such as [`AltDiffusion`](https://huggingface.co/docs/diffusers/api/pipelines/alt_diffusion).
+
+```
+export MODEL_NAME="CompVis/stable-diffusion-v1-4" --> export MODEL_NAME="BAAI/AltDiffusion-m9"
+or
+export MODEL_NAME="CompVis/stable-diffusion-v1-4" --> export MODEL_NAME="BAAI/AltDiffusion"
+```
+
+### Training with xformers:
+You can enable memory efficient attention by [installing xFormers](https://github.com/facebookresearch/xformers#installing-xformers) and padding the `--enable_xformers_memory_efficient_attention` argument to the script. This is not available with the Flax/JAX implementation.
+
+You can also use Dreambooth to train the specialized in-painting model. See [the script in the research folder for details](https://github.com/huggingface/diffusers/tree/main/examples/research_projects/dreambooth_inpaint).
\ No newline at end of file
diff --git a/diffusers/examples/research_projects/multi_subject_dreambooth/requirements.txt b/diffusers/examples/research_projects/multi_subject_dreambooth/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..e19b0ce60bf407bec21a9b85f9232cad957bfa6f
--- /dev/null
+++ b/diffusers/examples/research_projects/multi_subject_dreambooth/requirements.txt
@@ -0,0 +1,6 @@
+accelerate>=0.16.0
+torchvision
+transformers>=4.25.1
+ftfy
+tensorboard
+Jinja2
\ No newline at end of file
diff --git a/diffusers/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py b/diffusers/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py
new file mode 100644
index 0000000000000000000000000000000000000000..0f507b26d6a81cf64e7c699e6d5e2c47d9d175b7
--- /dev/null
+++ b/diffusers/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py
@@ -0,0 +1,1195 @@
+import argparse
+import itertools
+import json
+import logging
+import math
+import uuid
+import warnings
+from os import environ, listdir, makedirs
+from os.path import basename, join
+from pathlib import Path
+from typing import List
+
+import datasets
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import ProjectConfiguration, set_seed
+from huggingface_hub import create_repo, upload_folder
+from huggingface_hub.utils import insecure_hashlib
+from PIL import Image
+from torch import dtype
+from torch.nn import Module
+from torch.utils.data import Dataset
+from torchvision import transforms
+from tqdm.auto import tqdm
+from transformers import AutoTokenizer, PretrainedConfig
+
+import diffusers
+from diffusers import (
+ AutoencoderKL,
+ DDPMScheduler,
+ DiffusionPipeline,
+ DPMSolverMultistepScheduler,
+ UNet2DConditionModel,
+)
+from diffusers.optimization import get_scheduler
+from diffusers.utils import check_min_version, is_wandb_available
+from diffusers.utils.import_utils import is_xformers_available
+
+
+if is_wandb_available():
+ import wandb
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.13.0.dev0")
+
+logger = get_logger(__name__)
+
+
+def log_validation_images_to_tracker(
+ images: List[np.array], label: str, validation_prompt: str, accelerator: Accelerator, epoch: int
+):
+ logger.info(f"Logging images to tracker for validation prompt: {validation_prompt}.")
+
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
+ if tracker.name == "wandb":
+ tracker.log(
+ {
+ "validation": [
+ wandb.Image(image, caption=f"{label}_{epoch}_{i}: {validation_prompt}")
+ for i, image in enumerate(images)
+ ]
+ }
+ )
+
+
+# TODO: Add `prompt_embeds` and `negative_prompt_embeds` parameters to the function when `pre_compute_text_embeddings`
+# argument is implemented.
+def generate_validation_images(
+ text_encoder: Module,
+ tokenizer: Module,
+ unet: Module,
+ vae: Module,
+ arguments: argparse.Namespace,
+ accelerator: Accelerator,
+ weight_dtype: dtype,
+):
+ logger.info("Running validation images.")
+
+ pipeline_args = {}
+
+ if text_encoder is not None:
+ pipeline_args["text_encoder"] = accelerator.unwrap_model(text_encoder)
+
+ if vae is not None:
+ pipeline_args["vae"] = vae
+
+ # create pipeline (note: unet and vae are loaded again in float32)
+ pipeline = DiffusionPipeline.from_pretrained(
+ arguments.pretrained_model_name_or_path,
+ tokenizer=tokenizer,
+ unet=accelerator.unwrap_model(unet),
+ revision=arguments.revision,
+ torch_dtype=weight_dtype,
+ **pipeline_args,
+ )
+
+ # We train on the simplified learning objective. If we were previously predicting a variance, we need the
+ # scheduler to ignore it
+ scheduler_args = {}
+
+ if "variance_type" in pipeline.scheduler.config:
+ variance_type = pipeline.scheduler.config.variance_type
+
+ if variance_type in ["learned", "learned_range"]:
+ variance_type = "fixed_small"
+
+ scheduler_args["variance_type"] = variance_type
+
+ pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args)
+ pipeline = pipeline.to(accelerator.device)
+ pipeline.set_progress_bar_config(disable=True)
+
+ generator = (
+ None if arguments.seed is None else torch.Generator(device=accelerator.device).manual_seed(arguments.seed)
+ )
+
+ images_sets = []
+ for vp, nvi, vnp, vis, vgs in zip(
+ arguments.validation_prompt,
+ arguments.validation_number_images,
+ arguments.validation_negative_prompt,
+ arguments.validation_inference_steps,
+ arguments.validation_guidance_scale,
+ ):
+ images = []
+ if vp is not None:
+ logger.info(
+ f"Generating {nvi} images with prompt: '{vp}', negative prompt: '{vnp}', inference steps: {vis}, "
+ f"guidance scale: {vgs}."
+ )
+
+ pipeline_args = {"prompt": vp, "negative_prompt": vnp, "num_inference_steps": vis, "guidance_scale": vgs}
+
+ # run inference
+ # TODO: it would be good to measure whether it's faster to run inference on all images at once, one at a
+ # time or in small batches
+ for _ in range(nvi):
+ with torch.autocast("cuda"):
+ image = pipeline(**pipeline_args, num_images_per_prompt=1, generator=generator).images[0]
+ images.append(image)
+
+ images_sets.append(images)
+
+ del pipeline
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+
+ return images_sets
+
+
+def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
+ text_encoder_config = PretrainedConfig.from_pretrained(
+ pretrained_model_name_or_path,
+ subfolder="text_encoder",
+ revision=revision,
+ )
+ model_class = text_encoder_config.architectures[0]
+
+ if model_class == "CLIPTextModel":
+ from transformers import CLIPTextModel
+
+ return CLIPTextModel
+ elif model_class == "RobertaSeriesModelWithTransformation":
+ from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation
+
+ return RobertaSeriesModelWithTransformation
+ else:
+ raise ValueError(f"{model_class} is not supported.")
+
+
+def parse_args(input_args=None):
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--tokenizer_name",
+ type=str,
+ default=None,
+ help="Pretrained tokenizer name or path if not the same as model_name",
+ )
+ parser.add_argument(
+ "--instance_data_dir",
+ type=str,
+ default=None,
+ required=False,
+ help="A folder containing the training data of instance images.",
+ )
+ parser.add_argument(
+ "--class_data_dir",
+ type=str,
+ default=None,
+ required=False,
+ help="A folder containing the training data of class images.",
+ )
+ parser.add_argument(
+ "--instance_prompt",
+ type=str,
+ default=None,
+ required=False,
+ help="The prompt with identifier specifying the instance",
+ )
+ parser.add_argument(
+ "--class_prompt",
+ type=str,
+ default=None,
+ help="The prompt to specify images in the same class as provided instance images.",
+ )
+ parser.add_argument(
+ "--with_prior_preservation",
+ default=False,
+ action="store_true",
+ help="Flag to add prior preservation loss.",
+ )
+ parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.")
+ parser.add_argument(
+ "--num_class_images",
+ type=int,
+ default=100,
+ help=(
+ "Minimal class images for prior preservation loss. If there are not enough images already present in"
+ " class_data_dir, additional images will be sampled with class_prompt."
+ ),
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="text-inversion-model",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--center_crop",
+ default=False,
+ action="store_true",
+ help=(
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
+ " cropped. The images will be resized to the resolution first before cropping."
+ ),
+ )
+ parser.add_argument("--train_text_encoder", action="store_true", help="Whether to train the text encoder")
+ parser.add_argument(
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument(
+ "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=1)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
+ " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=(
+ "Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
+ " See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
+ " for more docs"
+ ),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=5e-6,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--lr_num_cycles",
+ type=int,
+ default=1,
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
+ )
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
+ parser.add_argument(
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument(
+ "--validation_steps",
+ type=int,
+ default=None,
+ help=(
+ "Run validation every X steps. Validation consists of running the prompt(s) `validation_prompt` "
+ "multiple times (`validation_number_images`) and logging the images."
+ ),
+ )
+ parser.add_argument(
+ "--validation_prompt",
+ type=str,
+ default=None,
+ help="A prompt that is used during validation to verify that the model is learning. You can use commas to "
+ "define multiple negative prompts. This parameter can be defined also within the file given by "
+ "`concepts_list` parameter in the respective subject.",
+ )
+ parser.add_argument(
+ "--validation_number_images",
+ type=int,
+ default=4,
+ help="Number of images that should be generated during validation with the validation parameters given. This "
+ "can be defined within the file given by `concepts_list` parameter in the respective subject.",
+ )
+ parser.add_argument(
+ "--validation_negative_prompt",
+ type=str,
+ default=None,
+ help="A negative prompt that is used during validation to verify that the model is learning. You can use commas"
+ " to define multiple negative prompts, each one corresponding to a validation prompt. This parameter can "
+ "be defined also within the file given by `concepts_list` parameter in the respective subject.",
+ )
+ parser.add_argument(
+ "--validation_inference_steps",
+ type=int,
+ default=25,
+ help="Number of inference steps (denoising steps) to run during validation. This can be defined within the "
+ "file given by `concepts_list` parameter in the respective subject.",
+ )
+ parser.add_argument(
+ "--validation_guidance_scale",
+ type=float,
+ default=7.5,
+ help="To control how much the image generation process follows the text prompt. This can be defined within the "
+ "file given by `concepts_list` parameter in the respective subject.",
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument(
+ "--prior_generation_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp32", "fp16", "bf16"],
+ help=(
+ "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32."
+ ),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+ parser.add_argument(
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
+ )
+ parser.add_argument(
+ "--set_grads_to_none",
+ action="store_true",
+ help=(
+ "Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain"
+ " behaviors, so disable this argument if it causes any problems. More info:"
+ " https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html"
+ ),
+ )
+ parser.add_argument(
+ "--concepts_list",
+ type=str,
+ default=None,
+ help="Path to json file containing a list of multiple concepts, will overwrite parameters like instance_prompt,"
+ " class_prompt, etc.",
+ )
+
+ if input_args:
+ args = parser.parse_args(input_args)
+ else:
+ args = parser.parse_args()
+
+ if not args.concepts_list and (not args.instance_data_dir or not args.instance_prompt):
+ raise ValueError(
+ "You must specify either instance parameters (data directory, prompt, etc.) or use "
+ "the `concept_list` parameter and specify them within the file."
+ )
+
+ if args.concepts_list:
+ if args.instance_prompt:
+ raise ValueError("If you are using `concepts_list` parameter, define the instance prompt within the file.")
+ if args.instance_data_dir:
+ raise ValueError(
+ "If you are using `concepts_list` parameter, define the instance data directory within the file."
+ )
+ if args.validation_steps and (args.validation_prompt or args.validation_negative_prompt):
+ raise ValueError(
+ "If you are using `concepts_list` parameter, define validation parameters for "
+ "each subject within the file:\n - `validation_prompt`."
+ "\n - `validation_negative_prompt`.\n - `validation_guidance_scale`."
+ "\n - `validation_number_images`.\n - `validation_prompt`."
+ "\n - `validation_inference_steps`.\nThe `validation_steps` parameter is the only one "
+ "that needs to be defined outside the file."
+ )
+
+ env_local_rank = int(environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ if args.with_prior_preservation:
+ if not args.concepts_list:
+ if not args.class_data_dir:
+ raise ValueError("You must specify a data directory for class images.")
+ if not args.class_prompt:
+ raise ValueError("You must specify prompt for class images.")
+ else:
+ if args.class_data_dir:
+ raise ValueError(
+ "If you are using `concepts_list` parameter, define the class data directory within the file."
+ )
+ if args.class_prompt:
+ raise ValueError(
+ "If you are using `concepts_list` parameter, define the class prompt within the file."
+ )
+ else:
+ # logger is not available yet
+ if not args.class_data_dir:
+ warnings.warn(
+ "Ignoring `class_data_dir` parameter, you need to use it together with `with_prior_preservation`."
+ )
+ if not args.class_prompt:
+ warnings.warn(
+ "Ignoring `class_prompt` parameter, you need to use it together with `with_prior_preservation`."
+ )
+
+ return args
+
+
+class DreamBoothDataset(Dataset):
+ """
+ A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
+ It pre-processes the images and then tokenizes prompts.
+ """
+
+ def __init__(
+ self,
+ instance_data_root,
+ instance_prompt,
+ tokenizer,
+ class_data_root=None,
+ class_prompt=None,
+ size=512,
+ center_crop=False,
+ ):
+ self.size = size
+ self.center_crop = center_crop
+ self.tokenizer = tokenizer
+
+ self.instance_data_root = []
+ self.instance_images_path = []
+ self.num_instance_images = []
+ self.instance_prompt = []
+ self.class_data_root = [] if class_data_root is not None else None
+ self.class_images_path = []
+ self.num_class_images = []
+ self.class_prompt = []
+ self._length = 0
+
+ for i in range(len(instance_data_root)):
+ self.instance_data_root.append(Path(instance_data_root[i]))
+ if not self.instance_data_root[i].exists():
+ raise ValueError("Instance images root doesn't exists.")
+
+ self.instance_images_path.append(list(Path(instance_data_root[i]).iterdir()))
+ self.num_instance_images.append(len(self.instance_images_path[i]))
+ self.instance_prompt.append(instance_prompt[i])
+ self._length += self.num_instance_images[i]
+
+ if class_data_root is not None:
+ self.class_data_root.append(Path(class_data_root[i]))
+ self.class_data_root[i].mkdir(parents=True, exist_ok=True)
+ self.class_images_path.append(list(self.class_data_root[i].iterdir()))
+ self.num_class_images.append(len(self.class_images_path))
+ if self.num_class_images[i] > self.num_instance_images[i]:
+ self._length -= self.num_instance_images[i]
+ self._length += self.num_class_images[i]
+ self.class_prompt.append(class_prompt[i])
+
+ self.image_transforms = transforms.Compose(
+ [
+ transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
+ transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+ )
+
+ def __len__(self):
+ return self._length
+
+ def __getitem__(self, index):
+ example = {}
+ for i in range(len(self.instance_images_path)):
+ instance_image = Image.open(self.instance_images_path[i][index % self.num_instance_images[i]])
+ if not instance_image.mode == "RGB":
+ instance_image = instance_image.convert("RGB")
+ example[f"instance_images_{i}"] = self.image_transforms(instance_image)
+ example[f"instance_prompt_ids_{i}"] = self.tokenizer(
+ self.instance_prompt[i],
+ truncation=True,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ return_tensors="pt",
+ ).input_ids
+
+ if self.class_data_root:
+ for i in range(len(self.class_data_root)):
+ class_image = Image.open(self.class_images_path[i][index % self.num_class_images[i]])
+ if not class_image.mode == "RGB":
+ class_image = class_image.convert("RGB")
+ example[f"class_images_{i}"] = self.image_transforms(class_image)
+ example[f"class_prompt_ids_{i}"] = self.tokenizer(
+ self.class_prompt[i],
+ truncation=True,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ return_tensors="pt",
+ ).input_ids
+
+ return example
+
+
+def collate_fn(num_instances, examples, with_prior_preservation=False):
+ input_ids = []
+ pixel_values = []
+
+ for i in range(num_instances):
+ input_ids += [example[f"instance_prompt_ids_{i}"] for example in examples]
+ pixel_values += [example[f"instance_images_{i}"] for example in examples]
+
+ # Concat class and instance examples for prior preservation.
+ # We do this to avoid doing two forward passes.
+ if with_prior_preservation:
+ for i in range(num_instances):
+ input_ids += [example[f"class_prompt_ids_{i}"] for example in examples]
+ pixel_values += [example[f"class_images_{i}"] for example in examples]
+
+ pixel_values = torch.stack(pixel_values)
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
+
+ input_ids = torch.cat(input_ids, dim=0)
+
+ batch = {
+ "input_ids": input_ids,
+ "pixel_values": pixel_values,
+ }
+ return batch
+
+
+class PromptDataset(Dataset):
+ """A simple dataset to prepare the prompts to generate class images on multiple GPUs."""
+
+ def __init__(self, prompt, num_samples):
+ self.prompt = prompt
+ self.num_samples = num_samples
+
+ def __len__(self):
+ return self.num_samples
+
+ def __getitem__(self, index):
+ example = {}
+ example["prompt"] = self.prompt
+ example["index"] = index
+ return example
+
+
+def main(args):
+ if args.report_to == "wandb" and args.hub_token is not None:
+ raise ValueError(
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
+ " Please use `huggingface-cli login` to authenticate with the Hub."
+ )
+
+ logging_dir = Path(args.output_dir, args.logging_dir)
+ accelerator_project_config = ProjectConfiguration(
+ total_limit=args.checkpoints_total_limit, project_dir=args.output_dir, logging_dir=logging_dir
+ )
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ )
+
+ # Disable AMP for MPS.
+ if torch.backends.mps.is_available():
+ accelerator.native_amp = False
+
+ if args.report_to == "wandb":
+ if not is_wandb_available():
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
+
+ # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
+ # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
+ # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.
+ if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1:
+ raise ValueError(
+ "Gradient accumulation is not supported when training the text encoder in distributed training. "
+ "Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
+ )
+
+ instance_data_dir = []
+ instance_prompt = []
+ class_data_dir = [] if args.with_prior_preservation else None
+ class_prompt = [] if args.with_prior_preservation else None
+ if args.concepts_list:
+ with open(args.concepts_list, "r") as f:
+ concepts_list = json.load(f)
+
+ if args.validation_steps:
+ args.validation_prompt = []
+ args.validation_number_images = []
+ args.validation_negative_prompt = []
+ args.validation_inference_steps = []
+ args.validation_guidance_scale = []
+
+ for concept in concepts_list:
+ instance_data_dir.append(concept["instance_data_dir"])
+ instance_prompt.append(concept["instance_prompt"])
+
+ if args.with_prior_preservation:
+ try:
+ class_data_dir.append(concept["class_data_dir"])
+ class_prompt.append(concept["class_prompt"])
+ except KeyError:
+ raise KeyError(
+ "`class_data_dir` or `class_prompt` not found in concepts_list while using "
+ "`with_prior_preservation`."
+ )
+ else:
+ if "class_data_dir" in concept:
+ warnings.warn(
+ "Ignoring `class_data_dir` key, to use it you need to enable `with_prior_preservation`."
+ )
+ if "class_prompt" in concept:
+ warnings.warn(
+ "Ignoring `class_prompt` key, to use it you need to enable `with_prior_preservation`."
+ )
+
+ if args.validation_steps:
+ args.validation_prompt.append(concept.get("validation_prompt", None))
+ args.validation_number_images.append(concept.get("validation_number_images", 4))
+ args.validation_negative_prompt.append(concept.get("validation_negative_prompt", None))
+ args.validation_inference_steps.append(concept.get("validation_inference_steps", 25))
+ args.validation_guidance_scale.append(concept.get("validation_guidance_scale", 7.5))
+ else:
+ # Parse instance and class inputs, and double check that lengths match
+ instance_data_dir = args.instance_data_dir.split(",")
+ instance_prompt = args.instance_prompt.split(",")
+ assert all(
+ x == len(instance_data_dir) for x in [len(instance_data_dir), len(instance_prompt)]
+ ), "Instance data dir and prompt inputs are not of the same length."
+
+ if args.with_prior_preservation:
+ class_data_dir = args.class_data_dir.split(",")
+ class_prompt = args.class_prompt.split(",")
+ assert all(
+ x == len(instance_data_dir)
+ for x in [len(instance_data_dir), len(instance_prompt), len(class_data_dir), len(class_prompt)]
+ ), "Instance & class data dir or prompt inputs are not of the same length."
+
+ if args.validation_steps:
+ validation_prompts = args.validation_prompt.split(",")
+ num_of_validation_prompts = len(validation_prompts)
+ args.validation_prompt = validation_prompts
+ args.validation_number_images = [args.validation_number_images] * num_of_validation_prompts
+
+ negative_validation_prompts = [None] * num_of_validation_prompts
+ if args.validation_negative_prompt:
+ negative_validation_prompts = args.validation_negative_prompt.split(",")
+ while len(negative_validation_prompts) < num_of_validation_prompts:
+ negative_validation_prompts.append(None)
+ args.validation_negative_prompt = negative_validation_prompts
+
+ assert num_of_validation_prompts == len(
+ negative_validation_prompts
+ ), "The length of negative prompts for validation is greater than the number of validation prompts."
+ args.validation_inference_steps = [args.validation_inference_steps] * num_of_validation_prompts
+ args.validation_guidance_scale = [args.validation_guidance_scale] * num_of_validation_prompts
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ datasets.utils.logging.set_verbosity_warning()
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ datasets.utils.logging.set_verbosity_error()
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Generate class images if prior preservation is enabled.
+ if args.with_prior_preservation:
+ for i in range(len(class_data_dir)):
+ class_images_dir = Path(class_data_dir[i])
+ if not class_images_dir.exists():
+ class_images_dir.mkdir(parents=True)
+ cur_class_images = len(list(class_images_dir.iterdir()))
+
+ if cur_class_images < args.num_class_images:
+ torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
+ if args.prior_generation_precision == "fp32":
+ torch_dtype = torch.float32
+ elif args.prior_generation_precision == "fp16":
+ torch_dtype = torch.float16
+ elif args.prior_generation_precision == "bf16":
+ torch_dtype = torch.bfloat16
+ pipeline = DiffusionPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ torch_dtype=torch_dtype,
+ safety_checker=None,
+ revision=args.revision,
+ )
+ pipeline.set_progress_bar_config(disable=True)
+
+ num_new_images = args.num_class_images - cur_class_images
+ logger.info(f"Number of class images to sample: {num_new_images}.")
+
+ sample_dataset = PromptDataset(class_prompt[i], num_new_images)
+ sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
+
+ sample_dataloader = accelerator.prepare(sample_dataloader)
+ pipeline.to(accelerator.device)
+
+ for example in tqdm(
+ sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
+ ):
+ images = pipeline(example["prompt"]).images
+
+ for ii, image in enumerate(images):
+ hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest()
+ image_filename = (
+ class_images_dir / f"{example['index'][ii] + cur_class_images}-{hash_image}.jpg"
+ )
+ image.save(image_filename)
+
+ # Clean up the memory deleting one-time-use variables.
+ del pipeline
+ del sample_dataloader
+ del sample_dataset
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
+ ).repo_id
+
+ # Load the tokenizer
+ tokenizer = None
+ if args.tokenizer_name:
+ tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False)
+ elif args.pretrained_model_name_or_path:
+ tokenizer = AutoTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="tokenizer",
+ revision=args.revision,
+ use_fast=False,
+ )
+
+ # import correct text encoder class
+ text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)
+
+ # Load scheduler and models
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
+ text_encoder = text_encoder_cls.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
+ )
+ vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
+ unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
+ )
+
+ vae.requires_grad_(False)
+ if not args.train_text_encoder:
+ text_encoder.requires_grad_(False)
+
+ if args.enable_xformers_memory_efficient_attention:
+ if is_xformers_available():
+ unet.enable_xformers_memory_efficient_attention()
+ else:
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
+
+ if args.gradient_checkpointing:
+ unet.enable_gradient_checkpointing()
+ if args.train_text_encoder:
+ text_encoder.gradient_checkpointing_enable()
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
+ )
+
+ optimizer_class = bnb.optim.AdamW8bit
+ else:
+ optimizer_class = torch.optim.AdamW
+
+ # Optimizer creation
+ params_to_optimize = (
+ itertools.chain(unet.parameters(), text_encoder.parameters()) if args.train_text_encoder else unet.parameters()
+ )
+ optimizer = optimizer_class(
+ params_to_optimize,
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ # Dataset and DataLoaders creation:
+ train_dataset = DreamBoothDataset(
+ instance_data_root=instance_data_dir,
+ instance_prompt=instance_prompt,
+ class_data_root=class_data_dir,
+ class_prompt=class_prompt,
+ tokenizer=tokenizer,
+ size=args.resolution,
+ center_crop=args.center_crop,
+ )
+
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset,
+ batch_size=args.train_batch_size,
+ shuffle=True,
+ collate_fn=lambda examples: collate_fn(len(instance_data_dir), examples, args.with_prior_preservation),
+ num_workers=1,
+ )
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
+ num_cycles=args.lr_num_cycles,
+ power=args.lr_power,
+ )
+
+ # Prepare everything with our `accelerator`.
+ if args.train_text_encoder:
+ unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ unet, text_encoder, optimizer, train_dataloader, lr_scheduler
+ )
+ else:
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ unet, optimizer, train_dataloader, lr_scheduler
+ )
+
+ # For mixed precision training we cast the text_encoder and vae weights to half-precision
+ # as these models are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ # Move vae and text_encoder to device and cast to weight_dtype
+ vae.to(accelerator.device, dtype=weight_dtype)
+ if not args.train_text_encoder:
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initialize automatically on the main process.
+ if accelerator.is_main_process:
+ accelerator.init_trackers("dreambooth", config=vars(args))
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = basename(args.resume_from_checkpoint)
+ else:
+ # Get the mos recent checkpoint
+ dirs = listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ resume_global_step = global_step * args.gradient_accumulation_steps
+ first_epoch = global_step // num_update_steps_per_epoch
+ resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
+
+ # Only show the progress bar once on each machine.
+ progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
+ progress_bar.set_description("Steps")
+
+ for epoch in range(first_epoch, args.num_train_epochs):
+ unet.train()
+ if args.train_text_encoder:
+ text_encoder.train()
+ for step, batch in enumerate(train_dataloader):
+ # Skip steps until we reach the resumed step
+ if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
+ if step % args.gradient_accumulation_steps == 0:
+ progress_bar.update(1)
+ continue
+
+ with accelerator.accumulate(unet):
+ # Convert images to latent space
+ latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
+ latents = latents * vae.config.scaling_factor
+
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(latents)
+ bsz = latents.shape[0]
+ # Sample a random timestep for each image
+ time_steps = torch.randint(
+ 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device
+ )
+ time_steps = time_steps.long()
+
+ # Add noise to the latents according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ noisy_latents = noise_scheduler.add_noise(latents, noise, time_steps)
+
+ # Get the text embedding for conditioning
+ encoder_hidden_states = text_encoder(batch["input_ids"])[0]
+
+ # Predict the noise residual
+ model_pred = unet(noisy_latents, time_steps, encoder_hidden_states).sample
+
+ # Get the target for loss depending on the prediction type
+ if noise_scheduler.config.prediction_type == "epsilon":
+ target = noise
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ target = noise_scheduler.get_velocity(latents, noise, time_steps)
+ else:
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
+
+ if args.with_prior_preservation:
+ # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
+ model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
+ target, target_prior = torch.chunk(target, 2, dim=0)
+
+ # Compute instance loss
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
+
+ # Compute prior loss
+ prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
+
+ # Add the prior loss to the instance loss.
+ loss = loss + args.prior_loss_weight * prior_loss
+ else:
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
+
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ params_to_clip = (
+ itertools.chain(unet.parameters(), text_encoder.parameters())
+ if args.train_text_encoder
+ else unet.parameters()
+ )
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad(set_to_none=args.set_grads_to_none)
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ global_step += 1
+
+ if accelerator.is_main_process:
+ if global_step % args.checkpointing_steps == 0:
+ save_path = join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ if (
+ args.validation_steps
+ and any(args.validation_prompt)
+ and global_step % args.validation_steps == 0
+ ):
+ images_set = generate_validation_images(
+ text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype
+ )
+ for images, validation_prompt in zip(images_set, args.validation_prompt):
+ if len(images) > 0:
+ label = str(uuid.uuid1())[:8] # generate an id for different set of images
+ log_validation_images_to_tracker(
+ images, label, validation_prompt, accelerator, global_step
+ )
+
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+ accelerator.log(logs, step=global_step)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ # Create the pipeline using the trained modules and save it.
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ pipeline = DiffusionPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ unet=accelerator.unwrap_model(unet),
+ text_encoder=accelerator.unwrap_model(text_encoder),
+ revision=args.revision,
+ )
+ pipeline.save_pretrained(args.output_dir)
+
+ if args.push_to_hub:
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ main(args)
diff --git a/diffusers/examples/research_projects/multi_subject_dreambooth_inpainting/README.md b/diffusers/examples/research_projects/multi_subject_dreambooth_inpainting/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..ffd8e304efceac7e5f921c9f3f45dcf9957939f8
--- /dev/null
+++ b/diffusers/examples/research_projects/multi_subject_dreambooth_inpainting/README.md
@@ -0,0 +1,93 @@
+# Multi Subject Dreambooth for Inpainting Models
+
+Please note that this project is not actively maintained. However, you can open an issue and tag @gzguevara.
+
+[DreamBooth](https://arxiv.org/abs/2208.12242) is a method to personalize text2image models like stable diffusion given just a few(3~5) images of a subject. This project consists of **two parts**. Training Stable Diffusion for inpainting requieres prompt-image-mask pairs. The Unet of inpainiting models have 5 additional input channels (4 for the encoded masked-image and 1 for the mask itself).
+
+**The first part**, the `multi_inpaint_dataset.ipynb` notebook, demonstrates how make a 🤗 dataset of prompt-image-mask pairs. You can, however, skip the first part and move straight to the second part with the example datasets in this project. ([cat toy dataset masked](https://huggingface.co/datasets/gzguevara/cat_toy_masked), [mr. potato head dataset masked](https://huggingface.co/datasets/gzguevara/mr_potato_head_masked))
+
+**The second part**, the `train_multi_subject_inpainting.py` training script, demonstrates how to implement a training procedure for one or more subjects and adapt it for stable diffusion for inpainting.
+
+## 1. Data Collection: Make Prompt-Image-Mask Pairs
+
+ Earlier training scripts have provided approaches like random masking for the training images. This project provides a notebook for more precise mask setting.
+
+The notebook can be found here: [](https://colab.research.google.com/drive/1JNEASI_B7pLW1srxhgln6nM0HoGAQT32?usp=sharing)
+
+The `multi_inpaint_dataset.ipynb` notebook, takes training & validation images, on which the user draws masks and provides prompts to make a prompt-image-mask pairs. This ensures that during training, the loss is computed on the area masking the object of interest, rather than on random areas. Moreover, the `multi_inpaint_dataset.ipynb` notebook allows you to build a validation dataset with corresponding masks for monitoring the training process. Example below:
+
+
+
+You can build multiple datasets for every subject and upload them to the 🤗 hub. Later, when launching the training script you can indicate the paths of the datasets, on which you would like to finetune Stable Diffusion for inpaining.
+
+## 2. Train Multi Subject Dreambooth for Inpainting
+
+### 2.1. Setting The Training Configuration
+
+Before launching the training script, make sure to select the inpainting the target model, the output directory and the 🤗 datasets.
+
+```bash
+export MODEL_NAME="runwayml/stable-diffusion-inpainting"
+export OUTPUT_DIR="path-to-save-model"
+
+export DATASET_1="gzguevara/mr_potato_head_masked"
+export DATASET_2="gzguevara/cat_toy_masked"
+... # Further paths to 🤗 datasets
+```
+
+### 2.2. Launching The Training Script
+
+```bash
+accelerate launch train_multi_subject_dreambooth_inpaint.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --instance_data_dir $DATASET_1 $DATASET_2 \
+ --output_dir=$OUTPUT_DIR \
+ --resolution=512 \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=2 \
+ --learning_rate=3e-6 \
+ --max_train_steps=500 \
+ --report_to_wandb
+```
+
+### 2.3. Fine-tune text encoder with the UNet.
+
+The script also allows to fine-tune the `text_encoder` along with the `unet`. It's been observed experimentally that fine-tuning `text_encoder` gives much better results especially on faces.
+Pass the `--train_text_encoder` argument to the script to enable training `text_encoder`.
+
+___Note: Training text encoder requires more memory, with this option the training won't fit on 16GB GPU. It needs at least 24GB VRAM.___
+
+```bash
+accelerate launch train_multi_subject_dreambooth_inpaint.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --instance_data_dir $DATASET_1 $DATASET_2 \
+ --output_dir=$OUTPUT_DIR \
+ --resolution=512 \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=2 \
+ --learning_rate=2e-6 \
+ --max_train_steps=500 \
+ --report_to_wandb \
+ --train_text_encoder
+```
+
+## 3. Results
+
+A [](https://wandb.ai/gzguevara/uncategorized/reports/Multi-Subject-Dreambooth-for-Inpainting--Vmlldzo2MzY5NDQ4?accessToken=y0nya2d7baguhbryxaikbfr1203amvn1jsmyl07vk122mrs7tnph037u1nqgse8t) is provided showing the training progress by every 50 steps. Note, the reported weights & baises run was performed on a A100 GPU with the following stetting:
+
+```bash
+accelerate launch train_multi_subject_dreambooth_inpaint.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --instance_data_dir $DATASET_1 $DATASET_2 \
+ --output_dir=$OUTPUT_DIR \
+ --resolution=512 \
+ --train_batch_size=10 \
+ --gradient_accumulation_steps=1 \
+ --learning_rate=1e-6 \
+ --max_train_steps=500 \
+ --report_to_wandb \
+ --train_text_encoder
+```
+Here you can see the target objects on my desk and next to my plant:
+
+
diff --git a/diffusers/examples/research_projects/multi_subject_dreambooth_inpainting/requirements.txt b/diffusers/examples/research_projects/multi_subject_dreambooth_inpainting/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..351287c77cbe1c6ca8917dc1ff6f184ddbb1378a
--- /dev/null
+++ b/diffusers/examples/research_projects/multi_subject_dreambooth_inpainting/requirements.txt
@@ -0,0 +1,8 @@
+accelerate>=0.16.0
+torchvision
+transformers>=4.25.1
+datasets>=2.16.0
+wandb>=0.16.1
+ftfy
+tensorboard
+Jinja2
\ No newline at end of file
diff --git a/diffusers/examples/research_projects/multi_subject_dreambooth_inpainting/train_multi_subject_dreambooth_inpainting.py b/diffusers/examples/research_projects/multi_subject_dreambooth_inpainting/train_multi_subject_dreambooth_inpainting.py
new file mode 100644
index 0000000000000000000000000000000000000000..567e5e3caaa36896b278aa816ffe32638b7b8734
--- /dev/null
+++ b/diffusers/examples/research_projects/multi_subject_dreambooth_inpainting/train_multi_subject_dreambooth_inpainting.py
@@ -0,0 +1,661 @@
+import argparse
+import copy
+import itertools
+import logging
+import math
+import os
+import random
+from pathlib import Path
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import ProjectConfiguration, set_seed
+from datasets import concatenate_datasets, load_dataset
+from PIL import Image
+from torch.utils.data import Dataset
+from torchvision import transforms
+from tqdm.auto import tqdm
+from transformers import CLIPTextModel, CLIPTokenizer
+
+from diffusers import (
+ AutoencoderKL,
+ DDPMScheduler,
+ StableDiffusionInpaintPipeline,
+ UNet2DConditionModel,
+)
+from diffusers.optimization import get_scheduler
+from diffusers.utils import check_min_version, is_wandb_available
+
+
+if is_wandb_available():
+ import wandb
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.13.0.dev0")
+
+logger = get_logger(__name__)
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument("--instance_data_dir", nargs="+", help="Instance data directories")
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="text-inversion-model",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--train_text_encoder", default=False, action="store_true", help="Whether to train the text encoder"
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument(
+ "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
+ )
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=5e-6,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default="no",
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose"
+ "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
+ "and an Nvidia Ampere GPU."
+ ),
+ )
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=1000,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
+ " checkpoints in case they are better than the last checkpoint and are suitable for resuming training"
+ " using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpointing_from",
+ type=int,
+ default=1000,
+ help=("Start to checkpoint from step"),
+ )
+ parser.add_argument(
+ "--validation_steps",
+ type=int,
+ default=50,
+ help=(
+ "Run validation every X steps. Validation consists of running the prompt"
+ " `args.validation_prompt` multiple times: `args.num_validation_images`"
+ " and logging the images."
+ ),
+ )
+ parser.add_argument(
+ "--validation_from",
+ type=int,
+ default=0,
+ help=("Start to validate from step"),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=(
+ "Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
+ " See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
+ " for more docs"
+ ),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--validation_project_name",
+ type=str,
+ default=None,
+ help="The w&b name.",
+ )
+ parser.add_argument(
+ "--report_to_wandb", default=False, action="store_true", help="Whether to report to weights and biases"
+ )
+
+ args = parser.parse_args()
+
+ return args
+
+
+def prepare_mask_and_masked_image(image, mask):
+ image = np.array(image.convert("RGB"))
+ image = image[None].transpose(0, 3, 1, 2)
+ image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
+
+ mask = np.array(mask.convert("L"))
+ mask = mask.astype(np.float32) / 255.0
+ mask = mask[None, None]
+ mask[mask < 0.5] = 0
+ mask[mask >= 0.5] = 1
+ mask = torch.from_numpy(mask)
+
+ masked_image = image * (mask < 0.5)
+
+ return mask, masked_image
+
+
+class DreamBoothDataset(Dataset):
+ def __init__(
+ self,
+ tokenizer,
+ datasets_paths,
+ ):
+ self.tokenizer = tokenizer
+ self.datasets_paths = (datasets_paths,)
+ self.datasets = [load_dataset(dataset_path) for dataset_path in self.datasets_paths[0]]
+ self.train_data = concatenate_datasets([dataset["train"] for dataset in self.datasets])
+ self.test_data = concatenate_datasets([dataset["test"] for dataset in self.datasets])
+
+ self.image_normalize = transforms.Compose(
+ [
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+ )
+
+ def set_image(self, img, switch):
+ if img.mode not in ["RGB", "L"]:
+ img = img.convert("RGB")
+
+ if switch:
+ img = img.transpose(Image.FLIP_LEFT_RIGHT)
+
+ img = img.resize((512, 512), Image.BILINEAR)
+
+ return img
+
+ def __len__(self):
+ return len(self.train_data)
+
+ def __getitem__(self, index):
+ # Lettings
+ example = {}
+ img_idx = index % len(self.train_data)
+ switch = random.choice([True, False])
+
+ # Load image
+ image = self.set_image(self.train_data[img_idx]["image"], switch)
+
+ # Normalize image
+ image_norm = self.image_normalize(image)
+
+ # Tokenise prompt
+ tokenized_prompt = self.tokenizer(
+ self.train_data[img_idx]["prompt"],
+ padding="do_not_pad",
+ truncation=True,
+ max_length=self.tokenizer.model_max_length,
+ ).input_ids
+
+ # Load masks for image
+ masks = [
+ self.set_image(self.train_data[img_idx][key], switch) for key in self.train_data[img_idx] if "mask" in key
+ ]
+
+ # Build example
+ example["PIL_image"] = image
+ example["instance_image"] = image_norm
+ example["instance_prompt_id"] = tokenized_prompt
+ example["instance_masks"] = masks
+
+ return example
+
+
+def weighted_mask(masks):
+ # Convert each mask to a NumPy array and ensure it's binary
+ mask_arrays = [np.array(mask) / 255 for mask in masks] # Normalizing to 0-1 range
+
+ # Generate random weights and apply them to each mask
+ weights = [random.random() for _ in masks]
+ weights = [weight / sum(weights) for weight in weights]
+ weighted_masks = [mask * weight for mask, weight in zip(mask_arrays, weights)]
+
+ # Sum the weighted masks
+ summed_mask = np.sum(weighted_masks, axis=0)
+
+ # Apply a threshold to create the final mask
+ threshold = 0.5 # This threshold can be adjusted
+ result_mask = summed_mask >= threshold
+
+ # Convert the result back to a PIL image
+ return Image.fromarray(result_mask.astype(np.uint8) * 255)
+
+
+def collate_fn(examples, tokenizer):
+ input_ids = [example["instance_prompt_id"] for example in examples]
+ pixel_values = [example["instance_image"] for example in examples]
+
+ masks, masked_images = [], []
+
+ for example in examples:
+ # generate a random mask
+ mask = weighted_mask(example["instance_masks"])
+
+ # prepare mask and masked image
+ mask, masked_image = prepare_mask_and_masked_image(example["PIL_image"], mask)
+
+ masks.append(mask)
+ masked_images.append(masked_image)
+
+ pixel_values = torch.stack(pixel_values).to(memory_format=torch.contiguous_format).float()
+ masks = torch.stack(masks)
+ masked_images = torch.stack(masked_images)
+ input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids
+
+ batch = {"input_ids": input_ids, "pixel_values": pixel_values, "masks": masks, "masked_images": masked_images}
+
+ return batch
+
+
+def log_validation(pipeline, text_encoder, unet, val_pairs, accelerator):
+ # update pipeline (note: unet and vae are loaded again in float32)
+ pipeline.text_encoder = accelerator.unwrap_model(text_encoder)
+ pipeline.unet = accelerator.unwrap_model(unet)
+
+ with torch.autocast("cuda"):
+ val_results = [{"data_or_path": pipeline(**pair).images[0], "caption": pair["prompt"]} for pair in val_pairs]
+
+ torch.cuda.empty_cache()
+
+ wandb.log({"validation": [wandb.Image(**val_result) for val_result in val_results]})
+
+
+def checkpoint(args, global_step, accelerator):
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+
+def main():
+ args = parse_args()
+
+ project_config = ProjectConfiguration(
+ total_limit=args.checkpoints_total_limit,
+ project_dir=args.output_dir,
+ logging_dir=Path(args.output_dir, args.logging_dir),
+ )
+
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ project_config=project_config,
+ log_with="wandb" if args.report_to_wandb else None,
+ )
+
+ if args.report_to_wandb and not is_wandb_available():
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
+
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+
+ # Load the tokenizer & models and create wrapper for stable diffusion
+ tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
+ text_encoder = CLIPTextModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder"
+ ).requires_grad_(args.train_text_encoder)
+ vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae").requires_grad_(False)
+ unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ optimizer = torch.optim.AdamW(
+ params=itertools.chain(unet.parameters(), text_encoder.parameters())
+ if args.train_text_encoder
+ else unet.parameters(),
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
+
+ train_dataset = DreamBoothDataset(
+ tokenizer=tokenizer,
+ datasets_paths=args.instance_data_dir,
+ )
+
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset,
+ batch_size=args.train_batch_size,
+ shuffle=True,
+ collate_fn=lambda examples: collate_fn(examples, tokenizer),
+ )
+
+ # Scheduler and math around the number of training steps.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
+ )
+
+ if args.train_text_encoder:
+ unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ unet, text_encoder, optimizer, train_dataloader, lr_scheduler
+ )
+ else:
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ unet, optimizer, train_dataloader, lr_scheduler
+ )
+
+ accelerator.register_for_checkpointing(lr_scheduler)
+
+ if args.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif args.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+ else:
+ weight_dtype = torch.float32
+
+ # Move text_encode and vae to gpu.
+ # For mixed precision training we cast the text_encoder and vae weights to half-precision
+ # as these models are only used for inference, keeping weights in full precision is not required.
+ vae.to(accelerator.device, dtype=weight_dtype)
+ if not args.train_text_encoder:
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+
+ # Afterwards we calculate our number of training epochs
+ num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ tracker_config = vars(copy.deepcopy(args))
+ accelerator.init_trackers(args.validation_project_name, config=tracker_config)
+
+ # create validation pipeline (note: unet and vae are loaded again in float32)
+ val_pipeline = StableDiffusionInpaintPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ tokenizer=tokenizer,
+ text_encoder=text_encoder,
+ unet=unet,
+ vae=vae,
+ torch_dtype=weight_dtype,
+ safety_checker=None,
+ )
+ val_pipeline.set_progress_bar_config(disable=True)
+
+ # prepare validation dataset
+ val_pairs = [
+ {
+ "image": example["image"],
+ "mask_image": mask,
+ "prompt": example["prompt"],
+ }
+ for example in train_dataset.test_data
+ for mask in [example[key] for key in example if "mask" in key]
+ ]
+
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
+ def save_model_hook(models, weights, output_dir):
+ if accelerator.is_main_process:
+ for model in models:
+ sub_dir = "unet" if isinstance(model, type(accelerator.unwrap_model(unet))) else "text_encoder"
+ model.save_pretrained(os.path.join(output_dir, sub_dir))
+
+ # make sure to pop weight so that corresponding model is not saved again
+ weights.pop()
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+
+ print()
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
+ logger.info(f" Num Epochs = {num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+
+ global_step = 0
+ first_epoch = 0
+
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the most recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ resume_global_step = global_step * args.gradient_accumulation_steps
+ first_epoch = global_step // num_update_steps_per_epoch
+ resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
+
+ # Only show the progress bar once on each machine.
+ progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
+ progress_bar.set_description("Steps")
+
+ for epoch in range(first_epoch, num_train_epochs):
+ unet.train()
+ for step, batch in enumerate(train_dataloader):
+ # Skip steps until we reach the resumed step
+ if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
+ if step % args.gradient_accumulation_steps == 0:
+ progress_bar.update(1)
+ continue
+
+ with accelerator.accumulate(unet):
+ # Convert images to latent space
+ latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
+ latents = latents * vae.config.scaling_factor
+
+ # Convert masked images to latent space
+ masked_latents = vae.encode(
+ batch["masked_images"].reshape(batch["pixel_values"].shape).to(dtype=weight_dtype)
+ ).latent_dist.sample()
+ masked_latents = masked_latents * vae.config.scaling_factor
+
+ masks = batch["masks"]
+ # resize the mask to latents shape as we concatenate the mask to the latents
+ mask = torch.stack(
+ [
+ torch.nn.functional.interpolate(mask, size=(args.resolution // 8, args.resolution // 8))
+ for mask in masks
+ ]
+ )
+ mask = mask.reshape(-1, 1, args.resolution // 8, args.resolution // 8)
+
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(latents)
+ bsz = latents.shape[0]
+ # Sample a random timestep for each image
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
+ timesteps = timesteps.long()
+
+ # Add noise to the latents according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
+
+ # concatenate the noised latents with the mask and the masked latents
+ latent_model_input = torch.cat([noisy_latents, mask, masked_latents], dim=1)
+
+ # Get the text embedding for conditioning
+ encoder_hidden_states = text_encoder(batch["input_ids"])[0]
+
+ # Predict the noise residual
+ noise_pred = unet(latent_model_input, timesteps, encoder_hidden_states).sample
+
+ # Get the target for loss depending on the prediction type
+ if noise_scheduler.config.prediction_type == "epsilon":
+ target = noise
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
+ else:
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
+
+ loss = F.mse_loss(noise_pred.float(), target.float(), reduction="mean")
+
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ params_to_clip = (
+ itertools.chain(unet.parameters(), text_encoder.parameters())
+ if args.train_text_encoder
+ else unet.parameters()
+ )
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
+
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ global_step += 1
+
+ if accelerator.is_main_process:
+ if (
+ global_step % args.validation_steps == 0
+ and global_step >= args.validation_from
+ and args.report_to_wandb
+ ):
+ log_validation(
+ val_pipeline,
+ text_encoder,
+ unet,
+ val_pairs,
+ accelerator,
+ )
+
+ if global_step % args.checkpointing_steps == 0 and global_step >= args.checkpointing_from:
+ checkpoint(
+ args,
+ global_step,
+ accelerator,
+ )
+
+ # Step logging
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+ accelerator.log(logs, step=global_step)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ accelerator.wait_for_everyone()
+
+ # Terminate training
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/diffusers/examples/research_projects/multi_token_textual_inversion/README.md b/diffusers/examples/research_projects/multi_token_textual_inversion/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..5e0aaf2c0575b5f404dbf9c4b4ec9ac5e4e00a3b
--- /dev/null
+++ b/diffusers/examples/research_projects/multi_token_textual_inversion/README.md
@@ -0,0 +1,143 @@
+## [Deprecated] Multi Token Textual Inversion
+
+**IMPORTART: This research project is deprecated. Multi Token Textual Inversion is now supported natively in [the official textual inversion example](https://github.com/huggingface/diffusers/tree/main/examples/textual_inversion#running-locally-with-pytorch).**
+
+The author of this project is [Isamu Isozaki](https://github.com/isamu-isozaki) - please make sure to tag the author for issue and PRs as well as @patrickvonplaten.
+
+We add multi token support to textual inversion. I added
+1. num_vec_per_token for the number of used to reference that token
+2. progressive_tokens for progressively training the token from 1 token to 2 token etc
+3. progressive_tokens_max_steps for the max number of steps until we start full training
+4. vector_shuffle to shuffle vectors
+
+Feel free to add these options to your training! In practice num_vec_per_token around 10+vector shuffle works great!
+
+## Textual Inversion fine-tuning example
+
+[Textual inversion](https://arxiv.org/abs/2208.01618) is a method to personalize text2image models like stable diffusion on your own images using just 3-5 examples.
+The `textual_inversion.py` script shows how to implement the training procedure and adapt it for stable diffusion.
+
+## Running on Colab
+
+Colab for training
+[](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/sd_textual_inversion_training.ipynb)
+
+Colab for inference
+[](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/stable_conceptualizer_inference.ipynb)
+
+## Running locally with PyTorch
+### Installing the dependencies
+
+Before running the scripts, make sure to install the library's training dependencies:
+
+**Important**
+
+To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
+```bash
+git clone https://github.com/huggingface/diffusers
+cd diffusers
+pip install .
+```
+
+Then cd in the example folder and run
+```bash
+pip install -r requirements.txt
+```
+
+And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
+
+```bash
+accelerate config
+```
+
+
+### Cat toy example
+
+You need to accept the model license before downloading or using the weights. In this example we'll use model version `v1-5`, so you'll need to visit [its card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5), read the license and tick the checkbox if you agree.
+
+You have to be a registered user in 🤗 Hugging Face Hub, and you'll also need to use an access token for the code to work. For more information on access tokens, please refer to [this section of the documentation](https://huggingface.co/docs/hub/security-tokens).
+
+Run the following command to authenticate your token
+
+```bash
+huggingface-cli login
+```
+
+If you have already cloned the repo, then you won't need to go through these steps.
+
+
+
+Now let's get our dataset.Download 3-4 images from [here](https://drive.google.com/drive/folders/1fmJMs25nxS_rSNqS5hTcRdLem_YQXbq5) and save them in a directory. This will be our training data.
+
+And launch the training using
+
+**___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.___**
+
+```bash
+export MODEL_NAME="stable-diffusion-v1-5/stable-diffusion-v1-5"
+export DATA_DIR="path-to-dir-containing-images"
+
+accelerate launch textual_inversion.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --train_data_dir=$DATA_DIR \
+ --learnable_property="object" \
+ --placeholder_token="" --initializer_token="toy" \
+ --resolution=512 \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=4 \
+ --max_train_steps=3000 \
+ --learning_rate=5.0e-04 --scale_lr \
+ --lr_scheduler="constant" \
+ --lr_warmup_steps=0 \
+ --output_dir="textual_inversion_cat"
+```
+
+A full training run takes ~1 hour on one V100 GPU.
+
+### Inference
+
+Once you have trained a model using above command, the inference can be done simply using the `StableDiffusionPipeline`. Make sure to include the `placeholder_token` in your prompt.
+
+```python
+from diffusers import StableDiffusionPipeline
+
+model_id = "path-to-your-trained-model"
+pipe = StableDiffusionPipeline.from_pretrained(model_id,torch_dtype=torch.float16).to("cuda")
+
+prompt = "A backpack"
+
+image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5).images[0]
+
+image.save("cat-backpack.png")
+```
+
+
+## Training with Flax/JAX
+
+For faster training on TPUs and GPUs you can leverage the flax training example. Follow the instructions above to get the model and dataset before running the script.
+
+Before running the scripts, make sure to install the library's training dependencies:
+
+```bash
+pip install -U -r requirements_flax.txt
+```
+
+```bash
+export MODEL_NAME="duongna/stable-diffusion-v1-4-flax"
+export DATA_DIR="path-to-dir-containing-images"
+
+python textual_inversion_flax.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --train_data_dir=$DATA_DIR \
+ --learnable_property="object" \
+ --placeholder_token="" --initializer_token="toy" \
+ --resolution=512 \
+ --train_batch_size=1 \
+ --max_train_steps=3000 \
+ --learning_rate=5.0e-04 --scale_lr \
+ --output_dir="textual_inversion_cat"
+```
+It should be at least 70% faster than the PyTorch script with the same configuration.
+
+### Training with xformers:
+You can enable memory efficient attention by [installing xFormers](https://github.com/facebookresearch/xformers#installing-xformers) and padding the `--enable_xformers_memory_efficient_attention` argument to the script. This is not available with the Flax/JAX implementation.
diff --git a/diffusers/examples/research_projects/multi_token_textual_inversion/multi_token_clip.py b/diffusers/examples/research_projects/multi_token_textual_inversion/multi_token_clip.py
new file mode 100644
index 0000000000000000000000000000000000000000..cdc344376ada63da590729cb82bc521c6fd2b5d1
--- /dev/null
+++ b/diffusers/examples/research_projects/multi_token_textual_inversion/multi_token_clip.py
@@ -0,0 +1,104 @@
+"""
+The main idea for this code is to provide a way for users to not need to bother with the hassle of multiple tokens for a concept by typing
+a photo of _0 _1 ... and so on
+and instead just do
+a photo of
+which gets translated to the above. This needs to work for both inference and training.
+For inference,
+the tokenizer encodes the text. So, we would want logic for our tokenizer to replace the placeholder token with
+it's underlying vectors
+For training,
+we would want to abstract away some logic like
+1. Adding tokens
+2. Updating gradient mask
+3. Saving embeddings
+to our Util class here.
+so
+TODO:
+1. have tokenizer keep track of concept, multiconcept pairs and replace during encode call x
+2. have mechanism for adding tokens x
+3. have mech for saving emebeddings x
+4. get mask to update x
+5. Loading tokens from embedding x
+6. Integrate to training x
+7. Test
+"""
+
+import copy
+import random
+
+from transformers import CLIPTokenizer
+
+
+class MultiTokenCLIPTokenizer(CLIPTokenizer):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.token_map = {}
+
+ def try_adding_tokens(self, placeholder_token, *args, **kwargs):
+ num_added_tokens = super().add_tokens(placeholder_token, *args, **kwargs)
+ if num_added_tokens == 0:
+ raise ValueError(
+ f"The tokenizer already contains the token {placeholder_token}. Please pass a different"
+ " `placeholder_token` that is not already in the tokenizer."
+ )
+
+ def add_placeholder_tokens(self, placeholder_token, *args, num_vec_per_token=1, **kwargs):
+ output = []
+ if num_vec_per_token == 1:
+ self.try_adding_tokens(placeholder_token, *args, **kwargs)
+ output.append(placeholder_token)
+ else:
+ output = []
+ for i in range(num_vec_per_token):
+ ith_token = placeholder_token + f"_{i}"
+ self.try_adding_tokens(ith_token, *args, **kwargs)
+ output.append(ith_token)
+ # handle cases where there is a new placeholder token that contains the current placeholder token but is larger
+ for token in self.token_map:
+ if token in placeholder_token:
+ raise ValueError(
+ f"The tokenizer already has placeholder token {token} that can get confused with"
+ f" {placeholder_token}keep placeholder tokens independent"
+ )
+ self.token_map[placeholder_token] = output
+
+ def replace_placeholder_tokens_in_text(self, text, vector_shuffle=False, prop_tokens_to_load=1.0):
+ """
+ Here, we replace the placeholder tokens in text recorded in token_map so that the text_encoder
+ can encode them
+ vector_shuffle was inspired by https://github.com/rinongal/textual_inversion/pull/119
+ where shuffling tokens were found to force the model to learn the concepts more descriptively.
+ """
+ if isinstance(text, list):
+ output = []
+ for i in range(len(text)):
+ output.append(self.replace_placeholder_tokens_in_text(text[i], vector_shuffle=vector_shuffle))
+ return output
+ for placeholder_token in self.token_map:
+ if placeholder_token in text:
+ tokens = self.token_map[placeholder_token]
+ tokens = tokens[: 1 + int(len(tokens) * prop_tokens_to_load)]
+ if vector_shuffle:
+ tokens = copy.copy(tokens)
+ random.shuffle(tokens)
+ text = text.replace(placeholder_token, " ".join(tokens))
+ return text
+
+ def __call__(self, text, *args, vector_shuffle=False, prop_tokens_to_load=1.0, **kwargs):
+ return super().__call__(
+ self.replace_placeholder_tokens_in_text(
+ text, vector_shuffle=vector_shuffle, prop_tokens_to_load=prop_tokens_to_load
+ ),
+ *args,
+ **kwargs,
+ )
+
+ def encode(self, text, *args, vector_shuffle=False, prop_tokens_to_load=1.0, **kwargs):
+ return super().encode(
+ self.replace_placeholder_tokens_in_text(
+ text, vector_shuffle=vector_shuffle, prop_tokens_to_load=prop_tokens_to_load
+ ),
+ *args,
+ **kwargs,
+ )
diff --git a/diffusers/examples/research_projects/multi_token_textual_inversion/requirements.txt b/diffusers/examples/research_projects/multi_token_textual_inversion/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..7a612982f4abbaa64f83db52e411a1235a372259
--- /dev/null
+++ b/diffusers/examples/research_projects/multi_token_textual_inversion/requirements.txt
@@ -0,0 +1,6 @@
+accelerate>=0.16.0
+torchvision
+transformers>=4.25.1
+ftfy
+tensorboard
+Jinja2
diff --git a/diffusers/examples/research_projects/multi_token_textual_inversion/requirements_flax.txt b/diffusers/examples/research_projects/multi_token_textual_inversion/requirements_flax.txt
new file mode 100644
index 0000000000000000000000000000000000000000..8f85ad523a3b46b65abf0138c05ecdd656e6845c
--- /dev/null
+++ b/diffusers/examples/research_projects/multi_token_textual_inversion/requirements_flax.txt
@@ -0,0 +1,8 @@
+transformers>=4.25.1
+flax
+optax
+torch
+torchvision
+ftfy
+tensorboard
+Jinja2
diff --git a/diffusers/examples/research_projects/multi_token_textual_inversion/textual_inversion.py b/diffusers/examples/research_projects/multi_token_textual_inversion/textual_inversion.py
new file mode 100644
index 0000000000000000000000000000000000000000..57ad77477b0d491044a1fc852a7ac3abe9563b1a
--- /dev/null
+++ b/diffusers/examples/research_projects/multi_token_textual_inversion/textual_inversion.py
@@ -0,0 +1,937 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+import argparse
+import logging
+import math
+import os
+import random
+from pathlib import Path
+
+import numpy as np
+import PIL
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import ProjectConfiguration, set_seed
+from huggingface_hub import create_repo, upload_folder
+from multi_token_clip import MultiTokenCLIPTokenizer
+
+# TODO: remove and import from diffusers.utils when the new version of diffusers is released
+from packaging import version
+from PIL import Image
+from torch.utils.data import Dataset
+from torchvision import transforms
+from tqdm.auto import tqdm
+from transformers import CLIPTextModel
+
+import diffusers
+from diffusers import (
+ AutoencoderKL,
+ DDPMScheduler,
+ DiffusionPipeline,
+ DPMSolverMultistepScheduler,
+ StableDiffusionPipeline,
+ UNet2DConditionModel,
+)
+from diffusers.optimization import get_scheduler
+from diffusers.utils import check_min_version, is_wandb_available
+from diffusers.utils.import_utils import is_xformers_available
+
+
+if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
+ PIL_INTERPOLATION = {
+ "linear": PIL.Image.Resampling.BILINEAR,
+ "bilinear": PIL.Image.Resampling.BILINEAR,
+ "bicubic": PIL.Image.Resampling.BICUBIC,
+ "lanczos": PIL.Image.Resampling.LANCZOS,
+ "nearest": PIL.Image.Resampling.NEAREST,
+ }
+else:
+ PIL_INTERPOLATION = {
+ "linear": PIL.Image.LINEAR,
+ "bilinear": PIL.Image.BILINEAR,
+ "bicubic": PIL.Image.BICUBIC,
+ "lanczos": PIL.Image.LANCZOS,
+ "nearest": PIL.Image.NEAREST,
+ }
+# ------------------------------------------------------------------------------
+
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.14.0.dev0")
+
+logger = get_logger(__name__)
+
+
+def add_tokens(tokenizer, text_encoder, placeholder_token, num_vec_per_token=1, initializer_token=None):
+ """
+ Add tokens to the tokenizer and set the initial value of token embeddings
+ """
+ tokenizer.add_placeholder_tokens(placeholder_token, num_vec_per_token=num_vec_per_token)
+ text_encoder.resize_token_embeddings(len(tokenizer))
+ token_embeds = text_encoder.get_input_embeddings().weight.data
+ placeholder_token_ids = tokenizer.encode(placeholder_token, add_special_tokens=False)
+ if initializer_token:
+ token_ids = tokenizer.encode(initializer_token, add_special_tokens=False)
+ for i, placeholder_token_id in enumerate(placeholder_token_ids):
+ token_embeds[placeholder_token_id] = token_embeds[token_ids[i * len(token_ids) // num_vec_per_token]]
+ else:
+ for i, placeholder_token_id in enumerate(placeholder_token_ids):
+ token_embeds[placeholder_token_id] = torch.randn_like(token_embeds[placeholder_token_id])
+ return placeholder_token
+
+
+def save_progress(tokenizer, text_encoder, accelerator, save_path):
+ for placeholder_token in tokenizer.token_map:
+ placeholder_token_ids = tokenizer.encode(placeholder_token, add_special_tokens=False)
+ learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[placeholder_token_ids]
+ if len(placeholder_token_ids) == 1:
+ learned_embeds = learned_embeds[None]
+ learned_embeds_dict = {placeholder_token: learned_embeds.detach().cpu()}
+ torch.save(learned_embeds_dict, save_path)
+
+
+def load_multitoken_tokenizer(tokenizer, text_encoder, learned_embeds_dict):
+ for placeholder_token in learned_embeds_dict:
+ placeholder_embeds = learned_embeds_dict[placeholder_token]
+ num_vec_per_token = placeholder_embeds.shape[0]
+ placeholder_embeds = placeholder_embeds.to(dtype=text_encoder.dtype)
+ add_tokens(tokenizer, text_encoder, placeholder_token, num_vec_per_token=num_vec_per_token)
+ placeholder_token_ids = tokenizer.encode(placeholder_token, add_special_tokens=False)
+ token_embeds = text_encoder.get_input_embeddings().weight.data
+ for i, placeholder_token_id in enumerate(placeholder_token_ids):
+ token_embeds[placeholder_token_id] = placeholder_embeds[i]
+
+
+def load_multitoken_tokenizer_from_automatic(tokenizer, text_encoder, automatic_dict, placeholder_token):
+ """
+ Automatic1111's tokens have format
+ {'string_to_token': {'*': 265}, 'string_to_param': {'*': tensor([[ 0.0833, 0.0030, 0.0057, ..., -0.0264, -0.0616, -0.0529],
+ [ 0.0058, -0.0190, -0.0584, ..., -0.0025, -0.0945, -0.0490],
+ [ 0.0916, 0.0025, 0.0365, ..., -0.0685, -0.0124, 0.0728],
+ [ 0.0812, -0.0199, -0.0100, ..., -0.0581, -0.0780, 0.0254]],
+ requires_grad=True)}, 'name': 'FloralMarble-400', 'step': 399, 'sd_checkpoint': '4bdfc29c', 'sd_checkpoint_name': 'SD2.1-768'}
+ """
+ learned_embeds_dict = {}
+ learned_embeds_dict[placeholder_token] = automatic_dict["string_to_param"]["*"]
+ load_multitoken_tokenizer(tokenizer, text_encoder, learned_embeds_dict)
+
+
+def get_mask(tokenizer, accelerator):
+ # Get the mask of the weights that won't change
+ mask = torch.ones(len(tokenizer)).to(accelerator.device, dtype=torch.bool)
+ for placeholder_token in tokenizer.token_map:
+ placeholder_token_ids = tokenizer.encode(placeholder_token, add_special_tokens=False)
+ for i in range(len(placeholder_token_ids)):
+ mask = mask & (torch.arange(len(tokenizer)) != placeholder_token_ids[i]).to(accelerator.device)
+ return mask
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ parser.add_argument(
+ "--progressive_tokens_max_steps",
+ type=int,
+ default=2000,
+ help="The number of steps until all tokens will be used.",
+ )
+ parser.add_argument(
+ "--progressive_tokens",
+ action="store_true",
+ help="Progressively train the tokens. For example, first train for 1 token, then 2 tokens and so on.",
+ )
+ parser.add_argument("--vector_shuffle", action="store_true", help="Shuffling tokens durint training")
+ parser.add_argument(
+ "--num_vec_per_token",
+ type=int,
+ default=1,
+ help=(
+ "The number of vectors used to represent the placeholder token. The higher the number, the better the"
+ " result at the cost of editability. This can be fixed by prompt editing."
+ ),
+ )
+ parser.add_argument(
+ "--save_steps",
+ type=int,
+ default=500,
+ help="Save learned_embeds.bin every X updates steps.",
+ )
+ parser.add_argument(
+ "--only_save_embeds",
+ action="store_true",
+ default=False,
+ help="Save only the embeddings for the new concept.",
+ )
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--tokenizer_name",
+ type=str,
+ default=None,
+ help="Pretrained tokenizer name or path if not the same as model_name",
+ )
+ parser.add_argument(
+ "--train_data_dir", type=str, default=None, required=True, help="A folder containing the training data."
+ )
+ parser.add_argument(
+ "--placeholder_token",
+ type=str,
+ default=None,
+ required=True,
+ help="A token to use as a placeholder for the concept.",
+ )
+ parser.add_argument(
+ "--initializer_token", type=str, default=None, required=True, help="A token to use as initializer word."
+ )
+ parser.add_argument("--learnable_property", type=str, default="object", help="Choose between 'object' and 'style'")
+ parser.add_argument("--repeats", type=int, default=100, help="How many times to repeat the training data.")
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="text-inversion-model",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution."
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=100)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=5000,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default="no",
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose"
+ "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
+ "and an Nvidia Ampere GPU."
+ ),
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument(
+ "--validation_prompt",
+ type=str,
+ default=None,
+ help="A prompt that is used during validation to verify that the model is learning.",
+ )
+ parser.add_argument(
+ "--num_validation_images",
+ type=int,
+ default=4,
+ help="Number of images that should be generated during validation with `validation_prompt`.",
+ )
+ parser.add_argument(
+ "--validation_epochs",
+ type=int,
+ default=50,
+ help=(
+ "Run validation every X epochs. Validation consists of running the prompt"
+ " `args.validation_prompt` multiple times: `args.num_validation_images`"
+ " and logging the images."
+ ),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=(
+ "Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
+ " See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
+ " for more docs"
+ ),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
+ )
+
+ args = parser.parse_args()
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ if args.train_data_dir is None:
+ raise ValueError("You must specify a train data directory.")
+
+ return args
+
+
+imagenet_templates_small = [
+ "a photo of a {}",
+ "a rendering of a {}",
+ "a cropped photo of the {}",
+ "the photo of a {}",
+ "a photo of a clean {}",
+ "a photo of a dirty {}",
+ "a dark photo of the {}",
+ "a photo of my {}",
+ "a photo of the cool {}",
+ "a close-up photo of a {}",
+ "a bright photo of the {}",
+ "a cropped photo of a {}",
+ "a photo of the {}",
+ "a good photo of the {}",
+ "a photo of one {}",
+ "a close-up photo of the {}",
+ "a rendition of the {}",
+ "a photo of the clean {}",
+ "a rendition of a {}",
+ "a photo of a nice {}",
+ "a good photo of a {}",
+ "a photo of the nice {}",
+ "a photo of the small {}",
+ "a photo of the weird {}",
+ "a photo of the large {}",
+ "a photo of a cool {}",
+ "a photo of a small {}",
+]
+
+imagenet_style_templates_small = [
+ "a painting in the style of {}",
+ "a rendering in the style of {}",
+ "a cropped painting in the style of {}",
+ "the painting in the style of {}",
+ "a clean painting in the style of {}",
+ "a dirty painting in the style of {}",
+ "a dark painting in the style of {}",
+ "a picture in the style of {}",
+ "a cool painting in the style of {}",
+ "a close-up painting in the style of {}",
+ "a bright painting in the style of {}",
+ "a cropped painting in the style of {}",
+ "a good painting in the style of {}",
+ "a close-up painting in the style of {}",
+ "a rendition in the style of {}",
+ "a nice painting in the style of {}",
+ "a small painting in the style of {}",
+ "a weird painting in the style of {}",
+ "a large painting in the style of {}",
+]
+
+
+class TextualInversionDataset(Dataset):
+ def __init__(
+ self,
+ data_root,
+ tokenizer,
+ learnable_property="object", # [object, style]
+ size=512,
+ repeats=100,
+ interpolation="bicubic",
+ flip_p=0.5,
+ set="train",
+ placeholder_token="*",
+ center_crop=False,
+ vector_shuffle=False,
+ progressive_tokens=False,
+ ):
+ self.data_root = data_root
+ self.tokenizer = tokenizer
+ self.learnable_property = learnable_property
+ self.size = size
+ self.placeholder_token = placeholder_token
+ self.center_crop = center_crop
+ self.flip_p = flip_p
+ self.vector_shuffle = vector_shuffle
+ self.progressive_tokens = progressive_tokens
+ self.prop_tokens_to_load = 0
+
+ self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)]
+
+ self.num_images = len(self.image_paths)
+ self._length = self.num_images
+
+ if set == "train":
+ self._length = self.num_images * repeats
+
+ self.interpolation = {
+ "linear": PIL_INTERPOLATION["linear"],
+ "bilinear": PIL_INTERPOLATION["bilinear"],
+ "bicubic": PIL_INTERPOLATION["bicubic"],
+ "lanczos": PIL_INTERPOLATION["lanczos"],
+ }[interpolation]
+
+ self.templates = imagenet_style_templates_small if learnable_property == "style" else imagenet_templates_small
+ self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p)
+
+ def __len__(self):
+ return self._length
+
+ def __getitem__(self, i):
+ example = {}
+ image = Image.open(self.image_paths[i % self.num_images])
+
+ if not image.mode == "RGB":
+ image = image.convert("RGB")
+
+ placeholder_string = self.placeholder_token
+ text = random.choice(self.templates).format(placeholder_string)
+
+ example["input_ids"] = self.tokenizer.encode(
+ text,
+ padding="max_length",
+ truncation=True,
+ max_length=self.tokenizer.model_max_length,
+ return_tensors="pt",
+ vector_shuffle=self.vector_shuffle,
+ prop_tokens_to_load=self.prop_tokens_to_load if self.progressive_tokens else 1.0,
+ )[0]
+
+ # default to score-sde preprocessing
+ img = np.array(image).astype(np.uint8)
+
+ if self.center_crop:
+ crop = min(img.shape[0], img.shape[1])
+ (
+ h,
+ w,
+ ) = (
+ img.shape[0],
+ img.shape[1],
+ )
+ img = img[(h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2]
+
+ image = Image.fromarray(img)
+ image = image.resize((self.size, self.size), resample=self.interpolation)
+
+ image = self.flip_transform(image)
+ image = np.array(image).astype(np.uint8)
+ image = (image / 127.5 - 1.0).astype(np.float32)
+
+ example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1)
+ return example
+
+
+def main():
+ args = parse_args()
+ if args.report_to == "wandb" and args.hub_token is not None:
+ raise ValueError(
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
+ " Please use `huggingface-cli login` to authenticate with the Hub."
+ )
+
+ logging_dir = os.path.join(args.output_dir, args.logging_dir)
+ accelerator_project_config = ProjectConfiguration(
+ total_limit=args.checkpoints_total_limit, project_dir=args.output_dir, logging_dir=logging_dir
+ )
+
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ )
+
+ # Disable AMP for MPS.
+ if torch.backends.mps.is_available():
+ accelerator.native_amp = False
+
+ if args.report_to == "wandb":
+ if not is_wandb_available():
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
+ import wandb
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
+ ).repo_id
+
+ # Load tokenizer
+ if args.tokenizer_name:
+ tokenizer = MultiTokenCLIPTokenizer.from_pretrained(args.tokenizer_name)
+ elif args.pretrained_model_name_or_path:
+ tokenizer = MultiTokenCLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
+
+ # Load scheduler and models
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
+ text_encoder = CLIPTextModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
+ )
+ vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
+ unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
+ )
+ if is_xformers_available():
+ try:
+ unet.enable_xformers_memory_efficient_attention()
+ except Exception as e:
+ logger.warning(
+ "Could not enable memory efficient attention. Make sure xformers is installed"
+ f" correctly and a GPU is available: {e}"
+ )
+ add_tokens(tokenizer, text_encoder, args.placeholder_token, args.num_vec_per_token, args.initializer_token)
+
+ # Freeze vae and unet
+ vae.requires_grad_(False)
+ unet.requires_grad_(False)
+ # Freeze all parameters except for the token embeddings in text encoder
+ text_encoder.text_model.encoder.requires_grad_(False)
+ text_encoder.text_model.final_layer_norm.requires_grad_(False)
+ text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)
+
+ if args.gradient_checkpointing:
+ # Keep unet in train mode if we are using gradient checkpointing to save memory.
+ # The dropout cannot be != 0 so it doesn't matter if we are in eval or train mode.
+ unet.train()
+ text_encoder.gradient_checkpointing_enable()
+ unet.enable_gradient_checkpointing()
+
+ if args.enable_xformers_memory_efficient_attention:
+ if is_xformers_available():
+ import xformers
+
+ xformers_version = version.parse(xformers.__version__)
+ if xformers_version == version.parse("0.0.16"):
+ logger.warning(
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
+ )
+ unet.enable_xformers_memory_efficient_attention()
+ else:
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ # Initialize the optimizer
+ optimizer = torch.optim.AdamW(
+ text_encoder.get_input_embeddings().parameters(), # only optimize the embeddings
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ # Dataset and DataLoaders creation:
+ train_dataset = TextualInversionDataset(
+ data_root=args.train_data_dir,
+ tokenizer=tokenizer,
+ size=args.resolution,
+ placeholder_token=args.placeholder_token,
+ repeats=args.repeats,
+ learnable_property=args.learnable_property,
+ center_crop=args.center_crop,
+ set="train",
+ )
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers
+ )
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
+ )
+
+ # Prepare everything with our `accelerator`.
+ text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ text_encoder, optimizer, train_dataloader, lr_scheduler
+ )
+
+ # For mixed precision training we cast the unet and vae weights to half-precision
+ # as these models are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ # Move vae and unet to device and cast to weight_dtype
+ unet.to(accelerator.device, dtype=weight_dtype)
+ vae.to(accelerator.device, dtype=weight_dtype)
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ accelerator.init_trackers("textual_inversion", config=vars(args))
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the most recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ resume_global_step = global_step * args.gradient_accumulation_steps
+ first_epoch = global_step // num_update_steps_per_epoch
+ resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
+
+ # Only show the progress bar once on each machine.
+ progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
+ progress_bar.set_description("Steps")
+
+ # keep original embeddings as reference
+ orig_embeds_params = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight.data.clone()
+
+ for epoch in range(first_epoch, args.num_train_epochs):
+ text_encoder.train()
+ for step, batch in enumerate(train_dataloader):
+ # Skip steps until we reach the resumed step
+ if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
+ if step % args.gradient_accumulation_steps == 0:
+ progress_bar.update(1)
+ continue
+ if args.progressive_tokens:
+ train_dataset.prop_tokens_to_load = float(global_step) / args.progressive_tokens_max_steps
+
+ with accelerator.accumulate(text_encoder):
+ # Convert images to latent space
+ latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample().detach()
+ latents = latents * vae.config.scaling_factor
+
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(latents)
+ bsz = latents.shape[0]
+ # Sample a random timestep for each image
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
+ timesteps = timesteps.long()
+
+ # Add noise to the latents according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
+
+ # Get the text embedding for conditioning
+ encoder_hidden_states = text_encoder(batch["input_ids"])[0].to(dtype=weight_dtype)
+
+ # Predict the noise residual
+ model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
+
+ # Get the target for loss depending on the prediction type
+ if noise_scheduler.config.prediction_type == "epsilon":
+ target = noise
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
+ else:
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
+
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
+
+ accelerator.backward(loss)
+
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ # Let's make sure we don't update any embedding weights besides the newly added token
+ index_no_updates = get_mask(tokenizer, accelerator)
+ with torch.no_grad():
+ accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[
+ index_no_updates
+ ] = orig_embeds_params[index_no_updates]
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ global_step += 1
+ if global_step % args.save_steps == 0:
+ save_path = os.path.join(args.output_dir, f"learned_embeds-steps-{global_step}.bin")
+ save_progress(tokenizer, text_encoder, accelerator, save_path)
+
+ if global_step % args.checkpointing_steps == 0:
+ if accelerator.is_main_process:
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+ accelerator.log(logs, step=global_step)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ if accelerator.is_main_process and args.validation_prompt is not None and epoch % args.validation_epochs == 0:
+ logger.info(
+ f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
+ f" {args.validation_prompt}."
+ )
+ # create pipeline (note: unet and vae are loaded again in float32)
+ pipeline = DiffusionPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ text_encoder=accelerator.unwrap_model(text_encoder),
+ tokenizer=tokenizer,
+ unet=unet,
+ vae=vae,
+ revision=args.revision,
+ torch_dtype=weight_dtype,
+ )
+ pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
+ pipeline = pipeline.to(accelerator.device)
+ pipeline.set_progress_bar_config(disable=True)
+
+ # run inference
+ generator = (
+ None if args.seed is None else torch.Generator(device=accelerator.device).manual_seed(args.seed)
+ )
+ images = []
+ for _ in range(args.num_validation_images):
+ with torch.autocast("cuda"):
+ image = pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]
+ images.append(image)
+
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
+ if tracker.name == "wandb":
+ tracker.log(
+ {
+ "validation": [
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
+ for i, image in enumerate(images)
+ ]
+ }
+ )
+
+ del pipeline
+ torch.cuda.empty_cache()
+
+ # Create the pipeline using using the trained modules and save it.
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ if args.push_to_hub and args.only_save_embeds:
+ logger.warning("Enabling full model saving because --push_to_hub=True was specified.")
+ save_full_model = True
+ else:
+ save_full_model = not args.only_save_embeds
+ if save_full_model:
+ pipeline = StableDiffusionPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ text_encoder=accelerator.unwrap_model(text_encoder),
+ vae=vae,
+ unet=unet,
+ tokenizer=tokenizer,
+ )
+ pipeline.save_pretrained(args.output_dir)
+ # Save the newly trained embeddings
+ save_path = os.path.join(args.output_dir, "learned_embeds.bin")
+ save_progress(tokenizer, text_encoder, accelerator, save_path)
+
+ if args.push_to_hub:
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/diffusers/examples/research_projects/multi_token_textual_inversion/textual_inversion_flax.py b/diffusers/examples/research_projects/multi_token_textual_inversion/textual_inversion_flax.py
new file mode 100644
index 0000000000000000000000000000000000000000..ecc89f98298e3e4205581fee1689761c519bc4e4
--- /dev/null
+++ b/diffusers/examples/research_projects/multi_token_textual_inversion/textual_inversion_flax.py
@@ -0,0 +1,654 @@
+import argparse
+import logging
+import math
+import os
+import random
+from pathlib import Path
+
+import jax
+import jax.numpy as jnp
+import numpy as np
+import optax
+import PIL
+import torch
+import torch.utils.checkpoint
+import transformers
+from flax import jax_utils
+from flax.training import train_state
+from flax.training.common_utils import shard
+from huggingface_hub import create_repo, upload_folder
+
+# TODO: remove and import from diffusers.utils when the new version of diffusers is released
+from packaging import version
+from PIL import Image
+from torch.utils.data import Dataset
+from torchvision import transforms
+from tqdm.auto import tqdm
+from transformers import CLIPImageProcessor, CLIPTokenizer, FlaxCLIPTextModel, set_seed
+
+from diffusers import (
+ FlaxAutoencoderKL,
+ FlaxDDPMScheduler,
+ FlaxPNDMScheduler,
+ FlaxStableDiffusionPipeline,
+ FlaxUNet2DConditionModel,
+)
+from diffusers.pipelines.stable_diffusion import FlaxStableDiffusionSafetyChecker
+from diffusers.utils import check_min_version
+
+
+if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
+ PIL_INTERPOLATION = {
+ "linear": PIL.Image.Resampling.BILINEAR,
+ "bilinear": PIL.Image.Resampling.BILINEAR,
+ "bicubic": PIL.Image.Resampling.BICUBIC,
+ "lanczos": PIL.Image.Resampling.LANCZOS,
+ "nearest": PIL.Image.Resampling.NEAREST,
+ }
+else:
+ PIL_INTERPOLATION = {
+ "linear": PIL.Image.LINEAR,
+ "bilinear": PIL.Image.BILINEAR,
+ "bicubic": PIL.Image.BICUBIC,
+ "lanczos": PIL.Image.LANCZOS,
+ "nearest": PIL.Image.NEAREST,
+ }
+# ------------------------------------------------------------------------------
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.14.0.dev0")
+
+logger = logging.getLogger(__name__)
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--tokenizer_name",
+ type=str,
+ default=None,
+ help="Pretrained tokenizer name or path if not the same as model_name",
+ )
+ parser.add_argument(
+ "--train_data_dir", type=str, default=None, required=True, help="A folder containing the training data."
+ )
+ parser.add_argument(
+ "--placeholder_token",
+ type=str,
+ default=None,
+ required=True,
+ help="A token to use as a placeholder for the concept.",
+ )
+ parser.add_argument(
+ "--initializer_token", type=str, default=None, required=True, help="A token to use as initializer word."
+ )
+ parser.add_argument("--learnable_property", type=str, default="object", help="Choose between 'object' and 'style'")
+ parser.add_argument("--repeats", type=int, default=100, help="How many times to repeat the training data.")
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="text-inversion-model",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution."
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=100)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=5000,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=True,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument(
+ "--use_auth_token",
+ action="store_true",
+ help=(
+ "Will use the token generated when running `huggingface-cli login` (necessary to use this script with"
+ " private models)."
+ ),
+ )
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+
+ args = parser.parse_args()
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ if args.train_data_dir is None:
+ raise ValueError("You must specify a train data directory.")
+
+ return args
+
+
+imagenet_templates_small = [
+ "a photo of a {}",
+ "a rendering of a {}",
+ "a cropped photo of the {}",
+ "the photo of a {}",
+ "a photo of a clean {}",
+ "a photo of a dirty {}",
+ "a dark photo of the {}",
+ "a photo of my {}",
+ "a photo of the cool {}",
+ "a close-up photo of a {}",
+ "a bright photo of the {}",
+ "a cropped photo of a {}",
+ "a photo of the {}",
+ "a good photo of the {}",
+ "a photo of one {}",
+ "a close-up photo of the {}",
+ "a rendition of the {}",
+ "a photo of the clean {}",
+ "a rendition of a {}",
+ "a photo of a nice {}",
+ "a good photo of a {}",
+ "a photo of the nice {}",
+ "a photo of the small {}",
+ "a photo of the weird {}",
+ "a photo of the large {}",
+ "a photo of a cool {}",
+ "a photo of a small {}",
+]
+
+imagenet_style_templates_small = [
+ "a painting in the style of {}",
+ "a rendering in the style of {}",
+ "a cropped painting in the style of {}",
+ "the painting in the style of {}",
+ "a clean painting in the style of {}",
+ "a dirty painting in the style of {}",
+ "a dark painting in the style of {}",
+ "a picture in the style of {}",
+ "a cool painting in the style of {}",
+ "a close-up painting in the style of {}",
+ "a bright painting in the style of {}",
+ "a cropped painting in the style of {}",
+ "a good painting in the style of {}",
+ "a close-up painting in the style of {}",
+ "a rendition in the style of {}",
+ "a nice painting in the style of {}",
+ "a small painting in the style of {}",
+ "a weird painting in the style of {}",
+ "a large painting in the style of {}",
+]
+
+
+class TextualInversionDataset(Dataset):
+ def __init__(
+ self,
+ data_root,
+ tokenizer,
+ learnable_property="object", # [object, style]
+ size=512,
+ repeats=100,
+ interpolation="bicubic",
+ flip_p=0.5,
+ set="train",
+ placeholder_token="*",
+ center_crop=False,
+ ):
+ self.data_root = data_root
+ self.tokenizer = tokenizer
+ self.learnable_property = learnable_property
+ self.size = size
+ self.placeholder_token = placeholder_token
+ self.center_crop = center_crop
+ self.flip_p = flip_p
+
+ self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)]
+
+ self.num_images = len(self.image_paths)
+ self._length = self.num_images
+
+ if set == "train":
+ self._length = self.num_images * repeats
+
+ self.interpolation = {
+ "linear": PIL_INTERPOLATION["linear"],
+ "bilinear": PIL_INTERPOLATION["bilinear"],
+ "bicubic": PIL_INTERPOLATION["bicubic"],
+ "lanczos": PIL_INTERPOLATION["lanczos"],
+ }[interpolation]
+
+ self.templates = imagenet_style_templates_small if learnable_property == "style" else imagenet_templates_small
+ self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p)
+
+ def __len__(self):
+ return self._length
+
+ def __getitem__(self, i):
+ example = {}
+ image = Image.open(self.image_paths[i % self.num_images])
+
+ if not image.mode == "RGB":
+ image = image.convert("RGB")
+
+ placeholder_string = self.placeholder_token
+ text = random.choice(self.templates).format(placeholder_string)
+
+ example["input_ids"] = self.tokenizer(
+ text,
+ padding="max_length",
+ truncation=True,
+ max_length=self.tokenizer.model_max_length,
+ return_tensors="pt",
+ ).input_ids[0]
+
+ # default to score-sde preprocessing
+ img = np.array(image).astype(np.uint8)
+
+ if self.center_crop:
+ crop = min(img.shape[0], img.shape[1])
+ (
+ h,
+ w,
+ ) = (
+ img.shape[0],
+ img.shape[1],
+ )
+ img = img[(h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2]
+
+ image = Image.fromarray(img)
+ image = image.resize((self.size, self.size), resample=self.interpolation)
+
+ image = self.flip_transform(image)
+ image = np.array(image).astype(np.uint8)
+ image = (image / 127.5 - 1.0).astype(np.float32)
+
+ example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1)
+ return example
+
+
+def resize_token_embeddings(model, new_num_tokens, initializer_token_id, placeholder_token_id, rng):
+ if model.config.vocab_size == new_num_tokens or new_num_tokens is None:
+ return
+ model.config.vocab_size = new_num_tokens
+
+ params = model.params
+ old_embeddings = params["text_model"]["embeddings"]["token_embedding"]["embedding"]
+ old_num_tokens, emb_dim = old_embeddings.shape
+
+ initializer = jax.nn.initializers.normal()
+
+ new_embeddings = initializer(rng, (new_num_tokens, emb_dim))
+ new_embeddings = new_embeddings.at[:old_num_tokens].set(old_embeddings)
+ new_embeddings = new_embeddings.at[placeholder_token_id].set(new_embeddings[initializer_token_id])
+ params["text_model"]["embeddings"]["token_embedding"]["embedding"] = new_embeddings
+
+ model.params = params
+ return model
+
+
+def get_params_to_save(params):
+ return jax.device_get(jax.tree_util.tree_map(lambda x: x[0], params))
+
+
+def main():
+ args = parse_args()
+
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ if jax.process_index() == 0:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
+ ).repo_id
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ # Setup logging, we only want one process per machine to log things on the screen.
+ logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
+ if jax.process_index() == 0:
+ transformers.utils.logging.set_verbosity_info()
+ else:
+ transformers.utils.logging.set_verbosity_error()
+
+ # Load the tokenizer and add the placeholder token as a additional special token
+ if args.tokenizer_name:
+ tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
+ elif args.pretrained_model_name_or_path:
+ tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
+
+ # Add the placeholder token in tokenizer
+ num_added_tokens = tokenizer.add_tokens(args.placeholder_token)
+ if num_added_tokens == 0:
+ raise ValueError(
+ f"The tokenizer already contains the token {args.placeholder_token}. Please pass a different"
+ " `placeholder_token` that is not already in the tokenizer."
+ )
+
+ # Convert the initializer_token, placeholder_token to ids
+ token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False)
+ # Check if initializer_token is a single token or a sequence of tokens
+ if len(token_ids) > 1:
+ raise ValueError("The initializer token must be a single token.")
+
+ initializer_token_id = token_ids[0]
+ placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token)
+
+ # Load models and create wrapper for stable diffusion
+ text_encoder = FlaxCLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
+ vae, vae_params = FlaxAutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
+ unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
+
+ # Create sampling rng
+ rng = jax.random.PRNGKey(args.seed)
+ rng, _ = jax.random.split(rng)
+ # Resize the token embeddings as we are adding new special tokens to the tokenizer
+ text_encoder = resize_token_embeddings(
+ text_encoder, len(tokenizer), initializer_token_id, placeholder_token_id, rng
+ )
+ original_token_embeds = text_encoder.params["text_model"]["embeddings"]["token_embedding"]["embedding"]
+
+ train_dataset = TextualInversionDataset(
+ data_root=args.train_data_dir,
+ tokenizer=tokenizer,
+ size=args.resolution,
+ placeholder_token=args.placeholder_token,
+ repeats=args.repeats,
+ learnable_property=args.learnable_property,
+ center_crop=args.center_crop,
+ set="train",
+ )
+
+ def collate_fn(examples):
+ pixel_values = torch.stack([example["pixel_values"] for example in examples])
+ input_ids = torch.stack([example["input_ids"] for example in examples])
+
+ batch = {"pixel_values": pixel_values, "input_ids": input_ids}
+ batch = {k: v.numpy() for k, v in batch.items()}
+
+ return batch
+
+ total_train_batch_size = args.train_batch_size * jax.local_device_count()
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset, batch_size=total_train_batch_size, shuffle=True, drop_last=True, collate_fn=collate_fn
+ )
+
+ # Optimization
+ if args.scale_lr:
+ args.learning_rate = args.learning_rate * total_train_batch_size
+
+ constant_scheduler = optax.constant_schedule(args.learning_rate)
+
+ optimizer = optax.adamw(
+ learning_rate=constant_scheduler,
+ b1=args.adam_beta1,
+ b2=args.adam_beta2,
+ eps=args.adam_epsilon,
+ weight_decay=args.adam_weight_decay,
+ )
+
+ def create_mask(params, label_fn):
+ def _map(params, mask, label_fn):
+ for k in params:
+ if label_fn(k):
+ mask[k] = "token_embedding"
+ else:
+ if isinstance(params[k], dict):
+ mask[k] = {}
+ _map(params[k], mask[k], label_fn)
+ else:
+ mask[k] = "zero"
+
+ mask = {}
+ _map(params, mask, label_fn)
+ return mask
+
+ def zero_grads():
+ # from https://github.com/deepmind/optax/issues/159#issuecomment-896459491
+ def init_fn(_):
+ return ()
+
+ def update_fn(updates, state, params=None):
+ return jax.tree_util.tree_map(jnp.zeros_like, updates), ()
+
+ return optax.GradientTransformation(init_fn, update_fn)
+
+ # Zero out gradients of layers other than the token embedding layer
+ tx = optax.multi_transform(
+ {"token_embedding": optimizer, "zero": zero_grads()},
+ create_mask(text_encoder.params, lambda s: s == "token_embedding"),
+ )
+
+ state = train_state.TrainState.create(apply_fn=text_encoder.__call__, params=text_encoder.params, tx=tx)
+
+ noise_scheduler = FlaxDDPMScheduler(
+ beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000
+ )
+ noise_scheduler_state = noise_scheduler.create_state()
+
+ # Initialize our training
+ train_rngs = jax.random.split(rng, jax.local_device_count())
+
+ # Define gradient train step fn
+ def train_step(state, vae_params, unet_params, batch, train_rng):
+ dropout_rng, sample_rng, new_train_rng = jax.random.split(train_rng, 3)
+
+ def compute_loss(params):
+ vae_outputs = vae.apply(
+ {"params": vae_params}, batch["pixel_values"], deterministic=True, method=vae.encode
+ )
+ latents = vae_outputs.latent_dist.sample(sample_rng)
+ # (NHWC) -> (NCHW)
+ latents = jnp.transpose(latents, (0, 3, 1, 2))
+ latents = latents * vae.config.scaling_factor
+
+ noise_rng, timestep_rng = jax.random.split(sample_rng)
+ noise = jax.random.normal(noise_rng, latents.shape)
+ bsz = latents.shape[0]
+ timesteps = jax.random.randint(
+ timestep_rng,
+ (bsz,),
+ 0,
+ noise_scheduler.config.num_train_timesteps,
+ )
+ noisy_latents = noise_scheduler.add_noise(noise_scheduler_state, latents, noise, timesteps)
+ encoder_hidden_states = state.apply_fn(
+ batch["input_ids"], params=params, dropout_rng=dropout_rng, train=True
+ )[0]
+ # Predict the noise residual and compute loss
+ model_pred = unet.apply(
+ {"params": unet_params}, noisy_latents, timesteps, encoder_hidden_states, train=False
+ ).sample
+
+ # Get the target for loss depending on the prediction type
+ if noise_scheduler.config.prediction_type == "epsilon":
+ target = noise
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ target = noise_scheduler.get_velocity(noise_scheduler_state, latents, noise, timesteps)
+ else:
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
+
+ loss = (target - model_pred) ** 2
+ loss = loss.mean()
+
+ return loss
+
+ grad_fn = jax.value_and_grad(compute_loss)
+ loss, grad = grad_fn(state.params)
+ grad = jax.lax.pmean(grad, "batch")
+ new_state = state.apply_gradients(grads=grad)
+
+ # Keep the token embeddings fixed except the newly added embeddings for the concept,
+ # as we only want to optimize the concept embeddings
+ token_embeds = original_token_embeds.at[placeholder_token_id].set(
+ new_state.params["text_model"]["embeddings"]["token_embedding"]["embedding"][placeholder_token_id]
+ )
+ new_state.params["text_model"]["embeddings"]["token_embedding"]["embedding"] = token_embeds
+
+ metrics = {"loss": loss}
+ metrics = jax.lax.pmean(metrics, axis_name="batch")
+ return new_state, metrics, new_train_rng
+
+ # Create parallel version of the train and eval step
+ p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
+
+ # Replicate the train state on each device
+ state = jax_utils.replicate(state)
+ vae_params = jax_utils.replicate(vae_params)
+ unet_params = jax_utils.replicate(unet_params)
+
+ # Train!
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader))
+
+ # Scheduler and math around the number of training steps.
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel & distributed) = {total_train_batch_size}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+
+ global_step = 0
+
+ epochs = tqdm(range(args.num_train_epochs), desc=f"Epoch ... (1/{args.num_train_epochs})", position=0)
+ for epoch in epochs:
+ # ======================== Training ================================
+
+ train_metrics = []
+
+ steps_per_epoch = len(train_dataset) // total_train_batch_size
+ train_step_progress_bar = tqdm(total=steps_per_epoch, desc="Training...", position=1, leave=False)
+ # train
+ for batch in train_dataloader:
+ batch = shard(batch)
+ state, train_metric, train_rngs = p_train_step(state, vae_params, unet_params, batch, train_rngs)
+ train_metrics.append(train_metric)
+
+ train_step_progress_bar.update(1)
+ global_step += 1
+
+ if global_step >= args.max_train_steps:
+ break
+
+ train_metric = jax_utils.unreplicate(train_metric)
+
+ train_step_progress_bar.close()
+ epochs.write(f"Epoch... ({epoch + 1}/{args.num_train_epochs} | Loss: {train_metric['loss']})")
+
+ # Create the pipeline using using the trained modules and save it.
+ if jax.process_index() == 0:
+ scheduler = FlaxPNDMScheduler(
+ beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True
+ )
+ safety_checker = FlaxStableDiffusionSafetyChecker.from_pretrained(
+ "CompVis/stable-diffusion-safety-checker", from_pt=True
+ )
+ pipeline = FlaxStableDiffusionPipeline(
+ text_encoder=text_encoder,
+ vae=vae,
+ unet=unet,
+ tokenizer=tokenizer,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32"),
+ )
+
+ pipeline.save_pretrained(
+ args.output_dir,
+ params={
+ "text_encoder": get_params_to_save(state.params),
+ "vae": get_params_to_save(vae_params),
+ "unet": get_params_to_save(unet_params),
+ "safety_checker": safety_checker.params,
+ },
+ )
+
+ # Also save the newly trained embeddings
+ learned_embeds = get_params_to_save(state.params)["text_model"]["embeddings"]["token_embedding"]["embedding"][
+ placeholder_token_id
+ ]
+ learned_embeds_dict = {args.placeholder_token: learned_embeds}
+ jnp.save(os.path.join(args.output_dir, "learned_embeds.npy"), learned_embeds_dict)
+
+ if args.push_to_hub:
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+
+if __name__ == "__main__":
+ main()
diff --git a/diffusers/examples/research_projects/onnxruntime/README.md b/diffusers/examples/research_projects/onnxruntime/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..a7935189f61b3078b0926661cd1b3ee91fc2295a
--- /dev/null
+++ b/diffusers/examples/research_projects/onnxruntime/README.md
@@ -0,0 +1,5 @@
+## Diffusers examples with ONNXRuntime optimizations
+
+**This research project is not actively maintained by the diffusers team. For any questions or comments, please contact Prathik Rao (prathikr), Sunghoon Choi (hanbitmyths), Ashwini Khade (askhade), or Peng Wang (pengwa) on github with any questions.**
+
+This aims to provide diffusers examples with ONNXRuntime optimizations for training/fine-tuning unconditional image generation, text to image, and textual inversion. Please see individual directories for more details on how to run each task using ONNXRuntime.
diff --git a/diffusers/examples/research_projects/onnxruntime/text_to_image/README.md b/diffusers/examples/research_projects/onnxruntime/text_to_image/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..f1f134c576b26aeddb637c803b9a70b4c8c1a97a
--- /dev/null
+++ b/diffusers/examples/research_projects/onnxruntime/text_to_image/README.md
@@ -0,0 +1,74 @@
+# Stable Diffusion text-to-image fine-tuning
+
+The `train_text_to_image.py` script shows how to fine-tune stable diffusion model on your own dataset.
+
+___Note___:
+
+___This script is experimental. The script fine-tunes the whole model and often times the model overfits and runs into issues like catastrophic forgetting. It's recommended to try different hyperparamters to get the best result on your dataset.___
+
+
+## Running locally with PyTorch
+### Installing the dependencies
+
+Before running the scripts, make sure to install the library's training dependencies:
+
+**Important**
+
+To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
+```bash
+git clone https://github.com/huggingface/diffusers
+cd diffusers
+pip install .
+```
+
+Then cd in the example folder and run
+```bash
+pip install -r requirements.txt
+```
+
+And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
+
+```bash
+accelerate config
+```
+
+### Naruto example
+
+You need to accept the model license before downloading or using the weights. In this example we'll use model version `v1-4`, so you'll need to visit [its card](https://huggingface.co/CompVis/stable-diffusion-v1-4), read the license and tick the checkbox if you agree.
+
+You have to be a registered user in 🤗 Hugging Face Hub, and you'll also need to use an access token for the code to work. For more information on access tokens, please refer to [this section of the documentation](https://huggingface.co/docs/hub/security-tokens).
+
+Run the following command to authenticate your token
+
+```bash
+huggingface-cli login
+```
+
+If you have already cloned the repo, then you won't need to go through these steps.
+
+
+
+## Use ONNXRuntime to accelerate training
+In order to leverage onnxruntime to accelerate training, please use train_text_to_image.py
+
+The command to train a DDPM UNetCondition model on the Naruto dataset with onnxruntime:
+
+```bash
+export MODEL_NAME="CompVis/stable-diffusion-v1-4"
+export dataset_name="lambdalabs/naruto-blip-captions"
+accelerate launch --mixed_precision="fp16" train_text_to_image.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --dataset_name=$dataset_name \
+ --use_ema \
+ --resolution=512 --center_crop --random_flip \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=4 \
+ --gradient_checkpointing \
+ --max_train_steps=15000 \
+ --learning_rate=1e-05 \
+ --max_grad_norm=1 \
+ --lr_scheduler="constant" --lr_warmup_steps=0 \
+ --output_dir="sd-naruto-model"
+```
+
+Please contact Prathik Rao (prathikr), Sunghoon Choi (hanbitmyths), Ashwini Khade (askhade), or Peng Wang (pengwa) on github with any questions.
\ No newline at end of file
diff --git a/diffusers/examples/research_projects/onnxruntime/text_to_image/requirements.txt b/diffusers/examples/research_projects/onnxruntime/text_to_image/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..2dbadea4474aac24b501e61a4b05f24168ac85be
--- /dev/null
+++ b/diffusers/examples/research_projects/onnxruntime/text_to_image/requirements.txt
@@ -0,0 +1,7 @@
+accelerate>=0.16.0
+torchvision
+transformers>=4.25.1
+datasets
+ftfy
+tensorboard
+modelcards
diff --git a/diffusers/examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py b/diffusers/examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py
new file mode 100644
index 0000000000000000000000000000000000000000..126a10b4f9e90e0716e2784f3e3673dbda37a601
--- /dev/null
+++ b/diffusers/examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py
@@ -0,0 +1,954 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+import argparse
+import logging
+import math
+import os
+import random
+from pathlib import Path
+
+import accelerate
+import datasets
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.state import AcceleratorState
+from accelerate.utils import ProjectConfiguration, set_seed
+from datasets import load_dataset
+from huggingface_hub import create_repo, upload_folder
+from onnxruntime.training.optim.fp16_optimizer import FP16_Optimizer as ORT_FP16_Optimizer
+from onnxruntime.training.ortmodule import ORTModule
+from packaging import version
+from torchvision import transforms
+from tqdm.auto import tqdm
+from transformers import CLIPTextModel, CLIPTokenizer
+from transformers.utils import ContextManagers
+
+import diffusers
+from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
+from diffusers.optimization import get_scheduler
+from diffusers.training_utils import EMAModel, compute_snr
+from diffusers.utils import check_min_version, deprecate, is_wandb_available
+from diffusers.utils.import_utils import is_xformers_available
+
+
+if is_wandb_available():
+ import wandb
+
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.17.0.dev0")
+
+logger = get_logger(__name__, log_level="INFO")
+
+DATASET_NAME_MAPPING = {
+ "lambdalabs/naruto-blip-captions": ("image", "text"),
+}
+
+
+def log_validation(vae, text_encoder, tokenizer, unet, args, accelerator, weight_dtype, epoch):
+ logger.info("Running validation... ")
+
+ pipeline = StableDiffusionPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ vae=accelerator.unwrap_model(vae),
+ text_encoder=accelerator.unwrap_model(text_encoder),
+ tokenizer=tokenizer,
+ unet=accelerator.unwrap_model(unet),
+ safety_checker=None,
+ revision=args.revision,
+ torch_dtype=weight_dtype,
+ )
+ pipeline = pipeline.to(accelerator.device)
+ pipeline.set_progress_bar_config(disable=True)
+
+ if args.enable_xformers_memory_efficient_attention:
+ pipeline.enable_xformers_memory_efficient_attention()
+
+ if args.seed is None:
+ generator = None
+ else:
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
+
+ images = []
+ for i in range(len(args.validation_prompts)):
+ with torch.autocast("cuda"):
+ image = pipeline(args.validation_prompts[i], num_inference_steps=20, generator=generator).images[0]
+
+ images.append(image)
+
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
+ elif tracker.name == "wandb":
+ tracker.log(
+ {
+ "validation": [
+ wandb.Image(image, caption=f"{i}: {args.validation_prompts[i]}")
+ for i, image in enumerate(images)
+ ]
+ }
+ )
+ else:
+ logger.warning(f"image logging not implemented for {tracker.name}")
+
+ del pipeline
+ torch.cuda.empty_cache()
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ parser.add_argument(
+ "--input_pertubation", type=float, default=0, help="The scale of input pretubation. Recommended 0.1."
+ )
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--dataset_name",
+ type=str,
+ default=None,
+ help=(
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
+ " or to a folder containing files that 🤗 Datasets can understand."
+ ),
+ )
+ parser.add_argument(
+ "--dataset_config_name",
+ type=str,
+ default=None,
+ help="The config of the Dataset, leave as None if there's only one config.",
+ )
+ parser.add_argument(
+ "--train_data_dir",
+ type=str,
+ default=None,
+ help=(
+ "A folder containing the training data. Folder contents must follow the structure described in"
+ " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
+ " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
+ ),
+ )
+ parser.add_argument(
+ "--image_column", type=str, default="image", help="The column of the dataset containing an image."
+ )
+ parser.add_argument(
+ "--caption_column",
+ type=str,
+ default="text",
+ help="The column of the dataset containing a caption or a list of captions.",
+ )
+ parser.add_argument(
+ "--max_train_samples",
+ type=int,
+ default=None,
+ help=(
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ ),
+ )
+ parser.add_argument(
+ "--validation_prompts",
+ type=str,
+ default=None,
+ nargs="+",
+ help=("A set of prompts evaluated every `--validation_epochs` and logged to `--report_to`."),
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="sd-model-finetuned",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--center_crop",
+ default=False,
+ action="store_true",
+ help=(
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
+ " cropped. The images will be resized to the resolution first before cropping."
+ ),
+ )
+ parser.add_argument(
+ "--random_flip",
+ action="store_true",
+ help="whether to randomly flip images horizontally",
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=100)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--snr_gamma",
+ type=float,
+ default=None,
+ help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
+ "More details here: https://arxiv.org/abs/2303.09556.",
+ )
+ parser.add_argument(
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")
+ parser.add_argument(
+ "--non_ema_revision",
+ type=str,
+ default=None,
+ required=False,
+ help=(
+ "Revision of pretrained non-ema model identifier. Must be a branch, tag or git identifier of the local or"
+ " remote repository specified with --pretrained_model_name_or_path."
+ ),
+ )
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=(
+ "Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
+ " See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
+ " for more docs"
+ ),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
+ )
+ parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.")
+ parser.add_argument(
+ "--validation_epochs",
+ type=int,
+ default=5,
+ help="Run validation every X epochs.",
+ )
+ parser.add_argument(
+ "--tracker_project_name",
+ type=str,
+ default="text2image-fine-tune",
+ help=(
+ "The `project_name` argument passed to Accelerator.init_trackers for"
+ " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
+ ),
+ )
+
+ args = parser.parse_args()
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ # Sanity checks
+ if args.dataset_name is None and args.train_data_dir is None:
+ raise ValueError("Need either a dataset name or a training folder.")
+
+ # default to using the same revision for the non-ema model if not specified
+ if args.non_ema_revision is None:
+ args.non_ema_revision = args.revision
+
+ return args
+
+
+def main():
+ args = parse_args()
+
+ if args.report_to == "wandb" and args.hub_token is not None:
+ raise ValueError(
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
+ " Please use `huggingface-cli login` to authenticate with the Hub."
+ )
+
+ if args.non_ema_revision is not None:
+ deprecate(
+ "non_ema_revision!=None",
+ "0.15.0",
+ message=(
+ "Downloading 'non_ema' weights from revision branches of the Hub is deprecated. Please make sure to"
+ " use `--variant=non_ema` instead."
+ ),
+ )
+ logging_dir = os.path.join(args.output_dir, args.logging_dir)
+ accelerator_project_config = ProjectConfiguration(
+ total_limit=args.checkpoints_total_limit, project_dir=args.output_dir, logging_dir=logging_dir
+ )
+
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ )
+
+ # Disable AMP for MPS.
+ if torch.backends.mps.is_available():
+ accelerator.native_amp = False
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ datasets.utils.logging.set_verbosity_warning()
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ datasets.utils.logging.set_verbosity_error()
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
+ ).repo_id
+
+ # Load scheduler, tokenizer and models.
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
+ tokenizer = CLIPTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
+ )
+
+ def deepspeed_zero_init_disabled_context_manager():
+ """
+ returns either a context list that includes one that will disable zero.Init or an empty context list
+ """
+ deepspeed_plugin = AcceleratorState().deepspeed_plugin if accelerate.state.is_initialized() else None
+ if deepspeed_plugin is None:
+ return []
+
+ return [deepspeed_plugin.zero3_init_context_manager(enable=False)]
+
+ # Currently Accelerate doesn't know how to handle multiple models under Deepspeed ZeRO stage 3.
+ # For this to work properly all models must be run through `accelerate.prepare`. But accelerate
+ # will try to assign the same optimizer with the same weights to all models during
+ # `deepspeed.initialize`, which of course doesn't work.
+ #
+ # For now the following workaround will partially support Deepspeed ZeRO-3, by excluding the 2
+ # frozen models from being partitioned during `zero.Init` which gets called during
+ # `from_pretrained` So CLIPTextModel and AutoencoderKL will not enjoy the parameter sharding
+ # across multiple gpus and only UNet2DConditionModel will get ZeRO sharded.
+ with ContextManagers(deepspeed_zero_init_disabled_context_manager()):
+ text_encoder = CLIPTextModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
+ )
+ vae = AutoencoderKL.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision
+ )
+
+ unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.non_ema_revision
+ )
+
+ # Freeze vae and text_encoder
+ vae.requires_grad_(False)
+ text_encoder.requires_grad_(False)
+
+ # Create EMA for the unet.
+ if args.use_ema:
+ ema_unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
+ )
+ ema_unet = EMAModel(ema_unet.parameters(), model_cls=UNet2DConditionModel, model_config=ema_unet.config)
+
+ if args.enable_xformers_memory_efficient_attention:
+ if is_xformers_available():
+ import xformers
+
+ xformers_version = version.parse(xformers.__version__)
+ if xformers_version == version.parse("0.0.16"):
+ logger.warning(
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
+ )
+ unet.enable_xformers_memory_efficient_attention()
+ else:
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
+
+ # `accelerate` 0.16.0 will have better support for customized saving
+ if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
+ def save_model_hook(models, weights, output_dir):
+ if accelerator.is_main_process:
+ if args.use_ema:
+ ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema"))
+
+ for i, model in enumerate(models):
+ model.save_pretrained(os.path.join(output_dir, "unet"))
+
+ # make sure to pop weight so that corresponding model is not saved again
+ weights.pop()
+
+ def load_model_hook(models, input_dir):
+ if args.use_ema:
+ load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), UNet2DConditionModel)
+ ema_unet.load_state_dict(load_model.state_dict())
+ ema_unet.to(accelerator.device)
+ del load_model
+
+ for i in range(len(models)):
+ # pop models so that they are not loaded again
+ model = models.pop()
+
+ # load diffusers style into model
+ load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet")
+ model.register_to_config(**load_model.config)
+
+ model.load_state_dict(load_model.state_dict())
+ del load_model
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
+
+ if args.gradient_checkpointing:
+ unet.enable_gradient_checkpointing()
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ # Initialize the optimizer
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
+ )
+
+ optimizer_cls = bnb.optim.AdamW8bit
+ else:
+ optimizer_cls = torch.optim.AdamW
+
+ optimizer = optimizer_cls(
+ unet.parameters(),
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ optimizer = ORT_FP16_Optimizer(optimizer)
+
+ # Get the datasets: you can either provide your own training and evaluation files (see below)
+ # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
+
+ # In distributed training, the load_dataset function guarantees that only one local process can concurrently
+ # download the dataset.
+ if args.dataset_name is not None:
+ # Downloading and loading a dataset from the hub.
+ dataset = load_dataset(
+ args.dataset_name,
+ args.dataset_config_name,
+ cache_dir=args.cache_dir,
+ )
+ else:
+ data_files = {}
+ if args.train_data_dir is not None:
+ data_files["train"] = os.path.join(args.train_data_dir, "**")
+ dataset = load_dataset(
+ "imagefolder",
+ data_files=data_files,
+ cache_dir=args.cache_dir,
+ )
+ # See more about loading custom images at
+ # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
+
+ # Preprocessing the datasets.
+ # We need to tokenize inputs and targets.
+ column_names = dataset["train"].column_names
+
+ # 6. Get the column names for input/target.
+ dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None)
+ if args.image_column is None:
+ image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
+ else:
+ image_column = args.image_column
+ if image_column not in column_names:
+ raise ValueError(
+ f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}"
+ )
+ if args.caption_column is None:
+ caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
+ else:
+ caption_column = args.caption_column
+ if caption_column not in column_names:
+ raise ValueError(
+ f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}"
+ )
+
+ # Preprocessing the datasets.
+ # We need to tokenize input captions and transform the images.
+ def tokenize_captions(examples, is_train=True):
+ captions = []
+ for caption in examples[caption_column]:
+ if isinstance(caption, str):
+ captions.append(caption)
+ elif isinstance(caption, (list, np.ndarray)):
+ # take a random caption if there are multiple
+ captions.append(random.choice(caption) if is_train else caption[0])
+ else:
+ raise ValueError(
+ f"Caption column `{caption_column}` should contain either strings or lists of strings."
+ )
+ inputs = tokenizer(
+ captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
+ )
+ return inputs.input_ids
+
+ # Preprocessing the datasets.
+ train_transforms = transforms.Compose(
+ [
+ transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
+ transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
+ transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+ )
+
+ def preprocess_train(examples):
+ images = [image.convert("RGB") for image in examples[image_column]]
+ examples["pixel_values"] = [train_transforms(image) for image in images]
+ examples["input_ids"] = tokenize_captions(examples)
+ return examples
+
+ with accelerator.main_process_first():
+ if args.max_train_samples is not None:
+ dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
+ # Set the training transforms
+ train_dataset = dataset["train"].with_transform(preprocess_train)
+
+ def collate_fn(examples):
+ pixel_values = torch.stack([example["pixel_values"] for example in examples])
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
+ input_ids = torch.stack([example["input_ids"] for example in examples])
+ return {"pixel_values": pixel_values, "input_ids": input_ids}
+
+ # DataLoaders creation:
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset,
+ shuffle=True,
+ collate_fn=collate_fn,
+ batch_size=args.train_batch_size,
+ num_workers=args.dataloader_num_workers,
+ )
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
+ )
+
+ # Prepare everything with our `accelerator`.
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ unet, optimizer, train_dataloader, lr_scheduler
+ )
+
+ if args.use_ema:
+ ema_unet.to(accelerator.device)
+
+ unet = ORTModule(unet)
+
+ # For mixed precision training we cast the text_encoder and vae weights to half-precision
+ # as these models are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ # Move text_encode and vae to gpu and cast to weight_dtype
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
+ vae.to(accelerator.device, dtype=weight_dtype)
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ tracker_config = dict(vars(args))
+ tracker_config.pop("validation_prompts")
+ accelerator.init_trackers(args.tracker_project_name, tracker_config)
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the most recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ resume_global_step = global_step * args.gradient_accumulation_steps
+ first_epoch = global_step // num_update_steps_per_epoch
+ resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
+
+ # Only show the progress bar once on each machine.
+ progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
+ progress_bar.set_description("Steps")
+
+ for epoch in range(first_epoch, args.num_train_epochs):
+ unet.train()
+ train_loss = 0.0
+ for step, batch in enumerate(train_dataloader):
+ # Skip steps until we reach the resumed step
+ if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
+ if step % args.gradient_accumulation_steps == 0:
+ progress_bar.update(1)
+ continue
+
+ with accelerator.accumulate(unet):
+ # Convert images to latent space
+ latents = vae.encode(batch["pixel_values"].to(weight_dtype)).latent_dist.sample()
+ latents = latents * vae.config.scaling_factor
+
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(latents)
+ if args.noise_offset:
+ # https://www.crosslabs.org//blog/diffusion-with-offset-noise
+ noise += args.noise_offset * torch.randn(
+ (latents.shape[0], latents.shape[1], 1, 1), device=latents.device
+ )
+ if args.input_pertubation:
+ new_noise = noise + args.input_pertubation * torch.randn_like(noise)
+ bsz = latents.shape[0]
+ # Sample a random timestep for each image
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
+ timesteps = timesteps.long()
+
+ # Add noise to the latents according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ if args.input_pertubation:
+ noisy_latents = noise_scheduler.add_noise(latents, new_noise, timesteps)
+ else:
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
+
+ # Get the text embedding for conditioning
+ encoder_hidden_states = text_encoder(batch["input_ids"])[0]
+
+ # Get the target for loss depending on the prediction type
+ if noise_scheduler.config.prediction_type == "epsilon":
+ target = noise
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
+ else:
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
+
+ # Predict the noise residual and compute loss
+ model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
+
+ if args.snr_gamma is None:
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
+ else:
+ # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
+ # Since we predict the noise instead of x_0, the original formulation is slightly changed.
+ # This is discussed in Section 4.2 of the same paper.
+ snr = compute_snr(noise_scheduler, timesteps)
+ mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(
+ dim=1
+ )[0]
+ if noise_scheduler.config.prediction_type == "epsilon":
+ mse_loss_weights = mse_loss_weights / snr
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ mse_loss_weights = mse_loss_weights / (snr + 1)
+
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
+ loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
+ loss = loss.mean()
+
+ # Gather the losses across all processes for logging (if we use distributed training).
+ avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
+ train_loss += avg_loss.item() / args.gradient_accumulation_steps
+
+ # Backpropagate
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ if args.use_ema:
+ ema_unet.step(unet.parameters())
+ progress_bar.update(1)
+ global_step += 1
+ accelerator.log({"train_loss": train_loss}, step=global_step)
+ train_loss = 0.0
+
+ if global_step % args.checkpointing_steps == 0:
+ if accelerator.is_main_process:
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ if accelerator.is_main_process:
+ if args.validation_prompts is not None and epoch % args.validation_epochs == 0:
+ if args.use_ema:
+ # Store the UNet parameters temporarily and load the EMA parameters to perform inference.
+ ema_unet.store(unet.parameters())
+ ema_unet.copy_to(unet.parameters())
+ log_validation(
+ vae,
+ text_encoder,
+ tokenizer,
+ unet,
+ args,
+ accelerator,
+ weight_dtype,
+ global_step,
+ )
+ if args.use_ema:
+ # Switch back to the original UNet parameters.
+ ema_unet.restore(unet.parameters())
+
+ # Create the pipeline using the trained modules and save it.
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ unet = accelerator.unwrap_model(unet)
+ if args.use_ema:
+ ema_unet.copy_to(unet.parameters())
+
+ pipeline = StableDiffusionPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ text_encoder=text_encoder,
+ vae=vae,
+ unet=unet,
+ revision=args.revision,
+ )
+ pipeline.save_pretrained(args.output_dir)
+
+ if args.push_to_hub:
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/diffusers/examples/research_projects/onnxruntime/textual_inversion/README.md b/diffusers/examples/research_projects/onnxruntime/textual_inversion/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..0f6ec7f51186539379df53e4b1b49c6332789cb1
--- /dev/null
+++ b/diffusers/examples/research_projects/onnxruntime/textual_inversion/README.md
@@ -0,0 +1,94 @@
+## Textual Inversion fine-tuning example
+
+[Textual inversion](https://arxiv.org/abs/2208.01618) is a method to personalize text2image models like stable diffusion on your own images using just 3-5 examples.
+The `textual_inversion.py` script shows how to implement the training procedure and adapt it for stable diffusion.
+
+## Running on Colab
+
+Colab for training
+[](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/sd_textual_inversion_training.ipynb)
+
+Colab for inference
+[](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/stable_conceptualizer_inference.ipynb)
+
+## Running locally with PyTorch
+### Installing the dependencies
+
+Before running the scripts, make sure to install the library's training dependencies:
+
+**Important**
+
+To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
+```bash
+git clone https://github.com/huggingface/diffusers
+cd diffusers
+pip install .
+```
+
+Then cd in the example folder and run
+```bash
+pip install -r requirements.txt
+```
+
+And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
+
+```bash
+accelerate config
+```
+
+
+### Cat toy example
+
+You need to accept the model license before downloading or using the weights. In this example we'll use model version `v1-5`, so you'll need to visit [its card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5), read the license and tick the checkbox if you agree.
+
+You have to be a registered user in 🤗 Hugging Face Hub, and you'll also need to use an access token for the code to work. For more information on access tokens, please refer to [this section of the documentation](https://huggingface.co/docs/hub/security-tokens).
+
+Run the following command to authenticate your token
+
+```bash
+huggingface-cli login
+```
+
+If you have already cloned the repo, then you won't need to go through these steps.
+
+
+
+Now let's get our dataset. For this example we will use some cat images: https://huggingface.co/datasets/diffusers/cat_toy_example .
+
+Let's first download it locally:
+
+```py
+from huggingface_hub import snapshot_download
+
+local_dir = "./cat"
+snapshot_download("diffusers/cat_toy_example", local_dir=local_dir, repo_type="dataset", ignore_patterns=".gitattributes")
+```
+
+This will be our training data.
+Now we can launch the training using
+
+## Use ONNXRuntime to accelerate training
+In order to leverage onnxruntime to accelerate training, please use textual_inversion.py
+
+The command to train on custom data with onnxruntime:
+
+```bash
+export MODEL_NAME="stable-diffusion-v1-5/stable-diffusion-v1-5"
+export DATA_DIR="path-to-dir-containing-images"
+
+accelerate launch textual_inversion.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --train_data_dir=$DATA_DIR \
+ --learnable_property="object" \
+ --placeholder_token="" --initializer_token="toy" \
+ --resolution=512 \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=4 \
+ --max_train_steps=3000 \
+ --learning_rate=5.0e-04 --scale_lr \
+ --lr_scheduler="constant" \
+ --lr_warmup_steps=0 \
+ --output_dir="textual_inversion_cat"
+```
+
+Please contact Prathik Rao (prathikr), Sunghoon Choi (hanbitmyths), Ashwini Khade (askhade), or Peng Wang (pengwa) on github with any questions.
\ No newline at end of file
diff --git a/diffusers/examples/research_projects/onnxruntime/textual_inversion/requirements.txt b/diffusers/examples/research_projects/onnxruntime/textual_inversion/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..c1a94eac83e6eb9a3a2dd11672b5d73f794ca3d1
--- /dev/null
+++ b/diffusers/examples/research_projects/onnxruntime/textual_inversion/requirements.txt
@@ -0,0 +1,6 @@
+accelerate>=0.16.0
+torchvision
+transformers>=4.25.1
+ftfy
+tensorboard
+modelcards
diff --git a/diffusers/examples/research_projects/onnxruntime/textual_inversion/textual_inversion.py b/diffusers/examples/research_projects/onnxruntime/textual_inversion/textual_inversion.py
new file mode 100644
index 0000000000000000000000000000000000000000..e10564fa59efb191ea37fef6c2fa15d9203c7f86
--- /dev/null
+++ b/diffusers/examples/research_projects/onnxruntime/textual_inversion/textual_inversion.py
@@ -0,0 +1,958 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+import argparse
+import logging
+import math
+import os
+import random
+import warnings
+from pathlib import Path
+
+import numpy as np
+import PIL
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import ProjectConfiguration, set_seed
+from huggingface_hub import create_repo, upload_folder
+from onnxruntime.training.optim.fp16_optimizer import FP16_Optimizer as ORT_FP16_Optimizer
+from onnxruntime.training.ortmodule import ORTModule
+
+# TODO: remove and import from diffusers.utils when the new version of diffusers is released
+from packaging import version
+from PIL import Image
+from torch.utils.data import Dataset
+from torchvision import transforms
+from tqdm.auto import tqdm
+from transformers import CLIPTextModel, CLIPTokenizer
+
+import diffusers
+from diffusers import (
+ AutoencoderKL,
+ DDPMScheduler,
+ DiffusionPipeline,
+ DPMSolverMultistepScheduler,
+ StableDiffusionPipeline,
+ UNet2DConditionModel,
+)
+from diffusers.optimization import get_scheduler
+from diffusers.utils import check_min_version, is_wandb_available
+from diffusers.utils.import_utils import is_xformers_available
+
+
+if is_wandb_available():
+ import wandb
+
+if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
+ PIL_INTERPOLATION = {
+ "linear": PIL.Image.Resampling.BILINEAR,
+ "bilinear": PIL.Image.Resampling.BILINEAR,
+ "bicubic": PIL.Image.Resampling.BICUBIC,
+ "lanczos": PIL.Image.Resampling.LANCZOS,
+ "nearest": PIL.Image.Resampling.NEAREST,
+ }
+else:
+ PIL_INTERPOLATION = {
+ "linear": PIL.Image.LINEAR,
+ "bilinear": PIL.Image.BILINEAR,
+ "bicubic": PIL.Image.BICUBIC,
+ "lanczos": PIL.Image.LANCZOS,
+ "nearest": PIL.Image.NEAREST,
+ }
+# ------------------------------------------------------------------------------
+
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.17.0.dev0")
+
+logger = get_logger(__name__)
+
+
+def save_model_card(repo_id: str, images=None, base_model=str, repo_folder=None):
+ img_str = ""
+ for i, image in enumerate(images):
+ image.save(os.path.join(repo_folder, f"image_{i}.png"))
+ img_str += f"\n"
+
+ yaml = f"""
+---
+license: creativeml-openrail-m
+base_model: {base_model}
+tags:
+- stable-diffusion
+- stable-diffusion-diffusers
+- text-to-image
+- diffusers
+- textual_inversion
+- diffusers-training
+- onxruntime
+inference: true
+---
+ """
+ model_card = f"""
+# Textual inversion text2image fine-tuning - {repo_id}
+These are textual inversion adaption weights for {base_model}. You can find some example images in the following. \n
+{img_str}
+"""
+ with open(os.path.join(repo_folder, "README.md"), "w") as f:
+ f.write(yaml + model_card)
+
+
+def log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch):
+ logger.info(
+ f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
+ f" {args.validation_prompt}."
+ )
+ # create pipeline (note: unet and vae are loaded again in float32)
+ pipeline = DiffusionPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ text_encoder=accelerator.unwrap_model(text_encoder),
+ tokenizer=tokenizer,
+ unet=unet,
+ vae=vae,
+ safety_checker=None,
+ revision=args.revision,
+ torch_dtype=weight_dtype,
+ )
+ pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
+ pipeline = pipeline.to(accelerator.device)
+ pipeline.set_progress_bar_config(disable=True)
+
+ # run inference
+ generator = None if args.seed is None else torch.Generator(device=accelerator.device).manual_seed(args.seed)
+ images = []
+ for _ in range(args.num_validation_images):
+ with torch.autocast("cuda"):
+ image = pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]
+ images.append(image)
+
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
+ if tracker.name == "wandb":
+ tracker.log(
+ {
+ "validation": [
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images)
+ ]
+ }
+ )
+
+ del pipeline
+ torch.cuda.empty_cache()
+ return images
+
+
+def save_progress(text_encoder, placeholder_token_ids, accelerator, args, save_path):
+ logger.info("Saving embeddings")
+ learned_embeds = (
+ accelerator.unwrap_model(text_encoder)
+ .get_input_embeddings()
+ .weight[min(placeholder_token_ids) : max(placeholder_token_ids) + 1]
+ )
+ learned_embeds_dict = {args.placeholder_token: learned_embeds.detach().cpu()}
+ torch.save(learned_embeds_dict, save_path)
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ parser.add_argument(
+ "--save_steps",
+ type=int,
+ default=500,
+ help="Save learned_embeds.bin every X updates steps.",
+ )
+ parser.add_argument(
+ "--save_as_full_pipeline",
+ action="store_true",
+ help="Save the complete stable diffusion pipeline.",
+ )
+ parser.add_argument(
+ "--num_vectors",
+ type=int,
+ default=1,
+ help="How many textual inversion vectors shall be used to learn the concept.",
+ )
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--tokenizer_name",
+ type=str,
+ default=None,
+ help="Pretrained tokenizer name or path if not the same as model_name",
+ )
+ parser.add_argument(
+ "--train_data_dir", type=str, default=None, required=True, help="A folder containing the training data."
+ )
+ parser.add_argument(
+ "--placeholder_token",
+ type=str,
+ default=None,
+ required=True,
+ help="A token to use as a placeholder for the concept.",
+ )
+ parser.add_argument(
+ "--initializer_token", type=str, default=None, required=True, help="A token to use as initializer word."
+ )
+ parser.add_argument("--learnable_property", type=str, default="object", help="Choose between 'object' and 'style'")
+ parser.add_argument("--repeats", type=int, default=100, help="How many times to repeat the training data.")
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="text-inversion-model",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution."
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=100)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=5000,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default="no",
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose"
+ "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
+ "and an Nvidia Ampere GPU."
+ ),
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument(
+ "--validation_prompt",
+ type=str,
+ default=None,
+ help="A prompt that is used during validation to verify that the model is learning.",
+ )
+ parser.add_argument(
+ "--num_validation_images",
+ type=int,
+ default=4,
+ help="Number of images that should be generated during validation with `validation_prompt`.",
+ )
+ parser.add_argument(
+ "--validation_steps",
+ type=int,
+ default=100,
+ help=(
+ "Run validation every X steps. Validation consists of running the prompt"
+ " `args.validation_prompt` multiple times: `args.num_validation_images`"
+ " and logging the images."
+ ),
+ )
+ parser.add_argument(
+ "--validation_epochs",
+ type=int,
+ default=None,
+ help=(
+ "Deprecated in favor of validation_steps. Run validation every X epochs. Validation consists of running the prompt"
+ " `args.validation_prompt` multiple times: `args.num_validation_images`"
+ " and logging the images."
+ ),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=(
+ "Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
+ " See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
+ " for more docs"
+ ),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
+ )
+
+ args = parser.parse_args()
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ if args.train_data_dir is None:
+ raise ValueError("You must specify a train data directory.")
+
+ return args
+
+
+imagenet_templates_small = [
+ "a photo of a {}",
+ "a rendering of a {}",
+ "a cropped photo of the {}",
+ "the photo of a {}",
+ "a photo of a clean {}",
+ "a photo of a dirty {}",
+ "a dark photo of the {}",
+ "a photo of my {}",
+ "a photo of the cool {}",
+ "a close-up photo of a {}",
+ "a bright photo of the {}",
+ "a cropped photo of a {}",
+ "a photo of the {}",
+ "a good photo of the {}",
+ "a photo of one {}",
+ "a close-up photo of the {}",
+ "a rendition of the {}",
+ "a photo of the clean {}",
+ "a rendition of a {}",
+ "a photo of a nice {}",
+ "a good photo of a {}",
+ "a photo of the nice {}",
+ "a photo of the small {}",
+ "a photo of the weird {}",
+ "a photo of the large {}",
+ "a photo of a cool {}",
+ "a photo of a small {}",
+]
+
+imagenet_style_templates_small = [
+ "a painting in the style of {}",
+ "a rendering in the style of {}",
+ "a cropped painting in the style of {}",
+ "the painting in the style of {}",
+ "a clean painting in the style of {}",
+ "a dirty painting in the style of {}",
+ "a dark painting in the style of {}",
+ "a picture in the style of {}",
+ "a cool painting in the style of {}",
+ "a close-up painting in the style of {}",
+ "a bright painting in the style of {}",
+ "a cropped painting in the style of {}",
+ "a good painting in the style of {}",
+ "a close-up painting in the style of {}",
+ "a rendition in the style of {}",
+ "a nice painting in the style of {}",
+ "a small painting in the style of {}",
+ "a weird painting in the style of {}",
+ "a large painting in the style of {}",
+]
+
+
+class TextualInversionDataset(Dataset):
+ def __init__(
+ self,
+ data_root,
+ tokenizer,
+ learnable_property="object", # [object, style]
+ size=512,
+ repeats=100,
+ interpolation="bicubic",
+ flip_p=0.5,
+ set="train",
+ placeholder_token="*",
+ center_crop=False,
+ ):
+ self.data_root = data_root
+ self.tokenizer = tokenizer
+ self.learnable_property = learnable_property
+ self.size = size
+ self.placeholder_token = placeholder_token
+ self.center_crop = center_crop
+ self.flip_p = flip_p
+
+ self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)]
+
+ self.num_images = len(self.image_paths)
+ self._length = self.num_images
+
+ if set == "train":
+ self._length = self.num_images * repeats
+
+ self.interpolation = {
+ "linear": PIL_INTERPOLATION["linear"],
+ "bilinear": PIL_INTERPOLATION["bilinear"],
+ "bicubic": PIL_INTERPOLATION["bicubic"],
+ "lanczos": PIL_INTERPOLATION["lanczos"],
+ }[interpolation]
+
+ self.templates = imagenet_style_templates_small if learnable_property == "style" else imagenet_templates_small
+ self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p)
+
+ def __len__(self):
+ return self._length
+
+ def __getitem__(self, i):
+ example = {}
+ image = Image.open(self.image_paths[i % self.num_images])
+
+ if not image.mode == "RGB":
+ image = image.convert("RGB")
+
+ placeholder_string = self.placeholder_token
+ text = random.choice(self.templates).format(placeholder_string)
+
+ example["input_ids"] = self.tokenizer(
+ text,
+ padding="max_length",
+ truncation=True,
+ max_length=self.tokenizer.model_max_length,
+ return_tensors="pt",
+ ).input_ids[0]
+
+ # default to score-sde preprocessing
+ img = np.array(image).astype(np.uint8)
+
+ if self.center_crop:
+ crop = min(img.shape[0], img.shape[1])
+ (
+ h,
+ w,
+ ) = (
+ img.shape[0],
+ img.shape[1],
+ )
+ img = img[(h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2]
+
+ image = Image.fromarray(img)
+ image = image.resize((self.size, self.size), resample=self.interpolation)
+
+ image = self.flip_transform(image)
+ image = np.array(image).astype(np.uint8)
+ image = (image / 127.5 - 1.0).astype(np.float32)
+
+ example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1)
+ return example
+
+
+def main():
+ args = parse_args()
+ if args.report_to == "wandb" and args.hub_token is not None:
+ raise ValueError(
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
+ " Please use `huggingface-cli login` to authenticate with the Hub."
+ )
+
+ logging_dir = os.path.join(args.output_dir, args.logging_dir)
+ accelerator_project_config = ProjectConfiguration(
+ total_limit=args.checkpoints_total_limit, project_dir=args.output_dir, logging_dir=logging_dir
+ )
+
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ )
+
+ # Disable AMP for MPS.
+ if torch.backends.mps.is_available():
+ accelerator.native_amp = False
+
+ if args.report_to == "wandb":
+ if not is_wandb_available():
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
+ ).repo_id
+
+ # Load tokenizer
+ if args.tokenizer_name:
+ tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
+ elif args.pretrained_model_name_or_path:
+ tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
+
+ # Load scheduler and models
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
+ text_encoder = CLIPTextModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
+ )
+ vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
+ unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
+ )
+
+ # Add the placeholder token in tokenizer
+ placeholder_tokens = [args.placeholder_token]
+
+ if args.num_vectors < 1:
+ raise ValueError(f"--num_vectors has to be larger or equal to 1, but is {args.num_vectors}")
+
+ # add dummy tokens for multi-vector
+ additional_tokens = []
+ for i in range(1, args.num_vectors):
+ additional_tokens.append(f"{args.placeholder_token}_{i}")
+ placeholder_tokens += additional_tokens
+
+ num_added_tokens = tokenizer.add_tokens(placeholder_tokens)
+ if num_added_tokens != args.num_vectors:
+ raise ValueError(
+ f"The tokenizer already contains the token {args.placeholder_token}. Please pass a different"
+ " `placeholder_token` that is not already in the tokenizer."
+ )
+
+ # Convert the initializer_token, placeholder_token to ids
+ token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False)
+ # Check if initializer_token is a single token or a sequence of tokens
+ if len(token_ids) > 1:
+ raise ValueError("The initializer token must be a single token.")
+
+ initializer_token_id = token_ids[0]
+ placeholder_token_ids = tokenizer.convert_tokens_to_ids(placeholder_tokens)
+
+ # Resize the token embeddings as we are adding new special tokens to the tokenizer
+ text_encoder.resize_token_embeddings(len(tokenizer))
+
+ # Initialise the newly added placeholder token with the embeddings of the initializer token
+ token_embeds = text_encoder.get_input_embeddings().weight.data
+ with torch.no_grad():
+ for token_id in placeholder_token_ids:
+ token_embeds[token_id] = token_embeds[initializer_token_id].clone()
+
+ # Freeze vae and unet
+ vae.requires_grad_(False)
+ unet.requires_grad_(False)
+ # Freeze all parameters except for the token embeddings in text encoder
+ text_encoder.text_model.encoder.requires_grad_(False)
+ text_encoder.text_model.final_layer_norm.requires_grad_(False)
+ text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)
+
+ if args.gradient_checkpointing:
+ # Keep unet in train mode if we are using gradient checkpointing to save memory.
+ # The dropout cannot be != 0 so it doesn't matter if we are in eval or train mode.
+ unet.train()
+ text_encoder.gradient_checkpointing_enable()
+ unet.enable_gradient_checkpointing()
+
+ if args.enable_xformers_memory_efficient_attention:
+ if is_xformers_available():
+ import xformers
+
+ xformers_version = version.parse(xformers.__version__)
+ if xformers_version == version.parse("0.0.16"):
+ logger.warning(
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
+ )
+ unet.enable_xformers_memory_efficient_attention()
+ else:
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ # Initialize the optimizer
+ optimizer = torch.optim.AdamW(
+ text_encoder.get_input_embeddings().parameters(), # only optimize the embeddings
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ optimizer = ORT_FP16_Optimizer(optimizer)
+
+ # Dataset and DataLoaders creation:
+ train_dataset = TextualInversionDataset(
+ data_root=args.train_data_dir,
+ tokenizer=tokenizer,
+ size=args.resolution,
+ placeholder_token=args.placeholder_token,
+ repeats=args.repeats,
+ learnable_property=args.learnable_property,
+ center_crop=args.center_crop,
+ set="train",
+ )
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers
+ )
+ if args.validation_epochs is not None:
+ warnings.warn(
+ f"FutureWarning: You are doing logging with validation_epochs={args.validation_epochs}."
+ " Deprecated validation_epochs in favor of `validation_steps`"
+ f"Setting `args.validation_steps` to {args.validation_epochs * len(train_dataset)}",
+ FutureWarning,
+ stacklevel=2,
+ )
+ args.validation_steps = args.validation_epochs * len(train_dataset)
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
+ )
+
+ # Prepare everything with our `accelerator`.
+ text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ text_encoder, optimizer, train_dataloader, lr_scheduler
+ )
+
+ text_encoder = ORTModule(text_encoder)
+ unet = ORTModule(unet)
+ vae = ORTModule(vae)
+
+ # For mixed precision training we cast the unet and vae weights to half-precision
+ # as these models are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ # Move vae and unet to device and cast to weight_dtype
+ unet.to(accelerator.device, dtype=weight_dtype)
+ vae.to(accelerator.device, dtype=weight_dtype)
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ accelerator.init_trackers("textual_inversion", config=vars(args))
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the most recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ resume_global_step = global_step * args.gradient_accumulation_steps
+ first_epoch = global_step // num_update_steps_per_epoch
+ resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
+
+ # Only show the progress bar once on each machine.
+ progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
+ progress_bar.set_description("Steps")
+
+ # keep original embeddings as reference
+ orig_embeds_params = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight.data.clone()
+
+ for epoch in range(first_epoch, args.num_train_epochs):
+ text_encoder.train()
+ for step, batch in enumerate(train_dataloader):
+ # Skip steps until we reach the resumed step
+ if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
+ if step % args.gradient_accumulation_steps == 0:
+ progress_bar.update(1)
+ continue
+
+ with accelerator.accumulate(text_encoder):
+ # Convert images to latent space
+ latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample().detach()
+ latents = latents * vae.config.scaling_factor
+
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(latents)
+ bsz = latents.shape[0]
+ # Sample a random timestep for each image
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
+ timesteps = timesteps.long()
+
+ # Add noise to the latents according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
+
+ # Get the text embedding for conditioning
+ encoder_hidden_states = text_encoder(batch["input_ids"])[0].to(dtype=weight_dtype)
+
+ # Predict the noise residual
+ model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
+
+ # Get the target for loss depending on the prediction type
+ if noise_scheduler.config.prediction_type == "epsilon":
+ target = noise
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
+ else:
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
+
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
+
+ accelerator.backward(loss)
+
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ # Let's make sure we don't update any embedding weights besides the newly added token
+ index_no_updates = torch.ones((len(tokenizer),), dtype=torch.bool)
+ index_no_updates[min(placeholder_token_ids) : max(placeholder_token_ids) + 1] = False
+
+ with torch.no_grad():
+ accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[
+ index_no_updates
+ ] = orig_embeds_params[index_no_updates]
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ images = []
+ progress_bar.update(1)
+ global_step += 1
+ if global_step % args.save_steps == 0:
+ save_path = os.path.join(args.output_dir, f"learned_embeds-steps-{global_step}.bin")
+ save_progress(text_encoder, placeholder_token_ids, accelerator, args, save_path)
+
+ if accelerator.is_main_process:
+ if global_step % args.checkpointing_steps == 0:
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ if args.validation_prompt is not None and global_step % args.validation_steps == 0:
+ images = log_validation(
+ text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch
+ )
+
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+ accelerator.log(logs, step=global_step)
+
+ if global_step >= args.max_train_steps:
+ break
+ # Create the pipeline using the trained modules and save it.
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ if args.push_to_hub and not args.save_as_full_pipeline:
+ logger.warning("Enabling full model saving because --push_to_hub=True was specified.")
+ save_full_model = True
+ else:
+ save_full_model = args.save_as_full_pipeline
+ if save_full_model:
+ pipeline = StableDiffusionPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ text_encoder=accelerator.unwrap_model(text_encoder),
+ vae=vae,
+ unet=unet,
+ tokenizer=tokenizer,
+ )
+ pipeline.save_pretrained(args.output_dir)
+ # Save the newly trained embeddings
+ save_path = os.path.join(args.output_dir, "learned_embeds.bin")
+ save_progress(text_encoder, placeholder_token_ids, accelerator, args, save_path)
+
+ if args.push_to_hub:
+ save_model_card(
+ repo_id,
+ images=images,
+ base_model=args.pretrained_model_name_or_path,
+ repo_folder=args.output_dir,
+ )
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/diffusers/examples/research_projects/onnxruntime/unconditional_image_generation/README.md b/diffusers/examples/research_projects/onnxruntime/unconditional_image_generation/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..c28ecefc9a3002b2f6c6d3d97e53047e82ab2733
--- /dev/null
+++ b/diffusers/examples/research_projects/onnxruntime/unconditional_image_generation/README.md
@@ -0,0 +1,50 @@
+## Training examples
+
+Creating a training image set is [described in a different document](https://huggingface.co/docs/datasets/image_process#image-datasets).
+
+### Installing the dependencies
+
+Before running the scripts, make sure to install the library's training dependencies:
+
+**Important**
+
+To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
+```bash
+git clone https://github.com/huggingface/diffusers
+cd diffusers
+pip install .
+```
+
+Then cd in the example folder and run
+```bash
+pip install -r requirements.txt
+```
+
+
+And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
+
+```bash
+accelerate config
+```
+
+#### Use ONNXRuntime to accelerate training
+
+In order to leverage onnxruntime to accelerate training, please use train_unconditional_ort.py
+
+The command to train a DDPM UNet model on the Oxford Flowers dataset with onnxruntime:
+
+```bash
+accelerate launch train_unconditional.py \
+ --dataset_name="huggan/flowers-102-categories" \
+ --resolution=64 --center_crop --random_flip \
+ --output_dir="ddpm-ema-flowers-64" \
+ --use_ema \
+ --train_batch_size=16 \
+ --num_epochs=1 \
+ --gradient_accumulation_steps=1 \
+ --learning_rate=1e-4 \
+ --lr_warmup_steps=500 \
+ --mixed_precision=fp16
+ ```
+
+Please contact Prathik Rao (prathikr), Sunghoon Choi (hanbitmyths), Ashwini Khade (askhade), or Peng Wang (pengwa) on github with any questions.
diff --git a/diffusers/examples/research_projects/onnxruntime/unconditional_image_generation/requirements.txt b/diffusers/examples/research_projects/onnxruntime/unconditional_image_generation/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..ca21143c42d9c08bf693fcc8cd11fed53acb895f
--- /dev/null
+++ b/diffusers/examples/research_projects/onnxruntime/unconditional_image_generation/requirements.txt
@@ -0,0 +1,4 @@
+accelerate>=0.16.0
+torchvision
+datasets
+tensorboard
\ No newline at end of file
diff --git a/diffusers/examples/research_projects/onnxruntime/unconditional_image_generation/train_unconditional.py b/diffusers/examples/research_projects/onnxruntime/unconditional_image_generation/train_unconditional.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a00f7cc4a9adf7cde0cbd17b9aa3b2153927cdf
--- /dev/null
+++ b/diffusers/examples/research_projects/onnxruntime/unconditional_image_generation/train_unconditional.py
@@ -0,0 +1,697 @@
+import argparse
+import inspect
+import logging
+import math
+import os
+from pathlib import Path
+
+import accelerate
+import datasets
+import torch
+import torch.nn.functional as F
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import ProjectConfiguration
+from datasets import load_dataset
+from huggingface_hub import create_repo, upload_folder
+from onnxruntime.training.optim.fp16_optimizer import FP16_Optimizer as ORT_FP16_Optimizer
+from onnxruntime.training.ortmodule import ORTModule
+from packaging import version
+from torchvision import transforms
+from tqdm.auto import tqdm
+
+import diffusers
+from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
+from diffusers.optimization import get_scheduler
+from diffusers.training_utils import EMAModel
+from diffusers.utils import check_min_version, is_accelerate_version, is_tensorboard_available, is_wandb_available
+from diffusers.utils.import_utils import is_xformers_available
+
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.17.0.dev0")
+
+logger = get_logger(__name__, log_level="INFO")
+
+
+def _extract_into_tensor(arr, timesteps, broadcast_shape):
+ """
+ Extract values from a 1-D numpy array for a batch of indices.
+
+ :param arr: the 1-D numpy array.
+ :param timesteps: a tensor of indices into the array to extract.
+ :param broadcast_shape: a larger shape of K dimensions with the batch
+ dimension equal to the length of timesteps.
+ :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
+ """
+ if not isinstance(arr, torch.Tensor):
+ arr = torch.from_numpy(arr)
+ res = arr[timesteps].float().to(timesteps.device)
+ while len(res.shape) < len(broadcast_shape):
+ res = res[..., None]
+ return res.expand(broadcast_shape)
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ parser.add_argument(
+ "--dataset_name",
+ type=str,
+ default=None,
+ help=(
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
+ " or to a folder containing files that HF Datasets can understand."
+ ),
+ )
+ parser.add_argument(
+ "--dataset_config_name",
+ type=str,
+ default=None,
+ help="The config of the Dataset, leave as None if there's only one config.",
+ )
+ parser.add_argument(
+ "--model_config_name_or_path",
+ type=str,
+ default=None,
+ help="The config of the UNet model to train, leave as None to use standard DDPM configuration.",
+ )
+ parser.add_argument(
+ "--train_data_dir",
+ type=str,
+ default=None,
+ help=(
+ "A folder containing the training data. Folder contents must follow the structure described in"
+ " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
+ " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
+ ),
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="ddpm-model-64",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument("--overwrite_output_dir", action="store_true")
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=64,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--center_crop",
+ default=False,
+ action="store_true",
+ help=(
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
+ " cropped. The images will be resized to the resolution first before cropping."
+ ),
+ )
+ parser.add_argument(
+ "--random_flip",
+ default=False,
+ action="store_true",
+ help="whether to randomly flip images horizontally",
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument(
+ "--eval_batch_size", type=int, default=16, help="The number of images to generate for evaluation."
+ )
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "The number of subprocesses to use for data loading. 0 means that the data will be loaded in the main"
+ " process."
+ ),
+ )
+ parser.add_argument("--num_epochs", type=int, default=100)
+ parser.add_argument("--save_images_epochs", type=int, default=10, help="How often to save images during training.")
+ parser.add_argument(
+ "--save_model_epochs", type=int, default=10, help="How often to save the model during training."
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="cosine",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.95, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument(
+ "--adam_weight_decay", type=float, default=1e-6, help="Weight decay magnitude for the Adam optimizer."
+ )
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer.")
+ parser.add_argument(
+ "--use_ema",
+ action="store_true",
+ help="Whether to use Exponential Moving Average for the final model weights.",
+ )
+ parser.add_argument("--ema_inv_gamma", type=float, default=1.0, help="The inverse gamma value for the EMA decay.")
+ parser.add_argument("--ema_power", type=float, default=3 / 4, help="The power value for the EMA decay.")
+ parser.add_argument("--ema_max_decay", type=float, default=0.9999, help="The maximum decay magnitude for EMA.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--hub_private_repo", action="store_true", help="Whether or not to create a private repository."
+ )
+ parser.add_argument(
+ "--logger",
+ type=str,
+ default="tensorboard",
+ choices=["tensorboard", "wandb"],
+ help=(
+ "Whether to use [tensorboard](https://www.tensorflow.org/tensorboard) or [wandb](https://www.wandb.ai)"
+ " for experiment tracking and logging of model metrics and model checkpoints"
+ ),
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default="no",
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose"
+ "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
+ "and an Nvidia Ampere GPU."
+ ),
+ )
+ parser.add_argument(
+ "--prediction_type",
+ type=str,
+ default="epsilon",
+ choices=["epsilon", "sample"],
+ help="Whether the model should predict the 'epsilon'/noise error or directly the reconstructed image 'x0'.",
+ )
+ parser.add_argument("--ddpm_num_steps", type=int, default=1000)
+ parser.add_argument("--ddpm_num_inference_steps", type=int, default=1000)
+ parser.add_argument("--ddpm_beta_schedule", type=str, default="linear")
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=(
+ "Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
+ " See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
+ " for more docs"
+ ),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
+ )
+
+ args = parser.parse_args()
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ if args.dataset_name is None and args.train_data_dir is None:
+ raise ValueError("You must specify either a dataset name from the hub or a train data directory.")
+
+ return args
+
+
+def main(args):
+ if args.report_to == "wandb" and args.hub_token is not None:
+ raise ValueError(
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
+ " Please use `huggingface-cli login` to authenticate with the Hub."
+ )
+
+ logging_dir = os.path.join(args.output_dir, args.logging_dir)
+ accelerator_project_config = ProjectConfiguration(
+ total_limit=args.checkpoints_total_limit, project_dir=args.output_dir, logging_dir=logging_dir
+ )
+
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ )
+
+ # Disable AMP for MPS.
+ if torch.backends.mps.is_available():
+ accelerator.native_amp = False
+
+ if args.logger == "tensorboard":
+ if not is_tensorboard_available():
+ raise ImportError("Make sure to install tensorboard if you want to use it for logging during training.")
+
+ elif args.logger == "wandb":
+ if not is_wandb_available():
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
+ import wandb
+
+ # `accelerate` 0.16.0 will have better support for customized saving
+ if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
+ def save_model_hook(models, weights, output_dir):
+ if accelerator.is_main_process:
+ if args.use_ema:
+ ema_model.save_pretrained(os.path.join(output_dir, "unet_ema"))
+
+ for i, model in enumerate(models):
+ model.save_pretrained(os.path.join(output_dir, "unet"))
+
+ # make sure to pop weight so that corresponding model is not saved again
+ weights.pop()
+
+ def load_model_hook(models, input_dir):
+ if args.use_ema:
+ load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), UNet2DModel)
+ ema_model.load_state_dict(load_model.state_dict())
+ ema_model.to(accelerator.device)
+ del load_model
+
+ for i in range(len(models)):
+ # pop models so that they are not loaded again
+ model = models.pop()
+
+ # load diffusers style into model
+ load_model = UNet2DModel.from_pretrained(input_dir, subfolder="unet")
+ model.register_to_config(**load_model.config)
+
+ model.load_state_dict(load_model.state_dict())
+ del load_model
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ datasets.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ datasets.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
+ ).repo_id
+
+ # Initialize the model
+ if args.model_config_name_or_path is None:
+ model = UNet2DModel(
+ sample_size=args.resolution,
+ in_channels=3,
+ out_channels=3,
+ layers_per_block=2,
+ block_out_channels=(128, 128, 256, 256, 512, 512),
+ down_block_types=(
+ "DownBlock2D",
+ "DownBlock2D",
+ "DownBlock2D",
+ "DownBlock2D",
+ "AttnDownBlock2D",
+ "DownBlock2D",
+ ),
+ up_block_types=(
+ "UpBlock2D",
+ "AttnUpBlock2D",
+ "UpBlock2D",
+ "UpBlock2D",
+ "UpBlock2D",
+ "UpBlock2D",
+ ),
+ )
+ else:
+ config = UNet2DModel.load_config(args.model_config_name_or_path)
+ model = UNet2DModel.from_config(config)
+
+ # Create EMA for the model.
+ if args.use_ema:
+ ema_model = EMAModel(
+ model.parameters(),
+ decay=args.ema_max_decay,
+ use_ema_warmup=True,
+ inv_gamma=args.ema_inv_gamma,
+ power=args.ema_power,
+ model_cls=UNet2DModel,
+ model_config=model.config,
+ )
+
+ if args.enable_xformers_memory_efficient_attention:
+ if is_xformers_available():
+ import xformers
+
+ xformers_version = version.parse(xformers.__version__)
+ if xformers_version == version.parse("0.0.16"):
+ logger.warning(
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
+ )
+ model.enable_xformers_memory_efficient_attention()
+ else:
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
+
+ # Initialize the scheduler
+ accepts_prediction_type = "prediction_type" in set(inspect.signature(DDPMScheduler.__init__).parameters.keys())
+ if accepts_prediction_type:
+ noise_scheduler = DDPMScheduler(
+ num_train_timesteps=args.ddpm_num_steps,
+ beta_schedule=args.ddpm_beta_schedule,
+ prediction_type=args.prediction_type,
+ )
+ else:
+ noise_scheduler = DDPMScheduler(num_train_timesteps=args.ddpm_num_steps, beta_schedule=args.ddpm_beta_schedule)
+
+ # Initialize the optimizer
+ optimizer = torch.optim.AdamW(
+ model.parameters(),
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ optimizer = ORT_FP16_Optimizer(optimizer)
+
+ # Get the datasets: you can either provide your own training and evaluation files (see below)
+ # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
+
+ # In distributed training, the load_dataset function guarantees that only one local process can concurrently
+ # download the dataset.
+ if args.dataset_name is not None:
+ dataset = load_dataset(
+ args.dataset_name,
+ args.dataset_config_name,
+ cache_dir=args.cache_dir,
+ split="train",
+ )
+ else:
+ dataset = load_dataset("imagefolder", data_dir=args.train_data_dir, cache_dir=args.cache_dir, split="train")
+ # See more about loading custom images at
+ # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
+
+ # Preprocessing the datasets and DataLoaders creation.
+ augmentations = transforms.Compose(
+ [
+ transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
+ transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
+ transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+ )
+
+ def transform_images(examples):
+ images = [augmentations(image.convert("RGB")) for image in examples["image"]]
+ return {"input": images}
+
+ logger.info(f"Dataset size: {len(dataset)}")
+
+ dataset.set_transform(transform_images)
+ train_dataloader = torch.utils.data.DataLoader(
+ dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers
+ )
+
+ # Initialize the learning rate scheduler
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
+ num_training_steps=(len(train_dataloader) * args.num_epochs),
+ )
+
+ # Prepare everything with our `accelerator`.
+ model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ model, optimizer, train_dataloader, lr_scheduler
+ )
+
+ if args.use_ema:
+ ema_model.to(accelerator.device)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ run = os.path.split(__file__)[-1].split(".")[0]
+ accelerator.init_trackers(run)
+
+ model = ORTModule(model)
+
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ max_train_steps = args.num_epochs * num_update_steps_per_epoch
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(dataset)}")
+ logger.info(f" Num Epochs = {args.num_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {max_train_steps}")
+
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the most recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ resume_global_step = global_step * args.gradient_accumulation_steps
+ first_epoch = global_step // num_update_steps_per_epoch
+ resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
+
+ # Train!
+ for epoch in range(first_epoch, args.num_epochs):
+ model.train()
+ progress_bar = tqdm(total=num_update_steps_per_epoch, disable=not accelerator.is_local_main_process)
+ progress_bar.set_description(f"Epoch {epoch}")
+ for step, batch in enumerate(train_dataloader):
+ # Skip steps until we reach the resumed step
+ if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
+ if step % args.gradient_accumulation_steps == 0:
+ progress_bar.update(1)
+ continue
+
+ clean_images = batch["input"]
+ # Sample noise that we'll add to the images
+ noise = torch.randn(
+ clean_images.shape, dtype=(torch.float32 if args.mixed_precision == "no" else torch.float16)
+ ).to(clean_images.device)
+ bsz = clean_images.shape[0]
+ # Sample a random timestep for each image
+ timesteps = torch.randint(
+ 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=clean_images.device
+ ).long()
+
+ # Add noise to the clean images according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps)
+
+ with accelerator.accumulate(model):
+ # Predict the noise residual
+ model_output = model(noisy_images, timesteps, return_dict=False)[0]
+
+ if args.prediction_type == "epsilon":
+ loss = F.mse_loss(model_output, noise) # this could have different weights!
+ elif args.prediction_type == "sample":
+ alpha_t = _extract_into_tensor(
+ noise_scheduler.alphas_cumprod, timesteps, (clean_images.shape[0], 1, 1, 1)
+ )
+ snr_weights = alpha_t / (1 - alpha_t)
+ loss = snr_weights * F.mse_loss(
+ model_output, clean_images, reduction="none"
+ ) # use SNR weighting from distillation paper
+ loss = loss.mean()
+ else:
+ raise ValueError(f"Unsupported prediction type: {args.prediction_type}")
+
+ accelerator.backward(loss)
+
+ if accelerator.sync_gradients:
+ accelerator.clip_grad_norm_(model.parameters(), 1.0)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ if args.use_ema:
+ ema_model.step(model.parameters())
+ progress_bar.update(1)
+ global_step += 1
+
+ if global_step % args.checkpointing_steps == 0:
+ if accelerator.is_main_process:
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step}
+ if args.use_ema:
+ logs["ema_decay"] = ema_model.cur_decay_value
+ progress_bar.set_postfix(**logs)
+ accelerator.log(logs, step=global_step)
+ progress_bar.close()
+
+ accelerator.wait_for_everyone()
+
+ # Generate sample images for visual inspection
+ if accelerator.is_main_process:
+ if epoch % args.save_images_epochs == 0 or epoch == args.num_epochs - 1:
+ unet = accelerator.unwrap_model(model)
+
+ if args.use_ema:
+ ema_model.store(unet.parameters())
+ ema_model.copy_to(unet.parameters())
+
+ pipeline = DDPMPipeline(
+ unet=unet,
+ scheduler=noise_scheduler,
+ )
+
+ generator = torch.Generator(device=pipeline.device).manual_seed(0)
+ # run pipeline in inference (sample random noise and denoise)
+ images = pipeline(
+ generator=generator,
+ batch_size=args.eval_batch_size,
+ num_inference_steps=args.ddpm_num_inference_steps,
+ output_type="np",
+ ).images
+
+ if args.use_ema:
+ ema_model.restore(unet.parameters())
+
+ # denormalize the images and save to tensorboard
+ images_processed = (images * 255).round().astype("uint8")
+
+ if args.logger == "tensorboard":
+ if is_accelerate_version(">=", "0.17.0.dev0"):
+ tracker = accelerator.get_tracker("tensorboard", unwrap=True)
+ else:
+ tracker = accelerator.get_tracker("tensorboard")
+ tracker.add_images("test_samples", images_processed.transpose(0, 3, 1, 2), epoch)
+ elif args.logger == "wandb":
+ # Upcoming `log_images` helper coming in https://github.com/huggingface/accelerate/pull/962/files
+ accelerator.get_tracker("wandb").log(
+ {"test_samples": [wandb.Image(img) for img in images_processed], "epoch": epoch},
+ step=global_step,
+ )
+
+ if epoch % args.save_model_epochs == 0 or epoch == args.num_epochs - 1:
+ # save the model
+ unet = accelerator.unwrap_model(model)
+
+ if args.use_ema:
+ ema_model.store(unet.parameters())
+ ema_model.copy_to(unet.parameters())
+
+ pipeline = DDPMPipeline(
+ unet=unet,
+ scheduler=noise_scheduler,
+ )
+
+ pipeline.save_pretrained(args.output_dir)
+
+ if args.use_ema:
+ ema_model.restore(unet.parameters())
+
+ if args.push_to_hub:
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message=f"Epoch {epoch}",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ main(args)
diff --git a/diffusers/examples/research_projects/promptdiffusion/README.md b/diffusers/examples/research_projects/promptdiffusion/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..33ffec312501f90bbe832829f5723c50b1d8899d
--- /dev/null
+++ b/diffusers/examples/research_projects/promptdiffusion/README.md
@@ -0,0 +1,49 @@
+# PromptDiffusion Pipeline
+
+From the project [page](https://zhendong-wang.github.io/prompt-diffusion.github.io/)
+
+"With a prompt consisting of a task-specific example pair of images and text guidance, and a new query image, Prompt Diffusion can comprehend the desired task and generate the corresponding output image on both seen (trained) and unseen (new) task types."
+
+For any usage questions, please refer to the [paper](https://arxiv.org/abs/2305.01115).
+
+Prepare models by converting them from the [checkpoint](https://huggingface.co/zhendongw/prompt-diffusion)
+
+To convert the controlnet, use cldm_v15.yaml from the [repository](https://github.com/Zhendong-Wang/Prompt-Diffusion/tree/main/models/):
+
+```bash
+python convert_original_promptdiffusion_to_diffusers.py --checkpoint_path path-to-network-step04999.ckpt --original_config_file path-to-cldm_v15.yaml --dump_path path-to-output-directory
+```
+
+To learn about how to convert the fine-tuned stable diffusion model, see the [Load different Stable Diffusion formats guide](https://huggingface.co/docs/diffusers/main/en/using-diffusers/other-formats).
+
+
+```py
+import torch
+from diffusers import UniPCMultistepScheduler
+from diffusers.utils import load_image
+from promptdiffusioncontrolnet import PromptDiffusionControlNetModel
+from pipeline_prompt_diffusion import PromptDiffusionPipeline
+
+
+from PIL import ImageOps
+
+image_a = ImageOps.invert(load_image("https://github.com/Zhendong-Wang/Prompt-Diffusion/blob/main/images_to_try/house_line.png?raw=true"))
+
+image_b = load_image("https://github.com/Zhendong-Wang/Prompt-Diffusion/blob/main/images_to_try/house.png?raw=true")
+query = ImageOps.invert(load_image("https://github.com/Zhendong-Wang/Prompt-Diffusion/blob/main/images_to_try/new_01.png?raw=true"))
+
+# load prompt diffusion controlnet and prompt diffusion
+
+controlnet = PromptDiffusionControlNetModel.from_pretrained("iczaw/prompt-diffusion-diffusers", subfolder="controlnet", torch_dtype=torch.float16)
+model_id = "path-to-model"
+pipe = PromptDiffusionPipeline.from_pretrained("iczaw/prompt-diffusion-diffusers", subfolder="base", controlnet=controlnet, torch_dtype=torch.float16, variant="fp16")
+
+# speed up diffusion process with faster scheduler and memory optimization
+pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
+# remove following line if xformers is not installed
+pipe.enable_xformers_memory_efficient_attention()
+pipe.enable_model_cpu_offload()
+# generate image
+generator = torch.manual_seed(0)
+image = pipe("a tortoise", num_inference_steps=20, generator=generator, image_pair=[image_a,image_b], image=query).images[0]
+```
diff --git a/diffusers/examples/research_projects/promptdiffusion/convert_original_promptdiffusion_to_diffusers.py b/diffusers/examples/research_projects/promptdiffusion/convert_original_promptdiffusion_to_diffusers.py
new file mode 100644
index 0000000000000000000000000000000000000000..26b56a21e8658b1e68826e807d748d69a696fe7f
--- /dev/null
+++ b/diffusers/examples/research_projects/promptdiffusion/convert_original_promptdiffusion_to_diffusers.py
@@ -0,0 +1,2118 @@
+# coding=utf-8
+# Copyright 2023 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Conversion script for stable diffusion checkpoints which _only_ contain a controlnet."""
+
+import argparse
+import re
+from contextlib import nullcontext
+from io import BytesIO
+from typing import Dict, Optional, Union
+
+import requests
+import torch
+import yaml
+from promptdiffusioncontrolnet import PromptDiffusionControlNetModel
+from transformers import (
+ AutoFeatureExtractor,
+ BertTokenizerFast,
+ CLIPImageProcessor,
+ CLIPTextConfig,
+ CLIPTextModel,
+ CLIPTextModelWithProjection,
+ CLIPTokenizer,
+ CLIPVisionConfig,
+ CLIPVisionModelWithProjection,
+)
+
+from diffusers.models import (
+ AutoencoderKL,
+ ControlNetModel,
+ PriorTransformer,
+ UNet2DConditionModel,
+)
+from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel
+from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline
+from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
+from diffusers.pipelines.stable_diffusion.stable_unclip_image_normalizer import StableUnCLIPImageNormalizer
+from diffusers.schedulers import (
+ DDIMScheduler,
+ DDPMScheduler,
+ DPMSolverMultistepScheduler,
+ EulerAncestralDiscreteScheduler,
+ EulerDiscreteScheduler,
+ HeunDiscreteScheduler,
+ LMSDiscreteScheduler,
+ PNDMScheduler,
+ UnCLIPScheduler,
+)
+from diffusers.utils import is_accelerate_available, logging
+
+
+if is_accelerate_available():
+ from accelerate import init_empty_weights
+ from accelerate.utils import set_module_tensor_to_device
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+def shave_segments(path, n_shave_prefix_segments=1):
+ """
+ Removes segments. Positive values shave the first segments, negative shave the last segments.
+ """
+ if n_shave_prefix_segments >= 0:
+ return ".".join(path.split(".")[n_shave_prefix_segments:])
+ else:
+ return ".".join(path.split(".")[:n_shave_prefix_segments])
+
+
+def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
+ """
+ Updates paths inside resnets to the new naming scheme (local renaming)
+ """
+ mapping = []
+ for old_item in old_list:
+ new_item = old_item.replace("in_layers.0", "norm1")
+ new_item = new_item.replace("in_layers.2", "conv1")
+
+ new_item = new_item.replace("out_layers.0", "norm2")
+ new_item = new_item.replace("out_layers.3", "conv2")
+
+ new_item = new_item.replace("emb_layers.1", "time_emb_proj")
+ new_item = new_item.replace("skip_connection", "conv_shortcut")
+
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
+
+ mapping.append({"old": old_item, "new": new_item})
+
+ return mapping
+
+
+def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
+ """
+ Updates paths inside resnets to the new naming scheme (local renaming)
+ """
+ mapping = []
+ for old_item in old_list:
+ new_item = old_item
+
+ new_item = new_item.replace("nin_shortcut", "conv_shortcut")
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
+
+ mapping.append({"old": old_item, "new": new_item})
+
+ return mapping
+
+
+def renew_attention_paths(old_list, n_shave_prefix_segments=0):
+ """
+ Updates paths inside attentions to the new naming scheme (local renaming)
+ """
+ mapping = []
+ for old_item in old_list:
+ new_item = old_item
+
+ # new_item = new_item.replace('norm.weight', 'group_norm.weight')
+ # new_item = new_item.replace('norm.bias', 'group_norm.bias')
+
+ # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
+ # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
+
+ # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
+
+ mapping.append({"old": old_item, "new": new_item})
+
+ return mapping
+
+
+def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
+ """
+ Updates paths inside attentions to the new naming scheme (local renaming)
+ """
+ mapping = []
+ for old_item in old_list:
+ new_item = old_item
+
+ new_item = new_item.replace("norm.weight", "group_norm.weight")
+ new_item = new_item.replace("norm.bias", "group_norm.bias")
+
+ new_item = new_item.replace("q.weight", "to_q.weight")
+ new_item = new_item.replace("q.bias", "to_q.bias")
+
+ new_item = new_item.replace("k.weight", "to_k.weight")
+ new_item = new_item.replace("k.bias", "to_k.bias")
+
+ new_item = new_item.replace("v.weight", "to_v.weight")
+ new_item = new_item.replace("v.bias", "to_v.bias")
+
+ new_item = new_item.replace("proj_out.weight", "to_out.0.weight")
+ new_item = new_item.replace("proj_out.bias", "to_out.0.bias")
+
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
+
+ mapping.append({"old": old_item, "new": new_item})
+
+ return mapping
+
+
+def assign_to_checkpoint(
+ paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
+):
+ """
+ This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits
+ attention layers, and takes into account additional replacements that may arise.
+
+ Assigns the weights to the new checkpoint.
+ """
+ assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
+
+ # Splits the attention layers into three variables.
+ if attention_paths_to_split is not None:
+ for path, path_map in attention_paths_to_split.items():
+ old_tensor = old_checkpoint[path]
+ channels = old_tensor.shape[0] // 3
+
+ target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
+
+ num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
+
+ old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
+ query, key, value = old_tensor.split(channels // num_heads, dim=1)
+
+ checkpoint[path_map["query"]] = query.reshape(target_shape)
+ checkpoint[path_map["key"]] = key.reshape(target_shape)
+ checkpoint[path_map["value"]] = value.reshape(target_shape)
+
+ for path in paths:
+ new_path = path["new"]
+
+ # These have already been assigned
+ if attention_paths_to_split is not None and new_path in attention_paths_to_split:
+ continue
+
+ # Global renaming happens here
+ new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
+ new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
+ new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
+
+ if additional_replacements is not None:
+ for replacement in additional_replacements:
+ new_path = new_path.replace(replacement["old"], replacement["new"])
+
+ # proj_attn.weight has to be converted from conv 1D to linear
+ is_attn_weight = "proj_attn.weight" in new_path or ("attentions" in new_path and "to_" in new_path)
+ shape = old_checkpoint[path["old"]].shape
+ if is_attn_weight and len(shape) == 3:
+ checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
+ elif is_attn_weight and len(shape) == 4:
+ checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0, 0]
+ else:
+ checkpoint[new_path] = old_checkpoint[path["old"]]
+
+
+def conv_attn_to_linear(checkpoint):
+ keys = list(checkpoint.keys())
+ attn_keys = ["query.weight", "key.weight", "value.weight"]
+ for key in keys:
+ if ".".join(key.split(".")[-2:]) in attn_keys:
+ if checkpoint[key].ndim > 2:
+ checkpoint[key] = checkpoint[key][:, :, 0, 0]
+ elif "proj_attn.weight" in key:
+ if checkpoint[key].ndim > 2:
+ checkpoint[key] = checkpoint[key][:, :, 0]
+
+
+def create_unet_diffusers_config(original_config, image_size: int, controlnet=False):
+ """
+ Creates a config for the diffusers based on the config of the LDM model.
+ """
+ if controlnet:
+ unet_params = original_config["model"]["params"]["control_stage_config"]["params"]
+ else:
+ if (
+ "unet_config" in original_config["model"]["params"]
+ and original_config["model"]["params"]["unet_config"] is not None
+ ):
+ unet_params = original_config["model"]["params"]["unet_config"]["params"]
+ else:
+ unet_params = original_config["model"]["params"]["network_config"]["params"]
+
+ vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"]
+
+ block_out_channels = [unet_params["model_channels"] * mult for mult in unet_params["channel_mult"]]
+
+ down_block_types = []
+ resolution = 1
+ for i in range(len(block_out_channels)):
+ block_type = "CrossAttnDownBlock2D" if resolution in unet_params["attention_resolutions"] else "DownBlock2D"
+ down_block_types.append(block_type)
+ if i != len(block_out_channels) - 1:
+ resolution *= 2
+
+ up_block_types = []
+ for i in range(len(block_out_channels)):
+ block_type = "CrossAttnUpBlock2D" if resolution in unet_params["attention_resolutions"] else "UpBlock2D"
+ up_block_types.append(block_type)
+ resolution //= 2
+
+ if unet_params["transformer_depth"] is not None:
+ transformer_layers_per_block = (
+ unet_params["transformer_depth"]
+ if isinstance(unet_params["transformer_depth"], int)
+ else list(unet_params["transformer_depth"])
+ )
+ else:
+ transformer_layers_per_block = 1
+
+ vae_scale_factor = 2 ** (len(vae_params["ch_mult"]) - 1)
+
+ head_dim = unet_params["num_heads"] if "num_heads" in unet_params else None
+ use_linear_projection = (
+ unet_params["use_linear_in_transformer"] if "use_linear_in_transformer" in unet_params else False
+ )
+ if use_linear_projection:
+ # stable diffusion 2-base-512 and 2-768
+ if head_dim is None:
+ head_dim_mult = unet_params["model_channels"] // unet_params["num_head_channels"]
+ head_dim = [head_dim_mult * c for c in list(unet_params["channel_mult"])]
+
+ class_embed_type = None
+ addition_embed_type = None
+ addition_time_embed_dim = None
+ projection_class_embeddings_input_dim = None
+ context_dim = None
+
+ if unet_params["context_dim"] is not None:
+ context_dim = (
+ unet_params["context_dim"]
+ if isinstance(unet_params["context_dim"], int)
+ else unet_params["context_dim"][0]
+ )
+
+ if "num_classes" in unet_params:
+ if unet_params["num_classes"] == "sequential":
+ if context_dim in [2048, 1280]:
+ # SDXL
+ addition_embed_type = "text_time"
+ addition_time_embed_dim = 256
+ else:
+ class_embed_type = "projection"
+ assert "adm_in_channels" in unet_params
+ projection_class_embeddings_input_dim = unet_params["adm_in_channels"]
+
+ config = {
+ "sample_size": image_size // vae_scale_factor,
+ "in_channels": unet_params["in_channels"],
+ "down_block_types": tuple(down_block_types),
+ "block_out_channels": tuple(block_out_channels),
+ "layers_per_block": unet_params["num_res_blocks"],
+ "cross_attention_dim": context_dim,
+ "attention_head_dim": head_dim,
+ "use_linear_projection": use_linear_projection,
+ "class_embed_type": class_embed_type,
+ "addition_embed_type": addition_embed_type,
+ "addition_time_embed_dim": addition_time_embed_dim,
+ "projection_class_embeddings_input_dim": projection_class_embeddings_input_dim,
+ "transformer_layers_per_block": transformer_layers_per_block,
+ }
+
+ if "disable_self_attentions" in unet_params:
+ config["only_cross_attention"] = unet_params["disable_self_attentions"]
+
+ if "num_classes" in unet_params and isinstance(unet_params["num_classes"], int):
+ config["num_class_embeds"] = unet_params["num_classes"]
+
+ if controlnet:
+ config["conditioning_channels"] = unet_params["hint_channels"]
+ else:
+ config["out_channels"] = unet_params["out_channels"]
+ config["up_block_types"] = tuple(up_block_types)
+
+ return config
+
+
+def create_vae_diffusers_config(original_config, image_size: int):
+ """
+ Creates a config for the diffusers based on the config of the LDM model.
+ """
+ vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"]
+ _ = original_config["model"]["params"]["first_stage_config"]["params"]["embed_dim"]
+
+ block_out_channels = [vae_params["ch"] * mult for mult in vae_params["ch_mult"]]
+ down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
+ up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
+
+ config = {
+ "sample_size": image_size,
+ "in_channels": vae_params["in_channels"],
+ "out_channels": vae_params["out_ch"],
+ "down_block_types": tuple(down_block_types),
+ "up_block_types": tuple(up_block_types),
+ "block_out_channels": tuple(block_out_channels),
+ "latent_channels": vae_params["z_channels"],
+ "layers_per_block": vae_params["num_res_blocks"],
+ }
+ return config
+
+
+def create_diffusers_schedular(original_config):
+ schedular = DDIMScheduler(
+ num_train_timesteps=original_config["model"]["params"]["timesteps"],
+ beta_start=original_config["model"]["params"]["linear_start"],
+ beta_end=original_config["model"]["params"]["linear_end"],
+ beta_schedule="scaled_linear",
+ )
+ return schedular
+
+
+def create_ldm_bert_config(original_config):
+ bert_params = original_config["model"]["params"]["cond_stage_config"]["params"]
+ config = LDMBertConfig(
+ d_model=bert_params.n_embed,
+ encoder_layers=bert_params.n_layer,
+ encoder_ffn_dim=bert_params.n_embed * 4,
+ )
+ return config
+
+
+def convert_ldm_unet_checkpoint(
+ checkpoint,
+ config,
+ path=None,
+ extract_ema=False,
+ controlnet=False,
+ skip_extract_state_dict=False,
+ promptdiffusion=False,
+):
+ """
+ Takes a state dict and a config, and returns a converted checkpoint.
+ """
+
+ if skip_extract_state_dict:
+ unet_state_dict = checkpoint
+ else:
+ # extract state_dict for UNet
+ unet_state_dict = {}
+ keys = list(checkpoint.keys())
+
+ if controlnet:
+ unet_key = "control_model."
+ else:
+ unet_key = "model.diffusion_model."
+
+ # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
+ if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema:
+ logger.warning(f"Checkpoint {path} has both EMA and non-EMA weights.")
+ logger.warning(
+ "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA"
+ " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag."
+ )
+ for key in keys:
+ if key.startswith("model.diffusion_model"):
+ flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
+ unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key)
+ else:
+ if sum(k.startswith("model_ema") for k in keys) > 100:
+ logger.warning(
+ "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA"
+ " weights (usually better for inference), please make sure to add the `--extract_ema` flag."
+ )
+
+ for key in keys:
+ if key.startswith(unet_key):
+ unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
+
+ new_checkpoint = {}
+
+ new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
+ new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
+ new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
+ new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
+
+ if config["class_embed_type"] is None:
+ # No parameters to port
+ ...
+ elif config["class_embed_type"] == "timestep" or config["class_embed_type"] == "projection":
+ new_checkpoint["class_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"]
+ new_checkpoint["class_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"]
+ new_checkpoint["class_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"]
+ new_checkpoint["class_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"]
+ else:
+ raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}")
+
+ if config["addition_embed_type"] == "text_time":
+ new_checkpoint["add_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"]
+ new_checkpoint["add_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"]
+ new_checkpoint["add_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"]
+ new_checkpoint["add_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"]
+
+ # Relevant to StableDiffusionUpscalePipeline
+ if "num_class_embeds" in config:
+ if (config["num_class_embeds"] is not None) and ("label_emb.weight" in unet_state_dict):
+ new_checkpoint["class_embedding.weight"] = unet_state_dict["label_emb.weight"]
+
+ new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
+ new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
+
+ if not controlnet:
+ new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
+ new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
+ new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
+ new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
+
+ # Retrieves the keys for the input blocks only
+ num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
+ input_blocks = {
+ layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key]
+ for layer_id in range(num_input_blocks)
+ }
+
+ # Retrieves the keys for the middle blocks only
+ num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
+ middle_blocks = {
+ layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
+ for layer_id in range(num_middle_blocks)
+ }
+
+ # Retrieves the keys for the output blocks only
+ num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
+ output_blocks = {
+ layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key]
+ for layer_id in range(num_output_blocks)
+ }
+
+ for i in range(1, num_input_blocks):
+ block_id = (i - 1) // (config["layers_per_block"] + 1)
+ layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
+
+ resnets = [
+ key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
+ ]
+ attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
+
+ if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
+ f"input_blocks.{i}.0.op.weight"
+ )
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
+ f"input_blocks.{i}.0.op.bias"
+ )
+
+ paths = renew_resnet_paths(resnets)
+ meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
+ assign_to_checkpoint(
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
+ )
+
+ if len(attentions):
+ paths = renew_attention_paths(attentions)
+
+ meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
+ assign_to_checkpoint(
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
+ )
+
+ resnet_0 = middle_blocks[0]
+ attentions = middle_blocks[1]
+ resnet_1 = middle_blocks[2]
+
+ resnet_0_paths = renew_resnet_paths(resnet_0)
+ assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
+
+ resnet_1_paths = renew_resnet_paths(resnet_1)
+ assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
+
+ attentions_paths = renew_attention_paths(attentions)
+ meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
+ assign_to_checkpoint(
+ attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
+ )
+
+ for i in range(num_output_blocks):
+ block_id = i // (config["layers_per_block"] + 1)
+ layer_in_block_id = i % (config["layers_per_block"] + 1)
+ output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
+ output_block_list = {}
+
+ for layer in output_block_layers:
+ layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
+ if layer_id in output_block_list:
+ output_block_list[layer_id].append(layer_name)
+ else:
+ output_block_list[layer_id] = [layer_name]
+
+ if len(output_block_list) > 1:
+ resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
+ attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
+
+ resnet_0_paths = renew_resnet_paths(resnets)
+ paths = renew_resnet_paths(resnets)
+
+ meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
+ assign_to_checkpoint(
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
+ )
+
+ output_block_list = {k: sorted(v) for k, v in output_block_list.items()}
+ if ["conv.bias", "conv.weight"] in output_block_list.values():
+ index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
+ f"output_blocks.{i}.{index}.conv.weight"
+ ]
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
+ f"output_blocks.{i}.{index}.conv.bias"
+ ]
+
+ # Clear attentions as they have been attributed above.
+ if len(attentions) == 2:
+ attentions = []
+
+ if len(attentions):
+ paths = renew_attention_paths(attentions)
+ meta_path = {
+ "old": f"output_blocks.{i}.1",
+ "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
+ }
+ assign_to_checkpoint(
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
+ )
+ else:
+ resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
+ for path in resnet_0_paths:
+ old_path = ".".join(["output_blocks", str(i), path["old"]])
+ new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
+
+ new_checkpoint[new_path] = unet_state_dict[old_path]
+
+ if controlnet and not promptdiffusion:
+ # conditioning embedding
+
+ orig_index = 0
+
+ new_checkpoint["controlnet_cond_embedding.conv_in.weight"] = unet_state_dict.pop(
+ f"input_hint_block.{orig_index}.weight"
+ )
+ new_checkpoint["controlnet_cond_embedding.conv_in.bias"] = unet_state_dict.pop(
+ f"input_hint_block.{orig_index}.bias"
+ )
+
+ orig_index += 2
+
+ diffusers_index = 0
+
+ while diffusers_index < 6:
+ new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.weight"] = unet_state_dict.pop(
+ f"input_hint_block.{orig_index}.weight"
+ )
+ new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.bias"] = unet_state_dict.pop(
+ f"input_hint_block.{orig_index}.bias"
+ )
+ diffusers_index += 1
+ orig_index += 2
+
+ new_checkpoint["controlnet_cond_embedding.conv_out.weight"] = unet_state_dict.pop(
+ f"input_hint_block.{orig_index}.weight"
+ )
+ new_checkpoint["controlnet_cond_embedding.conv_out.bias"] = unet_state_dict.pop(
+ f"input_hint_block.{orig_index}.bias"
+ )
+
+ # down blocks
+ for i in range(num_input_blocks):
+ new_checkpoint[f"controlnet_down_blocks.{i}.weight"] = unet_state_dict.pop(f"zero_convs.{i}.0.weight")
+ new_checkpoint[f"controlnet_down_blocks.{i}.bias"] = unet_state_dict.pop(f"zero_convs.{i}.0.bias")
+
+ # mid block
+ new_checkpoint["controlnet_mid_block.weight"] = unet_state_dict.pop("middle_block_out.0.weight")
+ new_checkpoint["controlnet_mid_block.bias"] = unet_state_dict.pop("middle_block_out.0.bias")
+
+ if promptdiffusion:
+ # conditioning embedding
+
+ orig_index = 0
+
+ new_checkpoint["controlnet_cond_embedding.conv_in.weight"] = unet_state_dict.pop(
+ f"input_hint_block.{orig_index}.weight"
+ )
+ new_checkpoint["controlnet_cond_embedding.conv_in.bias"] = unet_state_dict.pop(
+ f"input_hint_block.{orig_index}.bias"
+ )
+
+ new_checkpoint["controlnet_query_cond_embedding.conv_in.weight"] = unet_state_dict.pop(
+ f"input_cond_block.{orig_index}.weight"
+ )
+ new_checkpoint["controlnet_query_cond_embedding.conv_in.bias"] = unet_state_dict.pop(
+ f"input_cond_block.{orig_index}.bias"
+ )
+ orig_index += 2
+
+ diffusers_index = 0
+
+ while diffusers_index < 6:
+ new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.weight"] = unet_state_dict.pop(
+ f"input_hint_block.{orig_index}.weight"
+ )
+ new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.bias"] = unet_state_dict.pop(
+ f"input_hint_block.{orig_index}.bias"
+ )
+ new_checkpoint[f"controlnet_query_cond_embedding.blocks.{diffusers_index}.weight"] = unet_state_dict.pop(
+ f"input_cond_block.{orig_index}.weight"
+ )
+ new_checkpoint[f"controlnet_query_cond_embedding.blocks.{diffusers_index}.bias"] = unet_state_dict.pop(
+ f"input_cond_block.{orig_index}.bias"
+ )
+ diffusers_index += 1
+ orig_index += 2
+
+ new_checkpoint["controlnet_cond_embedding.conv_out.weight"] = unet_state_dict.pop(
+ f"input_hint_block.{orig_index}.weight"
+ )
+ new_checkpoint["controlnet_cond_embedding.conv_out.bias"] = unet_state_dict.pop(
+ f"input_hint_block.{orig_index}.bias"
+ )
+
+ new_checkpoint["controlnet_query_cond_embedding.conv_out.weight"] = unet_state_dict.pop(
+ f"input_cond_block.{orig_index}.weight"
+ )
+ new_checkpoint["controlnet_query_cond_embedding.conv_out.bias"] = unet_state_dict.pop(
+ f"input_cond_block.{orig_index}.bias"
+ )
+ # down blocks
+ for i in range(num_input_blocks):
+ new_checkpoint[f"controlnet_down_blocks.{i}.weight"] = unet_state_dict.pop(f"zero_convs.{i}.0.weight")
+ new_checkpoint[f"controlnet_down_blocks.{i}.bias"] = unet_state_dict.pop(f"zero_convs.{i}.0.bias")
+
+ # mid block
+ new_checkpoint["controlnet_mid_block.weight"] = unet_state_dict.pop("middle_block_out.0.weight")
+ new_checkpoint["controlnet_mid_block.bias"] = unet_state_dict.pop("middle_block_out.0.bias")
+
+ return new_checkpoint
+
+
+def convert_ldm_vae_checkpoint(checkpoint, config):
+ # extract state dict for VAE
+ vae_state_dict = {}
+ keys = list(checkpoint.keys())
+ vae_key = "first_stage_model." if any(k.startswith("first_stage_model.") for k in keys) else ""
+ for key in keys:
+ if key.startswith(vae_key):
+ vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
+
+ new_checkpoint = {}
+
+ new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
+ new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
+ new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
+ new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
+ new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
+ new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
+
+ new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
+ new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
+ new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
+ new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
+ new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
+ new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
+
+ new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
+ new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
+ new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
+ new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
+
+ # Retrieves the keys for the encoder down blocks only
+ num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
+ down_blocks = {
+ layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
+ }
+
+ # Retrieves the keys for the decoder up blocks only
+ num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
+ up_blocks = {
+ layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
+ }
+
+ for i in range(num_down_blocks):
+ resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
+
+ if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
+ f"encoder.down.{i}.downsample.conv.weight"
+ )
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
+ f"encoder.down.{i}.downsample.conv.bias"
+ )
+
+ paths = renew_vae_resnet_paths(resnets)
+ meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
+
+ mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
+ num_mid_res_blocks = 2
+ for i in range(1, num_mid_res_blocks + 1):
+ resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
+
+ paths = renew_vae_resnet_paths(resnets)
+ meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
+
+ mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
+ paths = renew_vae_attention_paths(mid_attentions)
+ meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
+ conv_attn_to_linear(new_checkpoint)
+
+ for i in range(num_up_blocks):
+ block_id = num_up_blocks - 1 - i
+ resnets = [
+ key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
+ ]
+
+ if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
+ f"decoder.up.{block_id}.upsample.conv.weight"
+ ]
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
+ f"decoder.up.{block_id}.upsample.conv.bias"
+ ]
+
+ paths = renew_vae_resnet_paths(resnets)
+ meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
+
+ mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
+ num_mid_res_blocks = 2
+ for i in range(1, num_mid_res_blocks + 1):
+ resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
+
+ paths = renew_vae_resnet_paths(resnets)
+ meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
+
+ mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
+ paths = renew_vae_attention_paths(mid_attentions)
+ meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
+ conv_attn_to_linear(new_checkpoint)
+ return new_checkpoint
+
+
+def convert_ldm_bert_checkpoint(checkpoint, config):
+ def _copy_attn_layer(hf_attn_layer, pt_attn_layer):
+ hf_attn_layer.q_proj.weight.data = pt_attn_layer.to_q.weight
+ hf_attn_layer.k_proj.weight.data = pt_attn_layer.to_k.weight
+ hf_attn_layer.v_proj.weight.data = pt_attn_layer.to_v.weight
+
+ hf_attn_layer.out_proj.weight = pt_attn_layer.to_out.weight
+ hf_attn_layer.out_proj.bias = pt_attn_layer.to_out.bias
+
+ def _copy_linear(hf_linear, pt_linear):
+ hf_linear.weight = pt_linear.weight
+ hf_linear.bias = pt_linear.bias
+
+ def _copy_layer(hf_layer, pt_layer):
+ # copy layer norms
+ _copy_linear(hf_layer.self_attn_layer_norm, pt_layer[0][0])
+ _copy_linear(hf_layer.final_layer_norm, pt_layer[1][0])
+
+ # copy attn
+ _copy_attn_layer(hf_layer.self_attn, pt_layer[0][1])
+
+ # copy MLP
+ pt_mlp = pt_layer[1][1]
+ _copy_linear(hf_layer.fc1, pt_mlp.net[0][0])
+ _copy_linear(hf_layer.fc2, pt_mlp.net[2])
+
+ def _copy_layers(hf_layers, pt_layers):
+ for i, hf_layer in enumerate(hf_layers):
+ if i != 0:
+ i += i
+ pt_layer = pt_layers[i : i + 2]
+ _copy_layer(hf_layer, pt_layer)
+
+ hf_model = LDMBertModel(config).eval()
+
+ # copy embeds
+ hf_model.model.embed_tokens.weight = checkpoint.transformer.token_emb.weight
+ hf_model.model.embed_positions.weight.data = checkpoint.transformer.pos_emb.emb.weight
+
+ # copy layer norm
+ _copy_linear(hf_model.model.layer_norm, checkpoint.transformer.norm)
+
+ # copy hidden layers
+ _copy_layers(hf_model.model.layers, checkpoint.transformer.attn_layers.layers)
+
+ _copy_linear(hf_model.to_logits, checkpoint.transformer.to_logits)
+
+ return hf_model
+
+
+def convert_ldm_clip_checkpoint(checkpoint, local_files_only=False, text_encoder=None):
+ if text_encoder is None:
+ config_name = "openai/clip-vit-large-patch14"
+ try:
+ config = CLIPTextConfig.from_pretrained(config_name, local_files_only=local_files_only)
+ except Exception:
+ raise ValueError(
+ f"With local_files_only set to {local_files_only}, you must first locally save the configuration in the following path: 'openai/clip-vit-large-patch14'."
+ )
+
+ ctx = init_empty_weights if is_accelerate_available() else nullcontext
+ with ctx():
+ text_model = CLIPTextModel(config)
+ else:
+ text_model = text_encoder
+
+ keys = list(checkpoint.keys())
+
+ text_model_dict = {}
+
+ remove_prefixes = ["cond_stage_model.transformer", "conditioner.embedders.0.transformer"]
+
+ for key in keys:
+ for prefix in remove_prefixes:
+ if key.startswith(prefix):
+ text_model_dict[key[len(prefix + ".") :]] = checkpoint[key]
+
+ if is_accelerate_available():
+ for param_name, param in text_model_dict.items():
+ set_module_tensor_to_device(text_model, param_name, "cpu", value=param)
+ else:
+ if not (hasattr(text_model, "embeddings") and hasattr(text_model.embeddings.position_ids)):
+ text_model_dict.pop("text_model.embeddings.position_ids", None)
+
+ text_model.load_state_dict(text_model_dict)
+
+ return text_model
+
+
+textenc_conversion_lst = [
+ ("positional_embedding", "text_model.embeddings.position_embedding.weight"),
+ ("token_embedding.weight", "text_model.embeddings.token_embedding.weight"),
+ ("ln_final.weight", "text_model.final_layer_norm.weight"),
+ ("ln_final.bias", "text_model.final_layer_norm.bias"),
+ ("text_projection", "text_projection.weight"),
+]
+textenc_conversion_map = {x[0]: x[1] for x in textenc_conversion_lst}
+
+textenc_transformer_conversion_lst = [
+ # (stable-diffusion, HF Diffusers)
+ ("resblocks.", "text_model.encoder.layers."),
+ ("ln_1", "layer_norm1"),
+ ("ln_2", "layer_norm2"),
+ (".c_fc.", ".fc1."),
+ (".c_proj.", ".fc2."),
+ (".attn", ".self_attn"),
+ ("ln_final.", "transformer.text_model.final_layer_norm."),
+ ("token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"),
+ ("positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"),
+]
+protected = {re.escape(x[0]): x[1] for x in textenc_transformer_conversion_lst}
+textenc_pattern = re.compile("|".join(protected.keys()))
+
+
+def convert_paint_by_example_checkpoint(checkpoint, local_files_only=False):
+ config = CLIPVisionConfig.from_pretrained("openai/clip-vit-large-patch14", local_files_only=local_files_only)
+ model = PaintByExampleImageEncoder(config)
+
+ keys = list(checkpoint.keys())
+
+ text_model_dict = {}
+
+ for key in keys:
+ if key.startswith("cond_stage_model.transformer"):
+ text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
+
+ # load clip vision
+ model.model.load_state_dict(text_model_dict)
+
+ # load mapper
+ keys_mapper = {
+ k[len("cond_stage_model.mapper.res") :]: v
+ for k, v in checkpoint.items()
+ if k.startswith("cond_stage_model.mapper")
+ }
+
+ MAPPING = {
+ "attn.c_qkv": ["attn1.to_q", "attn1.to_k", "attn1.to_v"],
+ "attn.c_proj": ["attn1.to_out.0"],
+ "ln_1": ["norm1"],
+ "ln_2": ["norm3"],
+ "mlp.c_fc": ["ff.net.0.proj"],
+ "mlp.c_proj": ["ff.net.2"],
+ }
+
+ mapped_weights = {}
+ for key, value in keys_mapper.items():
+ prefix = key[: len("blocks.i")]
+ suffix = key.split(prefix)[-1].split(".")[-1]
+ name = key.split(prefix)[-1].split(suffix)[0][1:-1]
+ mapped_names = MAPPING[name]
+
+ num_splits = len(mapped_names)
+ for i, mapped_name in enumerate(mapped_names):
+ new_name = ".".join([prefix, mapped_name, suffix])
+ shape = value.shape[0] // num_splits
+ mapped_weights[new_name] = value[i * shape : (i + 1) * shape]
+
+ model.mapper.load_state_dict(mapped_weights)
+
+ # load final layer norm
+ model.final_layer_norm.load_state_dict(
+ {
+ "bias": checkpoint["cond_stage_model.final_ln.bias"],
+ "weight": checkpoint["cond_stage_model.final_ln.weight"],
+ }
+ )
+
+ # load final proj
+ model.proj_out.load_state_dict(
+ {
+ "bias": checkpoint["proj_out.bias"],
+ "weight": checkpoint["proj_out.weight"],
+ }
+ )
+
+ # load uncond vector
+ model.uncond_vector.data = torch.nn.Parameter(checkpoint["learnable_vector"])
+ return model
+
+
+def convert_open_clip_checkpoint(
+ checkpoint,
+ config_name,
+ prefix="cond_stage_model.model.",
+ has_projection=False,
+ local_files_only=False,
+ **config_kwargs,
+):
+ # text_model = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2", subfolder="text_encoder")
+ # text_model = CLIPTextModelWithProjection.from_pretrained(
+ # "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", projection_dim=1280
+ # )
+ try:
+ config = CLIPTextConfig.from_pretrained(config_name, **config_kwargs, local_files_only=local_files_only)
+ except Exception:
+ raise ValueError(
+ f"With local_files_only set to {local_files_only}, you must first locally save the configuration in the following path: '{config_name}'."
+ )
+
+ ctx = init_empty_weights if is_accelerate_available() else nullcontext
+ with ctx():
+ text_model = CLIPTextModelWithProjection(config) if has_projection else CLIPTextModel(config)
+
+ keys = list(checkpoint.keys())
+
+ keys_to_ignore = []
+ if config_name == "stabilityai/stable-diffusion-2" and config.num_hidden_layers == 23:
+ # make sure to remove all keys > 22
+ keys_to_ignore += [k for k in keys if k.startswith("cond_stage_model.model.transformer.resblocks.23")]
+ keys_to_ignore += ["cond_stage_model.model.text_projection"]
+
+ text_model_dict = {}
+
+ if prefix + "text_projection" in checkpoint:
+ d_model = int(checkpoint[prefix + "text_projection"].shape[0])
+ else:
+ d_model = 1024
+
+ text_model_dict["text_model.embeddings.position_ids"] = text_model.text_model.embeddings.get_buffer("position_ids")
+
+ for key in keys:
+ if key in keys_to_ignore:
+ continue
+ if key[len(prefix) :] in textenc_conversion_map:
+ if key.endswith("text_projection"):
+ value = checkpoint[key].T.contiguous()
+ else:
+ value = checkpoint[key]
+
+ text_model_dict[textenc_conversion_map[key[len(prefix) :]]] = value
+
+ if key.startswith(prefix + "transformer."):
+ new_key = key[len(prefix + "transformer.") :]
+ if new_key.endswith(".in_proj_weight"):
+ new_key = new_key[: -len(".in_proj_weight")]
+ new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key)
+ text_model_dict[new_key + ".q_proj.weight"] = checkpoint[key][:d_model, :]
+ text_model_dict[new_key + ".k_proj.weight"] = checkpoint[key][d_model : d_model * 2, :]
+ text_model_dict[new_key + ".v_proj.weight"] = checkpoint[key][d_model * 2 :, :]
+ elif new_key.endswith(".in_proj_bias"):
+ new_key = new_key[: -len(".in_proj_bias")]
+ new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key)
+ text_model_dict[new_key + ".q_proj.bias"] = checkpoint[key][:d_model]
+ text_model_dict[new_key + ".k_proj.bias"] = checkpoint[key][d_model : d_model * 2]
+ text_model_dict[new_key + ".v_proj.bias"] = checkpoint[key][d_model * 2 :]
+ else:
+ new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key)
+
+ text_model_dict[new_key] = checkpoint[key]
+
+ if is_accelerate_available():
+ for param_name, param in text_model_dict.items():
+ set_module_tensor_to_device(text_model, param_name, "cpu", value=param)
+ else:
+ if not (hasattr(text_model, "embeddings") and hasattr(text_model.embeddings.position_ids)):
+ text_model_dict.pop("text_model.embeddings.position_ids", None)
+
+ text_model.load_state_dict(text_model_dict)
+
+ return text_model
+
+
+def stable_unclip_image_encoder(original_config, local_files_only=False):
+ """
+ Returns the image processor and clip image encoder for the img2img unclip pipeline.
+
+ We currently know of two types of stable unclip models which separately use the clip and the openclip image
+ encoders.
+ """
+
+ image_embedder_config = original_config["model"]["params"]["embedder_config"]
+
+ sd_clip_image_embedder_class = image_embedder_config["target"]
+ sd_clip_image_embedder_class = sd_clip_image_embedder_class.split(".")[-1]
+
+ if sd_clip_image_embedder_class == "ClipImageEmbedder":
+ clip_model_name = image_embedder_config.params.model
+
+ if clip_model_name == "ViT-L/14":
+ feature_extractor = CLIPImageProcessor()
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(
+ "openai/clip-vit-large-patch14", local_files_only=local_files_only
+ )
+ else:
+ raise NotImplementedError(f"Unknown CLIP checkpoint name in stable diffusion checkpoint {clip_model_name}")
+
+ elif sd_clip_image_embedder_class == "FrozenOpenCLIPImageEmbedder":
+ feature_extractor = CLIPImageProcessor()
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(
+ "laion/CLIP-ViT-H-14-laion2B-s32B-b79K", local_files_only=local_files_only
+ )
+ else:
+ raise NotImplementedError(
+ f"Unknown CLIP image embedder class in stable diffusion checkpoint {sd_clip_image_embedder_class}"
+ )
+
+ return feature_extractor, image_encoder
+
+
+def stable_unclip_image_noising_components(
+ original_config, clip_stats_path: Optional[str] = None, device: Optional[str] = None
+):
+ """
+ Returns the noising components for the img2img and txt2img unclip pipelines.
+
+ Converts the stability noise augmentor into
+ 1. a `StableUnCLIPImageNormalizer` for holding the CLIP stats
+ 2. a `DDPMScheduler` for holding the noise schedule
+
+ If the noise augmentor config specifies a clip stats path, the `clip_stats_path` must be provided.
+ """
+ noise_aug_config = original_config["model"]["params"]["noise_aug_config"]
+ noise_aug_class = noise_aug_config["target"]
+ noise_aug_class = noise_aug_class.split(".")[-1]
+
+ if noise_aug_class == "CLIPEmbeddingNoiseAugmentation":
+ noise_aug_config = noise_aug_config.params
+ embedding_dim = noise_aug_config.timestep_dim
+ max_noise_level = noise_aug_config.noise_schedule_config.timesteps
+ beta_schedule = noise_aug_config.noise_schedule_config.beta_schedule
+
+ image_normalizer = StableUnCLIPImageNormalizer(embedding_dim=embedding_dim)
+ image_noising_scheduler = DDPMScheduler(num_train_timesteps=max_noise_level, beta_schedule=beta_schedule)
+
+ if "clip_stats_path" in noise_aug_config:
+ if clip_stats_path is None:
+ raise ValueError("This stable unclip config requires a `clip_stats_path`")
+
+ clip_mean, clip_std = torch.load(clip_stats_path, map_location=device)
+ clip_mean = clip_mean[None, :]
+ clip_std = clip_std[None, :]
+
+ clip_stats_state_dict = {
+ "mean": clip_mean,
+ "std": clip_std,
+ }
+
+ image_normalizer.load_state_dict(clip_stats_state_dict)
+ else:
+ raise NotImplementedError(f"Unknown noise augmentor class: {noise_aug_class}")
+
+ return image_normalizer, image_noising_scheduler
+
+
+def convert_controlnet_checkpoint(
+ checkpoint,
+ original_config,
+ checkpoint_path,
+ image_size,
+ upcast_attention,
+ extract_ema,
+ use_linear_projection=None,
+ cross_attention_dim=None,
+):
+ ctrlnet_config = create_unet_diffusers_config(original_config, image_size=image_size, controlnet=True)
+ ctrlnet_config["upcast_attention"] = upcast_attention
+
+ ctrlnet_config.pop("sample_size")
+
+ if use_linear_projection is not None:
+ ctrlnet_config["use_linear_projection"] = use_linear_projection
+
+ if cross_attention_dim is not None:
+ ctrlnet_config["cross_attention_dim"] = cross_attention_dim
+
+ ctx = init_empty_weights if is_accelerate_available() else nullcontext
+ with ctx():
+ controlnet = ControlNetModel(**ctrlnet_config)
+
+ # Some controlnet ckpt files are distributed independently from the rest of the
+ # model components i.e. https://huggingface.co/thibaud/controlnet-sd21/
+ if "time_embed.0.weight" in checkpoint:
+ skip_extract_state_dict = True
+ else:
+ skip_extract_state_dict = False
+
+ converted_ctrl_checkpoint = convert_ldm_unet_checkpoint(
+ checkpoint,
+ ctrlnet_config,
+ path=checkpoint_path,
+ extract_ema=extract_ema,
+ controlnet=True,
+ skip_extract_state_dict=skip_extract_state_dict,
+ )
+
+ if is_accelerate_available():
+ for param_name, param in converted_ctrl_checkpoint.items():
+ set_module_tensor_to_device(controlnet, param_name, "cpu", value=param)
+ else:
+ controlnet.load_state_dict(converted_ctrl_checkpoint)
+
+ return controlnet
+
+
+def convert_promptdiffusion_checkpoint(
+ checkpoint,
+ original_config,
+ checkpoint_path,
+ image_size,
+ upcast_attention,
+ extract_ema,
+ use_linear_projection=None,
+ cross_attention_dim=None,
+):
+ ctrlnet_config = create_unet_diffusers_config(original_config, image_size=image_size, controlnet=True)
+ ctrlnet_config["upcast_attention"] = upcast_attention
+
+ ctrlnet_config.pop("sample_size")
+
+ if use_linear_projection is not None:
+ ctrlnet_config["use_linear_projection"] = use_linear_projection
+
+ if cross_attention_dim is not None:
+ ctrlnet_config["cross_attention_dim"] = cross_attention_dim
+
+ ctx = init_empty_weights if is_accelerate_available() else nullcontext
+ with ctx():
+ controlnet = PromptDiffusionControlNetModel(**ctrlnet_config)
+
+ # Some controlnet ckpt files are distributed independently from the rest of the
+ # model components i.e. https://huggingface.co/thibaud/controlnet-sd21/
+ if "time_embed.0.weight" in checkpoint:
+ skip_extract_state_dict = True
+ else:
+ skip_extract_state_dict = False
+
+ converted_ctrl_checkpoint = convert_ldm_unet_checkpoint(
+ checkpoint,
+ ctrlnet_config,
+ path=checkpoint_path,
+ extract_ema=extract_ema,
+ promptdiffusion=True,
+ controlnet=True,
+ skip_extract_state_dict=skip_extract_state_dict,
+ )
+
+ if is_accelerate_available():
+ for param_name, param in converted_ctrl_checkpoint.items():
+ set_module_tensor_to_device(controlnet, param_name, "cpu", value=param)
+ else:
+ controlnet.load_state_dict(converted_ctrl_checkpoint)
+
+ return controlnet
+
+
+def download_from_original_stable_diffusion_ckpt(
+ checkpoint_path_or_dict: Union[str, Dict[str, torch.Tensor]],
+ original_config_file: str = None,
+ image_size: Optional[int] = None,
+ prediction_type: str = None,
+ model_type: str = None,
+ extract_ema: bool = False,
+ scheduler_type: str = "pndm",
+ num_in_channels: Optional[int] = None,
+ upcast_attention: Optional[bool] = None,
+ device: str = None,
+ from_safetensors: bool = False,
+ stable_unclip: Optional[str] = None,
+ stable_unclip_prior: Optional[str] = None,
+ clip_stats_path: Optional[str] = None,
+ controlnet: Optional[bool] = None,
+ adapter: Optional[bool] = None,
+ load_safety_checker: bool = True,
+ pipeline_class: DiffusionPipeline = None,
+ local_files_only=False,
+ vae_path=None,
+ vae=None,
+ text_encoder=None,
+ text_encoder_2=None,
+ tokenizer=None,
+ tokenizer_2=None,
+ config_files=None,
+) -> DiffusionPipeline:
+ """
+ Load a Stable Diffusion pipeline object from a CompVis-style `.ckpt`/`.safetensors` file and (ideally) a `.yaml`
+ config file.
+
+ Although many of the arguments can be automatically inferred, some of these rely on brittle checks against the
+ global step count, which will likely fail for models that have undergone further fine-tuning. Therefore, it is
+ recommended that you override the default values and/or supply an `original_config_file` wherever possible.
+
+ Args:
+ checkpoint_path_or_dict (`str` or `dict`): Path to `.ckpt` file, or the state dict.
+ original_config_file (`str`):
+ Path to `.yaml` config file corresponding to the original architecture. If `None`, will be automatically
+ inferred by looking for a key that only exists in SD2.0 models.
+ image_size (`int`, *optional*, defaults to 512):
+ The image size that the model was trained on. Use 512 for Stable Diffusion v1.X and Stable Diffusion v2
+ Base. Use 768 for Stable Diffusion v2.
+ prediction_type (`str`, *optional*):
+ The prediction type that the model was trained on. Use `'epsilon'` for Stable Diffusion v1.X and Stable
+ Diffusion v2 Base. Use `'v_prediction'` for Stable Diffusion v2.
+ num_in_channels (`int`, *optional*, defaults to None):
+ The number of input channels. If `None`, it will be automatically inferred.
+ scheduler_type (`str`, *optional*, defaults to 'pndm'):
+ Type of scheduler to use. Should be one of `["pndm", "lms", "heun", "euler", "euler-ancestral", "dpm",
+ "ddim"]`.
+ model_type (`str`, *optional*, defaults to `None`):
+ The pipeline type. `None` to automatically infer, or one of `["FrozenOpenCLIPEmbedder",
+ "FrozenCLIPEmbedder", "PaintByExample"]`.
+ is_img2img (`bool`, *optional*, defaults to `False`):
+ Whether the model should be loaded as an img2img pipeline.
+ extract_ema (`bool`, *optional*, defaults to `False`): Only relevant for
+ checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights or not. Defaults to
+ `False`. Pass `True` to extract the EMA weights. EMA weights usually yield higher quality images for
+ inference. Non-EMA weights are usually better to continue fine-tuning.
+ upcast_attention (`bool`, *optional*, defaults to `None`):
+ Whether the attention computation should always be upcasted. This is necessary when running stable
+ diffusion 2.1.
+ device (`str`, *optional*, defaults to `None`):
+ The device to use. Pass `None` to determine automatically.
+ from_safetensors (`str`, *optional*, defaults to `False`):
+ If `checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch.
+ load_safety_checker (`bool`, *optional*, defaults to `True`):
+ Whether to load the safety checker or not. Defaults to `True`.
+ pipeline_class (`str`, *optional*, defaults to `None`):
+ The pipeline class to use. Pass `None` to determine automatically.
+ local_files_only (`bool`, *optional*, defaults to `False`):
+ Whether or not to only look at local files (i.e., do not try to download the model).
+ vae (`AutoencoderKL`, *optional*, defaults to `None`):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. If
+ this parameter is `None`, the function will load a new instance of [CLIP] by itself, if needed.
+ text_encoder (`CLIPTextModel`, *optional*, defaults to `None`):
+ An instance of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel)
+ to use, specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)
+ variant. If this parameter is `None`, the function will load a new instance of [CLIP] by itself, if needed.
+ tokenizer (`CLIPTokenizer`, *optional*, defaults to `None`):
+ An instance of
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer)
+ to use. If this parameter is `None`, the function will load a new instance of [CLIPTokenizer] by itself, if
+ needed.
+ config_files (`Dict[str, str]`, *optional*, defaults to `None`):
+ A dictionary mapping from config file names to their contents. If this parameter is `None`, the function
+ will load the config files by itself, if needed. Valid keys are:
+ - `v1`: Config file for Stable Diffusion v1
+ - `v2`: Config file for Stable Diffusion v2
+ - `xl`: Config file for Stable Diffusion XL
+ - `xl_refiner`: Config file for Stable Diffusion XL Refiner
+ return: A StableDiffusionPipeline object representing the passed-in `.ckpt`/`.safetensors` file.
+ """
+
+ # import pipelines here to avoid circular import error when using from_single_file method
+ from diffusers import (
+ LDMTextToImagePipeline,
+ PaintByExamplePipeline,
+ StableDiffusionControlNetPipeline,
+ StableDiffusionInpaintPipeline,
+ StableDiffusionPipeline,
+ StableDiffusionUpscalePipeline,
+ StableDiffusionXLControlNetInpaintPipeline,
+ StableDiffusionXLImg2ImgPipeline,
+ StableDiffusionXLInpaintPipeline,
+ StableDiffusionXLPipeline,
+ StableUnCLIPImg2ImgPipeline,
+ StableUnCLIPPipeline,
+ )
+
+ if prediction_type == "v-prediction":
+ prediction_type = "v_prediction"
+
+ if isinstance(checkpoint_path_or_dict, str):
+ if from_safetensors:
+ from safetensors.torch import load_file as safe_load
+
+ checkpoint = safe_load(checkpoint_path_or_dict, device="cpu")
+ else:
+ if device is None:
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ checkpoint = torch.load(checkpoint_path_or_dict, map_location=device)
+ else:
+ checkpoint = torch.load(checkpoint_path_or_dict, map_location=device)
+ elif isinstance(checkpoint_path_or_dict, dict):
+ checkpoint = checkpoint_path_or_dict
+
+ # Sometimes models don't have the global_step item
+ if "global_step" in checkpoint:
+ global_step = checkpoint["global_step"]
+ else:
+ logger.debug("global_step key not found in model")
+ global_step = None
+
+ # NOTE: this while loop isn't great but this controlnet checkpoint has one additional
+ # "state_dict" key https://huggingface.co/thibaud/controlnet-canny-sd21
+ while "state_dict" in checkpoint:
+ checkpoint = checkpoint["state_dict"]
+
+ if original_config_file is None:
+ key_name_v2_1 = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
+ key_name_sd_xl_base = "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias"
+ key_name_sd_xl_refiner = "conditioner.embedders.0.model.transformer.resblocks.9.mlp.c_proj.bias"
+ is_upscale = pipeline_class == StableDiffusionUpscalePipeline
+
+ config_url = None
+
+ # model_type = "v1"
+ if config_files is not None and "v1" in config_files:
+ original_config_file = config_files["v1"]
+ else:
+ config_url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
+
+ if key_name_v2_1 in checkpoint and checkpoint[key_name_v2_1].shape[-1] == 1024:
+ # model_type = "v2"
+ if config_files is not None and "v2" in config_files:
+ original_config_file = config_files["v2"]
+ else:
+ config_url = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml"
+ if global_step == 110000:
+ # v2.1 needs to upcast attention
+ upcast_attention = True
+ elif key_name_sd_xl_base in checkpoint:
+ # only base xl has two text embedders
+ if config_files is not None and "xl" in config_files:
+ original_config_file = config_files["xl"]
+ else:
+ config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_base.yaml"
+ elif key_name_sd_xl_refiner in checkpoint:
+ # only refiner xl has embedder and one text embedders
+ if config_files is not None and "xl_refiner" in config_files:
+ original_config_file = config_files["xl_refiner"]
+ else:
+ config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_refiner.yaml"
+
+ if is_upscale:
+ config_url = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/x4-upscaling.yaml"
+
+ if config_url is not None:
+ original_config_file = BytesIO(requests.get(config_url).content)
+ else:
+ with open(original_config_file, "r") as f:
+ original_config_file = f.read()
+
+ original_config = yaml.safe_load(original_config_file)
+
+ # Convert the text model.
+ if (
+ model_type is None
+ and "cond_stage_config" in original_config["model"]["params"]
+ and original_config["model"]["params"]["cond_stage_config"] is not None
+ ):
+ model_type = original_config["model"]["params"]["cond_stage_config"]["target"].split(".")[-1]
+ logger.debug(f"no `model_type` given, `model_type` inferred as: {model_type}")
+ elif model_type is None and original_config["model"]["params"]["network_config"] is not None:
+ if original_config["model"]["params"]["network_config"]["params"]["context_dim"] == 2048:
+ model_type = "SDXL"
+ else:
+ model_type = "SDXL-Refiner"
+ if image_size is None:
+ image_size = 1024
+
+ if pipeline_class is None:
+ # Check if we have a SDXL or SD model and initialize default pipeline
+ if model_type not in ["SDXL", "SDXL-Refiner"]:
+ pipeline_class = StableDiffusionPipeline if not controlnet else StableDiffusionControlNetPipeline
+ else:
+ pipeline_class = StableDiffusionXLPipeline if model_type == "SDXL" else StableDiffusionXLImg2ImgPipeline
+
+ if num_in_channels is None and pipeline_class in [
+ StableDiffusionInpaintPipeline,
+ StableDiffusionXLInpaintPipeline,
+ StableDiffusionXLControlNetInpaintPipeline,
+ ]:
+ num_in_channels = 9
+ if num_in_channels is None and pipeline_class == StableDiffusionUpscalePipeline:
+ num_in_channels = 7
+ elif num_in_channels is None:
+ num_in_channels = 4
+
+ if "unet_config" in original_config["model"]["params"]:
+ original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels
+
+ if (
+ "parameterization" in original_config["model"]["params"]
+ and original_config["model"]["params"]["parameterization"] == "v"
+ ):
+ if prediction_type is None:
+ # NOTE: For stable diffusion 2 base it is recommended to pass `prediction_type=="epsilon"`
+ # as it relies on a brittle global step parameter here
+ prediction_type = "epsilon" if global_step == 875000 else "v_prediction"
+ if image_size is None:
+ # NOTE: For stable diffusion 2 base one has to pass `image_size==512`
+ # as it relies on a brittle global step parameter here
+ image_size = 512 if global_step == 875000 else 768
+ else:
+ if prediction_type is None:
+ prediction_type = "epsilon"
+ if image_size is None:
+ image_size = 512
+
+ if controlnet is None and "control_stage_config" in original_config["model"]["params"]:
+ path = checkpoint_path_or_dict if isinstance(checkpoint_path_or_dict, str) else ""
+ controlnet = convert_controlnet_checkpoint(
+ checkpoint, original_config, path, image_size, upcast_attention, extract_ema
+ )
+
+ if "timesteps" in original_config["model"]["params"]:
+ num_train_timesteps = original_config["model"]["params"]["timesteps"]
+ else:
+ num_train_timesteps = 1000
+
+ if model_type in ["SDXL", "SDXL-Refiner"]:
+ scheduler_dict = {
+ "beta_schedule": "scaled_linear",
+ "beta_start": 0.00085,
+ "beta_end": 0.012,
+ "interpolation_type": "linear",
+ "num_train_timesteps": num_train_timesteps,
+ "prediction_type": "epsilon",
+ "sample_max_value": 1.0,
+ "set_alpha_to_one": False,
+ "skip_prk_steps": True,
+ "steps_offset": 1,
+ "timestep_spacing": "leading",
+ }
+ scheduler = EulerDiscreteScheduler.from_config(scheduler_dict)
+ scheduler_type = "euler"
+ else:
+ if "linear_start" in original_config["model"]["params"]:
+ beta_start = original_config["model"]["params"]["linear_start"]
+ else:
+ beta_start = 0.02
+
+ if "linear_end" in original_config["model"]["params"]:
+ beta_end = original_config["model"]["params"]["linear_end"]
+ else:
+ beta_end = 0.085
+ scheduler = DDIMScheduler(
+ beta_end=beta_end,
+ beta_schedule="scaled_linear",
+ beta_start=beta_start,
+ num_train_timesteps=num_train_timesteps,
+ steps_offset=1,
+ clip_sample=False,
+ set_alpha_to_one=False,
+ prediction_type=prediction_type,
+ )
+ # make sure scheduler works correctly with DDIM
+ scheduler.register_to_config(clip_sample=False)
+
+ if scheduler_type == "pndm":
+ config = dict(scheduler.config)
+ config["skip_prk_steps"] = True
+ scheduler = PNDMScheduler.from_config(config)
+ elif scheduler_type == "lms":
+ scheduler = LMSDiscreteScheduler.from_config(scheduler.config)
+ elif scheduler_type == "heun":
+ scheduler = HeunDiscreteScheduler.from_config(scheduler.config)
+ elif scheduler_type == "euler":
+ scheduler = EulerDiscreteScheduler.from_config(scheduler.config)
+ elif scheduler_type == "euler-ancestral":
+ scheduler = EulerAncestralDiscreteScheduler.from_config(scheduler.config)
+ elif scheduler_type == "dpm":
+ scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config)
+ elif scheduler_type == "ddim":
+ scheduler = scheduler
+ else:
+ raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!")
+
+ if pipeline_class == StableDiffusionUpscalePipeline:
+ image_size = original_config["model"]["params"]["unet_config"]["params"]["image_size"]
+
+ # Convert the UNet2DConditionModel model.
+ unet_config = create_unet_diffusers_config(original_config, image_size=image_size)
+ unet_config["upcast_attention"] = upcast_attention
+
+ path = checkpoint_path_or_dict if isinstance(checkpoint_path_or_dict, str) else ""
+ converted_unet_checkpoint = convert_ldm_unet_checkpoint(
+ checkpoint, unet_config, path=path, extract_ema=extract_ema
+ )
+
+ ctx = init_empty_weights if is_accelerate_available() else nullcontext
+ with ctx():
+ unet = UNet2DConditionModel(**unet_config)
+
+ if is_accelerate_available():
+ if model_type not in ["SDXL", "SDXL-Refiner"]: # SBM Delay this.
+ for param_name, param in converted_unet_checkpoint.items():
+ set_module_tensor_to_device(unet, param_name, "cpu", value=param)
+ else:
+ unet.load_state_dict(converted_unet_checkpoint)
+
+ # Convert the VAE model.
+ if vae_path is None and vae is None:
+ vae_config = create_vae_diffusers_config(original_config, image_size=image_size)
+ converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
+
+ if (
+ "model" in original_config
+ and "params" in original_config["model"]
+ and "scale_factor" in original_config["model"]["params"]
+ ):
+ vae_scaling_factor = original_config["model"]["params"]["scale_factor"]
+ else:
+ vae_scaling_factor = 0.18215 # default SD scaling factor
+
+ vae_config["scaling_factor"] = vae_scaling_factor
+
+ ctx = init_empty_weights if is_accelerate_available() else nullcontext
+ with ctx():
+ vae = AutoencoderKL(**vae_config)
+
+ if is_accelerate_available():
+ for param_name, param in converted_vae_checkpoint.items():
+ set_module_tensor_to_device(vae, param_name, "cpu", value=param)
+ else:
+ vae.load_state_dict(converted_vae_checkpoint)
+ elif vae is None:
+ vae = AutoencoderKL.from_pretrained(vae_path, local_files_only=local_files_only)
+
+ if model_type == "FrozenOpenCLIPEmbedder":
+ config_name = "stabilityai/stable-diffusion-2"
+ config_kwargs = {"subfolder": "text_encoder"}
+
+ if text_encoder is None:
+ text_model = convert_open_clip_checkpoint(
+ checkpoint, config_name, local_files_only=local_files_only, **config_kwargs
+ )
+ else:
+ text_model = text_encoder
+
+ try:
+ tokenizer = CLIPTokenizer.from_pretrained(
+ "stabilityai/stable-diffusion-2", subfolder="tokenizer", local_files_only=local_files_only
+ )
+ except Exception:
+ raise ValueError(
+ f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: 'stabilityai/stable-diffusion-2'."
+ )
+
+ if stable_unclip is None:
+ if controlnet:
+ pipe = pipeline_class(
+ vae=vae,
+ text_encoder=text_model,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ controlnet=controlnet,
+ safety_checker=None,
+ feature_extractor=None,
+ )
+ if hasattr(pipe, "requires_safety_checker"):
+ pipe.requires_safety_checker = False
+
+ elif pipeline_class == StableDiffusionUpscalePipeline:
+ scheduler = DDIMScheduler.from_pretrained(
+ "stabilityai/stable-diffusion-x4-upscaler", subfolder="scheduler"
+ )
+ low_res_scheduler = DDPMScheduler.from_pretrained(
+ "stabilityai/stable-diffusion-x4-upscaler", subfolder="low_res_scheduler"
+ )
+
+ pipe = pipeline_class(
+ vae=vae,
+ text_encoder=text_model,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ low_res_scheduler=low_res_scheduler,
+ safety_checker=None,
+ feature_extractor=None,
+ )
+
+ else:
+ pipe = pipeline_class(
+ vae=vae,
+ text_encoder=text_model,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=None,
+ feature_extractor=None,
+ )
+ if hasattr(pipe, "requires_safety_checker"):
+ pipe.requires_safety_checker = False
+
+ else:
+ image_normalizer, image_noising_scheduler = stable_unclip_image_noising_components(
+ original_config, clip_stats_path=clip_stats_path, device=device
+ )
+
+ if stable_unclip == "img2img":
+ feature_extractor, image_encoder = stable_unclip_image_encoder(original_config)
+
+ pipe = StableUnCLIPImg2ImgPipeline(
+ # image encoding components
+ feature_extractor=feature_extractor,
+ image_encoder=image_encoder,
+ # image noising components
+ image_normalizer=image_normalizer,
+ image_noising_scheduler=image_noising_scheduler,
+ # regular denoising components
+ tokenizer=tokenizer,
+ text_encoder=text_model,
+ unet=unet,
+ scheduler=scheduler,
+ # vae
+ vae=vae,
+ )
+ elif stable_unclip == "txt2img":
+ if stable_unclip_prior is None or stable_unclip_prior == "karlo":
+ karlo_model = "kakaobrain/karlo-v1-alpha"
+ prior = PriorTransformer.from_pretrained(
+ karlo_model, subfolder="prior", local_files_only=local_files_only
+ )
+
+ try:
+ prior_tokenizer = CLIPTokenizer.from_pretrained(
+ "openai/clip-vit-large-patch14", local_files_only=local_files_only
+ )
+ except Exception:
+ raise ValueError(
+ f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: 'openai/clip-vit-large-patch14'."
+ )
+ prior_text_model = CLIPTextModelWithProjection.from_pretrained(
+ "openai/clip-vit-large-patch14", local_files_only=local_files_only
+ )
+
+ prior_scheduler = UnCLIPScheduler.from_pretrained(
+ karlo_model, subfolder="prior_scheduler", local_files_only=local_files_only
+ )
+ prior_scheduler = DDPMScheduler.from_config(prior_scheduler.config)
+ else:
+ raise NotImplementedError(f"unknown prior for stable unclip model: {stable_unclip_prior}")
+
+ pipe = StableUnCLIPPipeline(
+ # prior components
+ prior_tokenizer=prior_tokenizer,
+ prior_text_encoder=prior_text_model,
+ prior=prior,
+ prior_scheduler=prior_scheduler,
+ # image noising components
+ image_normalizer=image_normalizer,
+ image_noising_scheduler=image_noising_scheduler,
+ # regular denoising components
+ tokenizer=tokenizer,
+ text_encoder=text_model,
+ unet=unet,
+ scheduler=scheduler,
+ # vae
+ vae=vae,
+ )
+ else:
+ raise NotImplementedError(f"unknown `stable_unclip` type: {stable_unclip}")
+ elif model_type == "PaintByExample":
+ vision_model = convert_paint_by_example_checkpoint(checkpoint)
+ try:
+ tokenizer = CLIPTokenizer.from_pretrained(
+ "openai/clip-vit-large-patch14", local_files_only=local_files_only
+ )
+ except Exception:
+ raise ValueError(
+ f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: 'openai/clip-vit-large-patch14'."
+ )
+ try:
+ feature_extractor = AutoFeatureExtractor.from_pretrained(
+ "CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only
+ )
+ except Exception:
+ raise ValueError(
+ f"With local_files_only set to {local_files_only}, you must first locally save the feature_extractor in the following path: 'CompVis/stable-diffusion-safety-checker'."
+ )
+ pipe = PaintByExamplePipeline(
+ vae=vae,
+ image_encoder=vision_model,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=None,
+ feature_extractor=feature_extractor,
+ )
+ elif model_type == "FrozenCLIPEmbedder":
+ text_model = convert_ldm_clip_checkpoint(
+ checkpoint, local_files_only=local_files_only, text_encoder=text_encoder
+ )
+ try:
+ tokenizer = (
+ CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14", local_files_only=local_files_only)
+ if tokenizer is None
+ else tokenizer
+ )
+ except Exception:
+ raise ValueError(
+ f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: 'openai/clip-vit-large-patch14'."
+ )
+
+ if load_safety_checker:
+ safety_checker = StableDiffusionSafetyChecker.from_pretrained(
+ "CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only
+ )
+ feature_extractor = AutoFeatureExtractor.from_pretrained(
+ "CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only
+ )
+ else:
+ safety_checker = None
+ feature_extractor = None
+
+ if controlnet:
+ pipe = pipeline_class(
+ vae=vae,
+ text_encoder=text_model,
+ tokenizer=tokenizer,
+ unet=unet,
+ controlnet=controlnet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+ else:
+ pipe = pipeline_class(
+ vae=vae,
+ text_encoder=text_model,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+ elif model_type in ["SDXL", "SDXL-Refiner"]:
+ is_refiner = model_type == "SDXL-Refiner"
+
+ if (is_refiner is False) and (tokenizer is None):
+ try:
+ tokenizer = CLIPTokenizer.from_pretrained(
+ "openai/clip-vit-large-patch14", local_files_only=local_files_only
+ )
+ except Exception:
+ raise ValueError(
+ f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: 'openai/clip-vit-large-patch14'."
+ )
+
+ if (is_refiner is False) and (text_encoder is None):
+ text_encoder = convert_ldm_clip_checkpoint(checkpoint, local_files_only=local_files_only)
+
+ if tokenizer_2 is None:
+ try:
+ tokenizer_2 = CLIPTokenizer.from_pretrained(
+ "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", pad_token="!", local_files_only=local_files_only
+ )
+ except Exception:
+ raise ValueError(
+ f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: 'laion/CLIP-ViT-bigG-14-laion2B-39B-b160k' with `pad_token` set to '!'."
+ )
+
+ if text_encoder_2 is None:
+ config_name = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
+ config_kwargs = {"projection_dim": 1280}
+ prefix = "conditioner.embedders.0.model." if is_refiner else "conditioner.embedders.1.model."
+
+ text_encoder_2 = convert_open_clip_checkpoint(
+ checkpoint,
+ config_name,
+ prefix=prefix,
+ has_projection=True,
+ local_files_only=local_files_only,
+ **config_kwargs,
+ )
+
+ if is_accelerate_available(): # SBM Now move model to cpu.
+ for param_name, param in converted_unet_checkpoint.items():
+ set_module_tensor_to_device(unet, param_name, "cpu", value=param)
+
+ if controlnet:
+ pipe = pipeline_class(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ text_encoder_2=text_encoder_2,
+ tokenizer_2=tokenizer_2,
+ unet=unet,
+ controlnet=controlnet,
+ scheduler=scheduler,
+ force_zeros_for_empty_prompt=True,
+ )
+ elif adapter:
+ pipe = pipeline_class(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ text_encoder_2=text_encoder_2,
+ tokenizer_2=tokenizer_2,
+ unet=unet,
+ adapter=adapter,
+ scheduler=scheduler,
+ force_zeros_for_empty_prompt=True,
+ )
+
+ else:
+ pipeline_kwargs = {
+ "vae": vae,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ "text_encoder_2": text_encoder_2,
+ "tokenizer_2": tokenizer_2,
+ "unet": unet,
+ "scheduler": scheduler,
+ }
+
+ if (pipeline_class == StableDiffusionXLImg2ImgPipeline) or (
+ pipeline_class == StableDiffusionXLInpaintPipeline
+ ):
+ pipeline_kwargs.update({"requires_aesthetics_score": is_refiner})
+
+ if is_refiner:
+ pipeline_kwargs.update({"force_zeros_for_empty_prompt": False})
+
+ pipe = pipeline_class(**pipeline_kwargs)
+ else:
+ text_config = create_ldm_bert_config(original_config)
+ text_model = convert_ldm_bert_checkpoint(checkpoint, text_config)
+ tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased", local_files_only=local_files_only)
+ pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
+
+ return pipe
+
+
+def download_controlnet_from_original_ckpt(
+ checkpoint_path: str,
+ original_config_file: str,
+ image_size: int = 512,
+ extract_ema: bool = False,
+ num_in_channels: Optional[int] = None,
+ upcast_attention: Optional[bool] = None,
+ device: str = None,
+ from_safetensors: bool = False,
+ use_linear_projection: Optional[bool] = None,
+ cross_attention_dim: Optional[bool] = None,
+) -> DiffusionPipeline:
+ if from_safetensors:
+ from safetensors import safe_open
+
+ checkpoint = {}
+ with safe_open(checkpoint_path, framework="pt", device="cpu") as f:
+ for key in f.keys():
+ checkpoint[key] = f.get_tensor(key)
+ else:
+ if device is None:
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ checkpoint = torch.load(checkpoint_path, map_location=device)
+ else:
+ checkpoint = torch.load(checkpoint_path, map_location=device)
+
+ # NOTE: this while loop isn't great but this controlnet checkpoint has one additional
+ # "state_dict" key https://huggingface.co/thibaud/controlnet-canny-sd21
+ while "state_dict" in checkpoint:
+ checkpoint = checkpoint["state_dict"]
+
+ original_config = yaml.safe_load(original_config_file)
+
+ if num_in_channels is not None:
+ original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels
+
+ if "control_stage_config" not in original_config["model"]["params"]:
+ raise ValueError("`control_stage_config` not present in original config")
+
+ controlnet = convert_controlnet_checkpoint(
+ checkpoint,
+ original_config,
+ checkpoint_path,
+ image_size,
+ upcast_attention,
+ extract_ema,
+ use_linear_projection=use_linear_projection,
+ cross_attention_dim=cross_attention_dim,
+ )
+
+ return controlnet
+
+
+def download_promptdiffusion_from_original_ckpt(
+ checkpoint_path: str,
+ original_config_file: str,
+ image_size: int = 512,
+ extract_ema: bool = False,
+ num_in_channels: Optional[int] = None,
+ upcast_attention: Optional[bool] = None,
+ device: str = None,
+ from_safetensors: bool = False,
+ use_linear_projection: Optional[bool] = None,
+ cross_attention_dim: Optional[bool] = None,
+) -> DiffusionPipeline:
+ if from_safetensors:
+ from safetensors import safe_open
+
+ checkpoint = {}
+ with safe_open(checkpoint_path, framework="pt", device="cpu") as f:
+ for key in f.keys():
+ checkpoint[key] = f.get_tensor(key)
+ else:
+ if device is None:
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ checkpoint = torch.load(checkpoint_path, map_location=device)
+ else:
+ checkpoint = torch.load(checkpoint_path, map_location=device)
+
+ # NOTE: this while loop isn't great but this controlnet checkpoint has one additional
+ # "state_dict" key https://huggingface.co/thibaud/controlnet-canny-sd21
+ while "state_dict" in checkpoint:
+ checkpoint = checkpoint["state_dict"]
+
+ original_config = yaml.safe_load(open(original_config_file))
+
+ if num_in_channels is not None:
+ original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels
+ if "control_stage_config" not in original_config["model"]["params"]:
+ raise ValueError("`control_stage_config` not present in original config")
+
+ controlnet = convert_promptdiffusion_checkpoint(
+ checkpoint,
+ original_config,
+ checkpoint_path,
+ image_size,
+ upcast_attention,
+ extract_ema,
+ use_linear_projection=use_linear_projection,
+ cross_attention_dim=cross_attention_dim,
+ )
+
+ return controlnet
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ "--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert."
+ )
+ parser.add_argument(
+ "--original_config_file",
+ type=str,
+ required=True,
+ help="The YAML config file corresponding to the original architecture.",
+ )
+ parser.add_argument(
+ "--num_in_channels",
+ default=None,
+ type=int,
+ help="The number of input channels. If `None` number of input channels will be automatically inferred.",
+ )
+ parser.add_argument(
+ "--image_size",
+ default=512,
+ type=int,
+ help=(
+ "The image size that the model was trained on. Use 512 for Stable Diffusion v1.X and Stable Diffusion v2"
+ " Base. Use 768 for Stable Diffusion v2."
+ ),
+ )
+ parser.add_argument(
+ "--extract_ema",
+ action="store_true",
+ help=(
+ "Only relevant for checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights"
+ " or not. Defaults to `False`. Add `--extract_ema` to extract the EMA weights. EMA weights usually yield"
+ " higher quality images for inference. Non-EMA weights are usually better to continue fine-tuning."
+ ),
+ )
+ parser.add_argument(
+ "--upcast_attention",
+ action="store_true",
+ help=(
+ "Whether the attention computation should always be upcasted. This is necessary when running stable"
+ " diffusion 2.1."
+ ),
+ )
+ parser.add_argument(
+ "--from_safetensors",
+ action="store_true",
+ help="If `--checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch.",
+ )
+ parser.add_argument(
+ "--to_safetensors",
+ action="store_true",
+ help="Whether to store pipeline in safetensors format or not.",
+ )
+ parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
+ parser.add_argument("--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)")
+
+ # small workaround to get argparser to parse a boolean input as either true _or_ false
+ def parse_bool(string):
+ if string == "True":
+ return True
+ elif string == "False":
+ return False
+ else:
+ raise ValueError(f"could not parse string as bool {string}")
+
+ parser.add_argument(
+ "--use_linear_projection", help="Override for use linear projection", required=False, type=parse_bool
+ )
+
+ parser.add_argument("--cross_attention_dim", help="Override for cross attention_dim", required=False, type=int)
+
+ args = parser.parse_args()
+
+ controlnet = download_promptdiffusion_from_original_ckpt(
+ checkpoint_path=args.checkpoint_path,
+ original_config_file=args.original_config_file,
+ image_size=args.image_size,
+ extract_ema=args.extract_ema,
+ num_in_channels=args.num_in_channels,
+ upcast_attention=args.upcast_attention,
+ from_safetensors=args.from_safetensors,
+ device=args.device,
+ use_linear_projection=args.use_linear_projection,
+ cross_attention_dim=args.cross_attention_dim,
+ )
+
+ controlnet.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors)
diff --git a/diffusers/examples/research_projects/promptdiffusion/pipeline_prompt_diffusion.py b/diffusers/examples/research_projects/promptdiffusion/pipeline_prompt_diffusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb4260d4653f98ad65ec6e30b17569c9603db4bf
--- /dev/null
+++ b/diffusers/examples/research_projects/promptdiffusion/pipeline_prompt_diffusion.py
@@ -0,0 +1,1328 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# Based on [In-Context Learning Unlocked for Diffusion Models](https://arxiv.org/abs/2305.01115)
+# Authors: Zhendong Wang, Yifan Jiang, Yadong Lu, Yelong Shen, Pengcheng He, Weizhu Chen, Zhangyang Wang, Mingyuan Zhou
+# Project Page: https://zhendong-wang.github.io/prompt-diffusion.github.io/
+# Code: https://github.com/Zhendong-Wang/Prompt-Diffusion
+#
+# Adapted to Diffusers by [iczaw](https://github.com/iczaw).
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import PIL.Image
+import torch
+import torch.nn.functional as F
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
+
+from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
+from diffusers.loaders import FromSingleFileMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
+from diffusers.models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
+from diffusers.models.lora import adjust_lora_scale_text_encoder
+from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline
+from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
+from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
+from diffusers.schedulers import KarrasDiffusionSchedulers
+from diffusers.utils import (
+ USE_PEFT_BACKEND,
+ deprecate,
+ logging,
+ replace_example_docstring,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
+from diffusers.utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> # !pip install opencv-python transformers accelerate
+ >>> from promptdiffusioncontrolnet import PromptDiffusionControlNetModel
+ >>> from diffusers.utils import load_image
+ >>> import torch
+
+ >>> from diffusers.pipelines.pipeline_utils import DiffusionPipeline
+ >>> from diffusers import UniPCMultistepScheduler
+ >>> from PIL import ImageOps
+
+ >>> # download an image
+ >>> image_a = ImageOps.invert(load_image(
+ ... "https://github.com/Zhendong-Wang/Prompt-Diffusion/blob/main/images_to_try/house_line.png?raw=true"
+ ... ))
+
+ >>> # download an image
+ >>> image_b = load_image(
+ ... "https://github.com/Zhendong-Wang/Prompt-Diffusion/blob/main/images_to_try/house.png?raw=true"
+ ... )
+
+ >>> # download an image
+ >>> query = ImageOps.invert(load_image(
+ ... "https://github.com/Zhendong-Wang/Prompt-Diffusion/blob/main/images_to_try/new_01.png?raw=true"
+ ... ))
+
+ >>> # load prompt diffusion control net and prompt diffusion
+ >>> controlnet = PromptDiffusionControlNetModel.from_pretrained("path-to-converted-promptdiffusion-controlnet", torch_dtype=torch.float16)
+ >>> pipe = DiffusionPipeline.from_pretrained(model_id, controlnet=controlnet, torch_dtype=torch.float16, variant="fp16", custom_pipeline="pipeline_prompt_diffusion")
+
+ >>> # speed up diffusion process with faster scheduler and memory optimization
+ >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
+ >>> # remove following line if xformers is not installed
+ >>> pipe.enable_xformers_memory_efficient_attention()
+
+ >>> pipe.enable_model_cpu_offload()
+
+ >>> # generate image
+ >>> generator = torch.manual_seed(0)
+ >>> image = pipe(
+ ... "a tortoise", num_inference_steps=20, generator=generator, image_pair=[image_a,image_b], image=query
+ ... ).images[0]
+ ```
+"""
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ **kwargs,
+):
+ """
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used,
+ `timesteps` must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
+ timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
+ must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class PromptDiffusionPipeline(
+ DiffusionPipeline, TextualInversionLoaderMixin, StableDiffusionLoraLoaderMixin, FromSingleFileMixin
+):
+ r"""
+ Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance.
+
+ This pipeline also adds experimental support for [Prompt Diffusion](https://arxiv.org/abs/2305.01115).
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
+
+ The pipeline also inherits the following loading methods:
+ - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
+ - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
+ - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
+ - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
+ text_encoder ([`~transformers.CLIPTextModel`]):
+ Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
+ tokenizer ([`~transformers.CLIPTokenizer`]):
+ A `CLIPTokenizer` to tokenize text.
+ unet ([`UNet2DConditionModel`]):
+ A `UNet2DConditionModel` to denoise the encoded image latents.
+ controlnet ([`ControlNetModel`] or `List[ControlNetModel]`):
+ Provides additional conditioning to the `unet` during the denoising process. If you set multiple
+ ControlNets as a list, the outputs from each ControlNet are added together to create one combined
+ additional conditioning.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
+ about a model's potential harms.
+ feature_extractor ([`~transformers.CLIPImageProcessor`]):
+ A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
+ """
+
+ model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
+ _optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
+ _exclude_from_cpu_offload = ["safety_checker"]
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],
+ scheduler: KarrasDiffusionSchedulers,
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPImageProcessor,
+ image_encoder: CLIPVisionModelWithProjection = None,
+ requires_safety_checker: bool = True,
+ ):
+ super().__init__()
+
+ if safety_checker is None and requires_safety_checker:
+ logger.warning(
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
+ )
+
+ if safety_checker is not None and feature_extractor is None:
+ raise ValueError(
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
+ )
+
+ if isinstance(controlnet, (list, tuple)):
+ controlnet = MultiControlNetModel(controlnet)
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ controlnet=controlnet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ image_encoder=image_encoder,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
+ self.control_image_processor = VaeImageProcessor(
+ vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
+ )
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.vae.enable_slicing()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_slicing()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling
+ def enable_vae_tiling(self):
+ r"""
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
+ processing larger images.
+ """
+ self.vae.enable_tiling()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling
+ def disable_vae_tiling(self):
+ r"""
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_tiling()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
+ def _encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt=None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ lora_scale: Optional[float] = None,
+ **kwargs,
+ ):
+ deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
+ deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
+
+ prompt_embeds_tuple = self.encode_prompt(
+ prompt=prompt,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ do_classifier_free_guidance=do_classifier_free_guidance,
+ negative_prompt=negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ lora_scale=lora_scale,
+ **kwargs,
+ )
+
+ # concatenate for backwards comp
+ prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
+
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt=None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ lora_scale: Optional[float] = None,
+ clip_skip: Optional[int] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ lora_scale (`float`, *optional*):
+ A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+ """
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ # dynamically adjust the LoRA scale
+ if not USE_PEFT_BACKEND:
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
+ else:
+ scale_lora_layers(self.text_encoder, lora_scale)
+
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = self.tokenizer.batch_decode(
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
+ )
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = text_inputs.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ if clip_skip is None:
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
+ prompt_embeds = prompt_embeds[0]
+ else:
+ prompt_embeds = self.text_encoder(
+ text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True
+ )
+ # Access the `hidden_states` first, that contains a tuple of
+ # all the hidden states from the encoder layers. Then index into
+ # the tuple to access the hidden states from the desired layer.
+ prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
+ # We also need to apply the final LayerNorm here to not mess with the
+ # representations. The `last_hidden_states` that we typically use for
+ # obtaining the final prompt representations passes through the LayerNorm
+ # layer.
+ prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
+
+ if self.text_encoder is not None:
+ prompt_embeds_dtype = self.text_encoder.dtype
+ elif self.unet is not None:
+ prompt_embeds_dtype = self.unet.dtype
+ else:
+ prompt_embeds_dtype = prompt_embeds.dtype
+
+ prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = uncond_input.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ negative_prompt_embeds = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ negative_prompt_embeds = negative_prompt_embeds[0]
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder, lora_scale)
+
+ return prompt_embeds, negative_prompt_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
+ def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
+ dtype = next(self.image_encoder.parameters()).dtype
+
+ if not isinstance(image, torch.Tensor):
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
+
+ image = image.to(device=device, dtype=dtype)
+ if output_hidden_states:
+ image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
+ image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
+ uncond_image_enc_hidden_states = self.image_encoder(
+ torch.zeros_like(image), output_hidden_states=True
+ ).hidden_states[-2]
+ uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
+ num_images_per_prompt, dim=0
+ )
+ return image_enc_hidden_states, uncond_image_enc_hidden_states
+ else:
+ image_embeds = self.image_encoder(image).image_embeds
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+ uncond_image_embeds = torch.zeros_like(image_embeds)
+
+ return image_embeds, uncond_image_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
+ def run_safety_checker(self, image, device, dtype):
+ if self.safety_checker is None:
+ has_nsfw_concept = None
+ else:
+ if torch.is_tensor(image):
+ feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
+ else:
+ feature_extractor_input = self.image_processor.numpy_to_pil(image)
+ safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
+ image, has_nsfw_concept = self.safety_checker(
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
+ )
+ return image, has_nsfw_concept
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
+ def decode_latents(self, latents):
+ deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
+ deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
+
+ latents = 1 / self.vae.config.scaling_factor * latents
+ image = self.vae.decode(latents, return_dict=False)[0]
+ image = (image / 2 + 0.5).clamp(0, 1)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+ return image
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(
+ self,
+ prompt,
+ image,
+ image_pair,
+ callback_steps,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ controlnet_conditioning_scale=1.0,
+ control_guidance_start=0.0,
+ control_guidance_end=1.0,
+ callback_on_step_end_tensor_inputs=None,
+ ):
+ if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ # `prompt` needs more sophisticated handling when there are multiple
+ # conditionings.
+ if isinstance(self.controlnet, MultiControlNetModel):
+ if isinstance(prompt, list):
+ logger.warning(
+ f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}"
+ " prompts. The conditionings will be fixed across the prompts."
+ )
+
+ # Check `image`
+ is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
+ self.controlnet, torch._dynamo.eval_frame.OptimizedModule
+ )
+ if (
+ isinstance(self.controlnet, ControlNetModel)
+ or is_compiled
+ and isinstance(self.controlnet._orig_mod, ControlNetModel)
+ ):
+ self.check_image(image, prompt, prompt_embeds)
+ elif (
+ isinstance(self.controlnet, MultiControlNetModel)
+ or is_compiled
+ and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
+ ):
+ if not isinstance(image, list):
+ raise TypeError("For multiple controlnets: `image` must be type `list`")
+
+ # When `image` is a nested list:
+ # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]])
+ elif any(isinstance(i, list) for i in image):
+ raise ValueError("A single batch of multiple conditionings is not supported at the moment.")
+ elif len(image) != len(self.controlnet.nets):
+ raise ValueError(
+ f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets."
+ )
+
+ for image_ in image:
+ self.check_image(image_, prompt, prompt_embeds)
+ else:
+ assert False
+
+ # Check `image_pair`
+ if len(image_pair) == 2:
+ for image in image_pair:
+ if (
+ isinstance(self.controlnet, ControlNetModel)
+ or is_compiled
+ and isinstance(self.controlnet._orig_mod, ControlNetModel)
+ ):
+ self.check_image(image, prompt, prompt_embeds)
+ else:
+ raise ValueError(
+ f"You have passed a list of images of length {len(image_pair)}."
+ f"Make sure the list size equals to two."
+ )
+
+ # Check `controlnet_conditioning_scale`
+ if (
+ isinstance(self.controlnet, ControlNetModel)
+ or is_compiled
+ and isinstance(self.controlnet._orig_mod, ControlNetModel)
+ ):
+ if not isinstance(controlnet_conditioning_scale, float):
+ raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
+ elif (
+ isinstance(self.controlnet, MultiControlNetModel)
+ or is_compiled
+ and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
+ ):
+ if isinstance(controlnet_conditioning_scale, list):
+ if any(isinstance(i, list) for i in controlnet_conditioning_scale):
+ raise ValueError("A single batch of multiple conditionings is not supported at the moment.")
+ elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len(
+ self.controlnet.nets
+ ):
+ raise ValueError(
+ "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have"
+ " the same length as the number of controlnets"
+ )
+ else:
+ assert False
+
+ if not isinstance(control_guidance_start, (tuple, list)):
+ control_guidance_start = [control_guidance_start]
+
+ if not isinstance(control_guidance_end, (tuple, list)):
+ control_guidance_end = [control_guidance_end]
+
+ if len(control_guidance_start) != len(control_guidance_end):
+ raise ValueError(
+ f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list."
+ )
+
+ if isinstance(self.controlnet, MultiControlNetModel):
+ if len(control_guidance_start) != len(self.controlnet.nets):
+ raise ValueError(
+ f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}."
+ )
+
+ for start, end in zip(control_guidance_start, control_guidance_end):
+ if start >= end:
+ raise ValueError(
+ f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}."
+ )
+ if start < 0.0:
+ raise ValueError(f"control guidance start: {start} can't be smaller than 0.")
+ if end > 1.0:
+ raise ValueError(f"control guidance end: {end} can't be larger than 1.0.")
+
+ # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image
+ def check_image(self, image, prompt, prompt_embeds):
+ image_is_pil = isinstance(image, PIL.Image.Image)
+ image_is_tensor = isinstance(image, torch.Tensor)
+ image_is_np = isinstance(image, np.ndarray)
+ image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)
+ image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)
+ image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray)
+
+ if (
+ not image_is_pil
+ and not image_is_tensor
+ and not image_is_np
+ and not image_is_pil_list
+ and not image_is_tensor_list
+ and not image_is_np_list
+ ):
+ raise TypeError(
+ f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}"
+ )
+
+ if image_is_pil:
+ image_batch_size = 1
+ else:
+ image_batch_size = len(image)
+
+ if prompt is not None and isinstance(prompt, str):
+ prompt_batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ prompt_batch_size = len(prompt)
+ elif prompt_embeds is not None:
+ prompt_batch_size = prompt_embeds.shape[0]
+
+ if image_batch_size != 1 and image_batch_size != prompt_batch_size:
+ raise ValueError(
+ f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
+ )
+
+ # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image
+ def prepare_image(
+ self,
+ image,
+ width,
+ height,
+ batch_size,
+ num_images_per_prompt,
+ device,
+ dtype,
+ do_classifier_free_guidance=False,
+ guess_mode=False,
+ ):
+ image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
+ image_batch_size = image.shape[0]
+
+ if image_batch_size == 1:
+ repeat_by = batch_size
+ else:
+ # image batch size is the same as prompt batch size
+ repeat_by = num_images_per_prompt
+
+ image = image.repeat_interleave(repeat_by, dim=0)
+
+ image = image.to(device=device, dtype=dtype)
+
+ if do_classifier_free_guidance and not guess_mode:
+ image = torch.cat([image] * 2)
+
+ return image
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
+ shape = (
+ batch_size,
+ num_channels_latents,
+ int(height) // self.vae_scale_factor,
+ int(width) // self.vae_scale_factor,
+ )
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_freeu
+ def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
+ r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497.
+
+ The suffixes after the scaling factors represent the stages where they are being applied.
+
+ Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values
+ that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
+
+ Args:
+ s1 (`float`):
+ Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
+ mitigate "oversmoothing effect" in the enhanced denoising process.
+ s2 (`float`):
+ Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
+ mitigate "oversmoothing effect" in the enhanced denoising process.
+ b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
+ b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
+ """
+ if not hasattr(self, "unet"):
+ raise ValueError("The pipeline must have `unet` for using FreeU.")
+ self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2)
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_freeu
+ def disable_freeu(self):
+ """Disables the FreeU mechanism if enabled."""
+ self.unet.disable_freeu()
+
+ # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
+ def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
+ """
+ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
+
+ Args:
+ timesteps (`torch.Tensor`):
+ generate embedding vectors at these timesteps
+ embedding_dim (`int`, *optional*, defaults to 512):
+ dimension of the embeddings to generate
+ dtype:
+ data type of the generated embeddings
+
+ Returns:
+ `torch.Tensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
+ """
+ assert len(w.shape) == 1
+ w = w * 1000.0
+
+ half_dim = embedding_dim // 2
+ emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
+ emb = w.to(dtype)[:, None] * emb[None, :]
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+ if embedding_dim % 2 == 1: # zero pad
+ emb = torch.nn.functional.pad(emb, (0, 1))
+ assert emb.shape == (w.shape[0], embedding_dim)
+ return emb
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def clip_skip(self):
+ return self._clip_skip
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
+
+ @property
+ def cross_attention_kwargs(self):
+ return self._cross_attention_kwargs
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ image: PipelineImageInput = None,
+ image_pair: List[PipelineImageInput] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ timesteps: List[int] = None,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ ip_adapter_image: Optional[PipelineImageInput] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
+ guess_mode: bool = False,
+ control_guidance_start: Union[float, List[float]] = 0.0,
+ control_guidance_end: Union[float, List[float]] = 1.0,
+ clip_skip: Optional[int] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ **kwargs,
+ ):
+ r"""
+ The call function to the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
+ image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
+ `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
+ The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
+ specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be
+ accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height
+ and/or width are passed, `image` is resized accordingly. If multiple ControlNets are specified in
+ `init`, images must be passed as a list such that each element of the list can be correctly batched for
+ input to a single ControlNet.
+ image_pair `List[PIL.Image.Image]`:
+ a pair of task-specific example images
+ height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
+ passed will be used. Must be in descending order.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ A higher guidance scale value encourages the model to generate images closely linked to the text
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
+ generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor is generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
+ provided, text embeddings are generated from the `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that calls every `callback_steps` steps during inference. The function is called with the
+ following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function is called. If not specified, the callback is called at
+ every step.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
+ The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added
+ to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set
+ the corresponding scale as a list.
+ guess_mode (`bool`, *optional*, defaults to `False`):
+ The ControlNet encoder tries to recognize the content of the input image even if you remove all
+ prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended.
+ control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
+ The percentage of total steps at which the ControlNet starts applying.
+ control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
+ The percentage of total steps at which the ControlNet stops applying.
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
+ otherwise a `tuple` is returned where the first element is a list with the generated images and the
+ second element is a list of `bool`s indicating whether the corresponding generated image contains
+ "not-safe-for-work" (nsfw) content.
+ """
+
+ callback = kwargs.pop("callback", None)
+ callback_steps = kwargs.pop("callback_steps", None)
+
+ if callback is not None:
+ deprecate(
+ "callback",
+ "1.0.0",
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
+ )
+ if callback_steps is not None:
+ deprecate(
+ "callback_steps",
+ "1.0.0",
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
+ )
+
+ controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
+
+ # align format for control guidance
+ if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
+ control_guidance_start = len(control_guidance_end) * [control_guidance_start]
+ elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
+ control_guidance_end = len(control_guidance_start) * [control_guidance_end]
+ elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
+ mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
+ control_guidance_start, control_guidance_end = (
+ mult * [control_guidance_start],
+ mult * [control_guidance_end],
+ )
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ image,
+ image_pair,
+ callback_steps,
+ negative_prompt,
+ prompt_embeds,
+ negative_prompt_embeds,
+ controlnet_conditioning_scale,
+ control_guidance_start,
+ control_guidance_end,
+ callback_on_step_end_tensor_inputs,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._clip_skip = clip_skip
+ self._cross_attention_kwargs = cross_attention_kwargs
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
+ controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
+
+ global_pool_conditions = (
+ controlnet.config.global_pool_conditions
+ if isinstance(controlnet, ControlNetModel)
+ else controlnet.nets[0].config.global_pool_conditions
+ )
+ guess_mode = guess_mode or global_pool_conditions
+
+ # 3. Encode input prompt
+ text_encoder_lora_scale = (
+ self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
+ )
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
+ prompt,
+ device,
+ num_images_per_prompt,
+ self.do_classifier_free_guidance,
+ negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ lora_scale=text_encoder_lora_scale,
+ clip_skip=self.clip_skip,
+ )
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ if self.do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+
+ # 3.1 Prepare image pair
+
+ if isinstance(controlnet, ControlNetModel):
+ image_pair = torch.cat(
+ [
+ self.prepare_image(
+ image=im,
+ width=width,
+ height=height,
+ batch_size=batch_size * num_images_per_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ dtype=controlnet.dtype,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ guess_mode=guess_mode,
+ )
+ for im in image_pair
+ ],
+ 1,
+ )
+ # 4. Prepare image
+ if isinstance(controlnet, ControlNetModel):
+ image = self.prepare_image(
+ image=image,
+ width=width,
+ height=height,
+ batch_size=batch_size * num_images_per_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ dtype=controlnet.dtype,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ guess_mode=guess_mode,
+ )
+ height, width = image.shape[-2:]
+ elif isinstance(controlnet, MultiControlNetModel):
+ images = []
+
+ for image_ in image:
+ image_ = self.prepare_image(
+ image=image_,
+ width=width,
+ height=height,
+ batch_size=batch_size * num_images_per_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ dtype=controlnet.dtype,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ guess_mode=guess_mode,
+ )
+
+ images.append(image_)
+
+ image = images
+ height, width = image[0].shape[-2:]
+ else:
+ assert False
+
+ # 5. Prepare timesteps
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
+ self._num_timesteps = len(timesteps)
+
+ # 6. Prepare latent variables
+ num_channels_latents = self.unet.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 6.5 Optionally get Guidance Scale Embedding
+ timestep_cond = None
+ if self.unet.config.time_cond_proj_dim is not None:
+ guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
+ timestep_cond = self.get_guidance_scale_embedding(
+ guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
+ ).to(device=device, dtype=latents.dtype)
+
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 7.2 Create tensor stating which controlnets to keep
+ controlnet_keep = []
+ for i in range(len(timesteps)):
+ keeps = [
+ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
+ for s, e in zip(control_guidance_start, control_guidance_end)
+ ]
+ controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)
+
+ # 8. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ is_unet_compiled = is_compiled_module(self.unet)
+ is_controlnet_compiled = is_compiled_module(self.controlnet)
+ is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1")
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # Relevant thread:
+ # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
+ if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1:
+ torch._inductor.cudagraph_mark_step_begin()
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # controlnet(s) inference
+ if guess_mode and self.do_classifier_free_guidance:
+ # Infer ControlNet only for the conditional batch.
+ control_model_input = latents
+ control_model_input = self.scheduler.scale_model_input(control_model_input, t)
+ controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
+ else:
+ control_model_input = latent_model_input
+ controlnet_prompt_embeds = prompt_embeds
+
+ if isinstance(controlnet_keep[i], list):
+ cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
+ else:
+ controlnet_cond_scale = controlnet_conditioning_scale
+ if isinstance(controlnet_cond_scale, list):
+ controlnet_cond_scale = controlnet_cond_scale[0]
+ cond_scale = controlnet_cond_scale * controlnet_keep[i]
+
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
+ control_model_input,
+ t,
+ encoder_hidden_states=controlnet_prompt_embeds,
+ controlnet_query_cond=image,
+ controlnet_cond=image_pair,
+ conditioning_scale=cond_scale,
+ guess_mode=guess_mode,
+ return_dict=False,
+ )
+
+ if guess_mode and self.do_classifier_free_guidance:
+ # Inferred ControlNet only for the conditional batch.
+ # To apply the output of ControlNet to both the unconditional and conditional batches,
+ # add 0 to the unconditional batch to keep it unchanged.
+ down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
+ mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
+
+ # predict the noise residual
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ timestep_cond=timestep_cond,
+ cross_attention_kwargs=self.cross_attention_kwargs,
+ down_block_additional_residuals=down_block_res_samples,
+ mid_block_additional_residual=mid_block_res_sample,
+ return_dict=False,
+ )[0]
+
+ # perform guidance
+ if self.do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
+
+ # If we do sequential model offloading, let's offload unet and controlnet
+ # manually for max memory savings
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.unet.to("cpu")
+ self.controlnet.to("cpu")
+ torch.cuda.empty_cache()
+
+ if not output_type == "latent":
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
+ 0
+ ]
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
+ else:
+ image = latents
+ has_nsfw_concept = None
+
+ if has_nsfw_concept is None:
+ do_denormalize = [True] * image.shape[0]
+ else:
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
+
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
diff --git a/diffusers/examples/research_projects/promptdiffusion/promptdiffusioncontrolnet.py b/diffusers/examples/research_projects/promptdiffusion/promptdiffusioncontrolnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..46cabd863dfaca444ac9b819943ef0843ef505fd
--- /dev/null
+++ b/diffusers/examples/research_projects/promptdiffusion/promptdiffusioncontrolnet.py
@@ -0,0 +1,385 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Any, Dict, Optional, Tuple, Union
+
+import torch
+
+from diffusers.configuration_utils import register_to_config
+from diffusers.models.controlnet import (
+ ControlNetConditioningEmbedding,
+ ControlNetModel,
+ ControlNetOutput,
+)
+from diffusers.utils import logging
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class PromptDiffusionControlNetModel(ControlNetModel):
+ """
+ A PromptDiffusionControlNet model.
+
+ Args:
+ in_channels (`int`, defaults to 4):
+ The number of channels in the input sample.
+ flip_sin_to_cos (`bool`, defaults to `True`):
+ Whether to flip the sin to cos in the time embedding.
+ freq_shift (`int`, defaults to 0):
+ The frequency shift to apply to the time embedding.
+ down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
+ The tuple of downsample blocks to use.
+ only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`):
+ block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`):
+ The tuple of output channels for each block.
+ layers_per_block (`int`, defaults to 2):
+ The number of layers per block.
+ downsample_padding (`int`, defaults to 1):
+ The padding to use for the downsampling convolution.
+ mid_block_scale_factor (`float`, defaults to 1):
+ The scale factor to use for the mid block.
+ act_fn (`str`, defaults to "silu"):
+ The activation function to use.
+ norm_num_groups (`int`, *optional*, defaults to 32):
+ The number of groups to use for the normalization. If None, normalization and activation layers is skipped
+ in post-processing.
+ norm_eps (`float`, defaults to 1e-5):
+ The epsilon to use for the normalization.
+ cross_attention_dim (`int`, defaults to 1280):
+ The dimension of the cross attention features.
+ transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
+ encoder_hid_dim (`int`, *optional*, defaults to None):
+ If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
+ dimension to `cross_attention_dim`.
+ encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
+ If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
+ embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
+ attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8):
+ The dimension of the attention heads.
+ use_linear_projection (`bool`, defaults to `False`):
+ class_embed_type (`str`, *optional*, defaults to `None`):
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None,
+ `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
+ addition_embed_type (`str`, *optional*, defaults to `None`):
+ Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
+ "text". "text" will use the `TextTimeEmbedding` layer.
+ num_class_embeds (`int`, *optional*, defaults to 0):
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
+ class conditioning with `class_embed_type` equal to `None`.
+ upcast_attention (`bool`, defaults to `False`):
+ resnet_time_scale_shift (`str`, defaults to `"default"`):
+ Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`.
+ projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`):
+ The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when
+ `class_embed_type="projection"`.
+ controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`):
+ The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
+ conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`):
+ The tuple of output channel for each block in the `conditioning_embedding` layer.
+ global_pool_conditions (`bool`, defaults to `False`):
+ TODO(Patrick) - unused parameter.
+ addition_embed_type_num_heads (`int`, defaults to 64):
+ The number of heads to use for the `TextTimeEmbedding` layer.
+ """
+
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ in_channels: int = 4,
+ conditioning_channels: int = 3,
+ flip_sin_to_cos: bool = True,
+ freq_shift: int = 0,
+ down_block_types: Tuple[str, ...] = (
+ "CrossAttnDownBlock2D",
+ "CrossAttnDownBlock2D",
+ "CrossAttnDownBlock2D",
+ "DownBlock2D",
+ ),
+ mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
+ block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
+ layers_per_block: int = 2,
+ downsample_padding: int = 1,
+ mid_block_scale_factor: float = 1,
+ act_fn: str = "silu",
+ norm_num_groups: Optional[int] = 32,
+ norm_eps: float = 1e-5,
+ cross_attention_dim: int = 1280,
+ transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
+ encoder_hid_dim: Optional[int] = None,
+ encoder_hid_dim_type: Optional[str] = None,
+ attention_head_dim: Union[int, Tuple[int, ...]] = 8,
+ num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None,
+ use_linear_projection: bool = False,
+ class_embed_type: Optional[str] = None,
+ addition_embed_type: Optional[str] = None,
+ addition_time_embed_dim: Optional[int] = None,
+ num_class_embeds: Optional[int] = None,
+ upcast_attention: bool = False,
+ resnet_time_scale_shift: str = "default",
+ projection_class_embeddings_input_dim: Optional[int] = None,
+ controlnet_conditioning_channel_order: str = "rgb",
+ conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
+ global_pool_conditions: bool = False,
+ addition_embed_type_num_heads: int = 64,
+ ):
+ super().__init__(
+ in_channels,
+ conditioning_channels,
+ flip_sin_to_cos,
+ freq_shift,
+ down_block_types,
+ mid_block_type,
+ only_cross_attention,
+ block_out_channels,
+ layers_per_block,
+ downsample_padding,
+ mid_block_scale_factor,
+ act_fn,
+ norm_num_groups,
+ norm_eps,
+ cross_attention_dim,
+ transformer_layers_per_block,
+ encoder_hid_dim,
+ encoder_hid_dim_type,
+ attention_head_dim,
+ num_attention_heads,
+ use_linear_projection,
+ class_embed_type,
+ addition_embed_type,
+ addition_time_embed_dim,
+ num_class_embeds,
+ upcast_attention,
+ resnet_time_scale_shift,
+ projection_class_embeddings_input_dim,
+ controlnet_conditioning_channel_order,
+ conditioning_embedding_out_channels,
+ global_pool_conditions,
+ addition_embed_type_num_heads,
+ )
+ self.controlnet_query_cond_embedding = ControlNetConditioningEmbedding(
+ conditioning_embedding_channels=block_out_channels[0],
+ block_out_channels=conditioning_embedding_out_channels,
+ conditioning_channels=3,
+ )
+
+ def forward(
+ self,
+ sample: torch.Tensor,
+ timestep: Union[torch.Tensor, float, int],
+ encoder_hidden_states: torch.Tensor,
+ controlnet_cond: torch.Tensor,
+ controlnet_query_cond: torch.Tensor,
+ conditioning_scale: float = 1.0,
+ class_labels: Optional[torch.Tensor] = None,
+ timestep_cond: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ guess_mode: bool = False,
+ return_dict: bool = True,
+ ) -> Union[ControlNetOutput, Tuple[Tuple[torch.Tensor, ...], torch.Tensor]]:
+ """
+ The [`~PromptDiffusionControlNetModel`] forward method.
+
+ Args:
+ sample (`torch.Tensor`):
+ The noisy input tensor.
+ timestep (`Union[torch.Tensor, float, int]`):
+ The number of timesteps to denoise an input.
+ encoder_hidden_states (`torch.Tensor`):
+ The encoder hidden states.
+ controlnet_cond (`torch.Tensor`):
+ The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
+ controlnet_query_cond (`torch.Tensor`):
+ The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
+ conditioning_scale (`float`, defaults to `1.0`):
+ The scale factor for ControlNet outputs.
+ class_labels (`torch.Tensor`, *optional*, defaults to `None`):
+ Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
+ timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
+ Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the
+ timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep
+ embeddings.
+ attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
+ negative values to the attention scores corresponding to "discard" tokens.
+ added_cond_kwargs (`dict`):
+ Additional conditions for the Stable Diffusion XL UNet.
+ cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
+ A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
+ guess_mode (`bool`, defaults to `False`):
+ In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if
+ you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
+ return_dict (`bool`, defaults to `True`):
+ Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~models.controlnet.ControlNetOutput`] **or** `tuple`:
+ If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
+ returned where the first element is the sample tensor.
+ """
+ # check channel order
+ channel_order = self.config.controlnet_conditioning_channel_order
+
+ if channel_order == "rgb":
+ # in rgb order by default
+ ...
+ elif channel_order == "bgr":
+ controlnet_cond = torch.flip(controlnet_cond, dims=[1])
+ else:
+ raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}")
+
+ # prepare attention_mask
+ if attention_mask is not None:
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
+ attention_mask = attention_mask.unsqueeze(1)
+
+ # 1. time
+ timesteps = timestep
+ if not torch.is_tensor(timesteps):
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
+ # This would be a good case for the `match` statement (Python 3.10+)
+ is_mps = sample.device.type == "mps"
+ if isinstance(timestep, float):
+ dtype = torch.float32 if is_mps else torch.float64
+ else:
+ dtype = torch.int32 if is_mps else torch.int64
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
+ elif len(timesteps.shape) == 0:
+ timesteps = timesteps[None].to(sample.device)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timesteps = timesteps.expand(sample.shape[0])
+
+ t_emb = self.time_proj(timesteps)
+
+ # timesteps does not contain any weights and will always return f32 tensors
+ # but time_embedding might actually be running in fp16. so we need to cast here.
+ # there might be better ways to encapsulate this.
+ t_emb = t_emb.to(dtype=sample.dtype)
+
+ emb = self.time_embedding(t_emb, timestep_cond)
+ aug_emb = None
+
+ if self.class_embedding is not None:
+ if class_labels is None:
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
+
+ if self.config.class_embed_type == "timestep":
+ class_labels = self.time_proj(class_labels)
+
+ class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
+ emb = emb + class_emb
+
+ if self.config.addition_embed_type is not None:
+ if self.config.addition_embed_type == "text":
+ aug_emb = self.add_embedding(encoder_hidden_states)
+
+ elif self.config.addition_embed_type == "text_time":
+ if "text_embeds" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
+ )
+ text_embeds = added_cond_kwargs.get("text_embeds")
+ if "time_ids" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
+ )
+ time_ids = added_cond_kwargs.get("time_ids")
+ time_embeds = self.add_time_proj(time_ids.flatten())
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
+
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
+ add_embeds = add_embeds.to(emb.dtype)
+ aug_emb = self.add_embedding(add_embeds)
+
+ emb = emb + aug_emb if aug_emb is not None else emb
+
+ # 2. pre-process
+ sample = self.conv_in(sample)
+
+ controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
+ controlnet_query_cond = self.controlnet_query_cond_embedding(controlnet_query_cond)
+ sample = sample + controlnet_cond + controlnet_query_cond
+
+ # 3. down
+ down_block_res_samples = (sample,)
+ for downsample_block in self.down_blocks:
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
+ sample, res_samples = downsample_block(
+ hidden_states=sample,
+ temb=emb,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ cross_attention_kwargs=cross_attention_kwargs,
+ )
+ else:
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
+
+ down_block_res_samples += res_samples
+
+ # 4. mid
+ if self.mid_block is not None:
+ if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
+ sample = self.mid_block(
+ sample,
+ emb,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ cross_attention_kwargs=cross_attention_kwargs,
+ )
+ else:
+ sample = self.mid_block(sample, emb)
+
+ # 5. Control net blocks
+
+ controlnet_down_block_res_samples = ()
+
+ for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
+ down_block_res_sample = controlnet_block(down_block_res_sample)
+ controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,)
+
+ down_block_res_samples = controlnet_down_block_res_samples
+
+ mid_block_res_sample = self.controlnet_mid_block(sample)
+
+ # 6. scaling
+ if guess_mode and not self.config.global_pool_conditions:
+ scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0
+ scales = scales * conditioning_scale
+ down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
+ mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
+ else:
+ down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
+ mid_block_res_sample = mid_block_res_sample * conditioning_scale
+
+ if self.config.global_pool_conditions:
+ down_block_res_samples = [
+ torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples
+ ]
+ mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True)
+
+ if not return_dict:
+ return (down_block_res_samples, mid_block_res_sample)
+
+ return ControlNetOutput(
+ down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
+ )
diff --git a/diffusers/examples/research_projects/pytorch_xla/README.md b/diffusers/examples/research_projects/pytorch_xla/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..a6901d5ada9d6d3d1d7ed95bd3b74e6ac2ba5822
--- /dev/null
+++ b/diffusers/examples/research_projects/pytorch_xla/README.md
@@ -0,0 +1,167 @@
+# Stable Diffusion text-to-image fine-tuning using PyTorch/XLA
+
+The `train_text_to_image_xla.py` script shows how to fine-tune stable diffusion model on TPU devices using PyTorch/XLA.
+
+It has been tested on v4 and v5p TPU versions. Training code has been tested on multi-host.
+
+This script implements Distributed Data Parallel using GSPMD feature in XLA compiler
+where we shard the input batches over the TPU devices.
+
+As of 9-11-2024, these are some expected step times.
+
+| accelerator | global batch size | step time (seconds) |
+| ----------- | ----------------- | --------- |
+| v5p-128 | 1024 | 0.245 |
+| v5p-256 | 2048 | 0.234 |
+| v5p-512 | 4096 | 0.2498 |
+
+## Create TPU
+
+To create a TPU on Google Cloud first set these environment variables:
+
+```bash
+export TPU_NAME=
+export PROJECT_ID=
+export ZONE=
+export ACCELERATOR_TYPE=
+export RUNTIME_VERSION=
+```
+
+Then run the create TPU command:
+```bash
+gcloud alpha compute tpus tpu-vm create ${TPU_NAME} --project ${PROJECT_ID}
+--zone ${ZONE} --accelerator-type ${ACCELERATOR_TYPE} --version ${RUNTIME_VERSION}
+--reserved
+```
+
+You can also use other ways to reserve TPUs like GKE or queued resources.
+
+## Setup TPU environment
+
+Install PyTorch and PyTorch/XLA nightly versions:
+```bash
+gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
+--project=${PROJECT_ID} --zone=${ZONE} --worker=all \
+--command='
+pip3 install --pre torch==2.5.0.dev20240905+cpu torchvision==0.20.0.dev20240905+cpu --index-url https://download.pytorch.org/whl/nightly/cpu
+pip3 install "torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.5.0.dev20240905-cp310-cp310-linux_x86_64.whl" -f https://storage.googleapis.com/libtpu-releases/index.html
+'
+```
+
+Verify that PyTorch and PyTorch/XLA were installed correctly:
+
+```bash
+gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
+--project ${PROJECT_ID} --zone ${ZONE} --worker=all \
+--command='python3 -c "import torch; import torch_xla;"'
+```
+
+Install dependencies:
+```bash
+gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
+--project=${PROJECT_ID} --zone=${ZONE} --worker=all \
+--command='
+git clone https://github.com/huggingface/diffusers.git
+cd diffusers
+git checkout main
+cd examples/research_projects/pytorch_xla
+pip3 install -r requirements.txt
+pip3 install pillow --upgrade
+cd ../../..
+pip3 install .'
+```
+
+## Run the training job
+
+### Authenticate
+
+Run the following command to authenticate your token.
+
+```bash
+huggingface-cli login
+```
+
+This script only trains the unet part of the network. The VAE and text encoder
+are fixed.
+
+```bash
+gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
+--project=${PROJECT_ID} --zone=${ZONE} --worker=all \
+--command='
+export XLA_DISABLE_FUNCTIONALIZATION=1
+export PROFILE_DIR=/tmp/
+export CACHE_DIR=/tmp/
+export DATASET_NAME=lambdalabs/naruto-blip-captions
+export PER_HOST_BATCH_SIZE=32 # This is known to work on TPU v4. Can set this to 64 for TPU v5p
+export TRAIN_STEPS=50
+export OUTPUT_DIR=/tmp/trained-model/
+python diffusers/examples/research_projects/pytorch_xla/train_text_to_image_xla.py --pretrained_model_name_or_path=stabilityai/stable-diffusion-2-base --dataset_name=$DATASET_NAME --resolution=512 --center_crop --random_flip --train_batch_size=$PER_HOST_BATCH_SIZE --max_train_steps=$TRAIN_STEPS --learning_rate=1e-06 --mixed_precision=bf16 --profile_duration=80000 --output_dir=$OUTPUT_DIR --dataloader_num_workers=4 --loader_prefetch_size=4 --device_prefetch_size=4'
+
+```
+
+### Environment Envs Explained
+
+* `XLA_DISABLE_FUNCTIONALIZATION`: To optimize the performance for AdamW optimizer.
+* `PROFILE_DIR`: Specify where to put the profiling results.
+* `CACHE_DIR`: Directory to store XLA compiled graphs for persistent caching.
+* `DATASET_NAME`: Dataset to train the model.
+* `PER_HOST_BATCH_SIZE`: Size of the batch to load per CPU host. For e.g. for a v5p-16 with 2 CPU hosts, the global batch size will be 2xPER_HOST_BATCH_SIZE. The input batch is sharded along the batch axis.
+* `TRAIN_STEPS`: Total number of training steps to run the training for.
+* `OUTPUT_DIR`: Directory to store the fine-tuned model.
+
+## Run inference using the output model
+
+To run inference using the output, you can simply load the model and pass it
+input prompts. The first pass will compile the graph and takes longer with the following passes running much faster.
+
+```bash
+export CACHE_DIR=/tmp/
+```
+
+```python
+import torch
+import os
+import sys
+import numpy as np
+
+import torch_xla.core.xla_model as xm
+from time import time
+from diffusers import StableDiffusionPipeline
+import torch_xla.runtime as xr
+
+CACHE_DIR = os.environ.get("CACHE_DIR", None)
+if CACHE_DIR:
+ xr.initialize_cache(CACHE_DIR, readonly=False)
+
+def main():
+ device = xm.xla_device()
+ model_path = "jffacevedo/pxla_trained_model"
+ pipe = StableDiffusionPipeline.from_pretrained(
+ model_path,
+ torch_dtype=torch.bfloat16
+ )
+ pipe.to(device)
+ prompt = ["A naruto with green eyes and red legs."]
+ start = time()
+ print("compiling...")
+ image = pipe(prompt, num_inference_steps=30, guidance_scale=7.5).images[0]
+ print(f"compile time: {time() - start}")
+ print("generate...")
+ start = time()
+ image = pipe(prompt, num_inference_steps=30, guidance_scale=7.5).images[0]
+ print(f"generation time (after compile) : {time() - start}")
+ image.save("naruto.png")
+
+if __name__ == '__main__':
+ main()
+```
+
+Expected Results:
+
+```bash
+compiling...
+100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 30/30 [10:03<00:00, 20.10s/it]
+compile time: 720.656970500946
+generate...
+100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 30/30 [00:01<00:00, 17.65it/s]
+generation time (after compile) : 1.8461642265319824
\ No newline at end of file
diff --git a/diffusers/examples/research_projects/pytorch_xla/requirements.txt b/diffusers/examples/research_projects/pytorch_xla/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..c3ffa42f0edc9862345ab82b69ffb5cdfaeac45e
--- /dev/null
+++ b/diffusers/examples/research_projects/pytorch_xla/requirements.txt
@@ -0,0 +1,8 @@
+accelerate>=0.16.0
+torchvision
+transformers>=4.25.1
+datasets>=2.19.1
+ftfy
+tensorboard
+Jinja2
+peft==0.7.0
diff --git a/diffusers/examples/research_projects/pytorch_xla/train_text_to_image_xla.py b/diffusers/examples/research_projects/pytorch_xla/train_text_to_image_xla.py
new file mode 100644
index 0000000000000000000000000000000000000000..5d9d8c540f11e500e14c98575d75ab5c20d3327b
--- /dev/null
+++ b/diffusers/examples/research_projects/pytorch_xla/train_text_to_image_xla.py
@@ -0,0 +1,669 @@
+import argparse
+import os
+import random
+import time
+from pathlib import Path
+
+import datasets
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import torch_xla.core.xla_model as xm
+import torch_xla.debug.profiler as xp
+import torch_xla.distributed.parallel_loader as pl
+import torch_xla.distributed.spmd as xs
+import torch_xla.runtime as xr
+from huggingface_hub import create_repo, upload_folder
+from torchvision import transforms
+from transformers import CLIPTextModel, CLIPTokenizer
+
+from diffusers import (
+ AutoencoderKL,
+ DDPMScheduler,
+ StableDiffusionPipeline,
+ UNet2DConditionModel,
+)
+from diffusers.training_utils import compute_snr
+from diffusers.utils import is_wandb_available
+from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
+
+
+if is_wandb_available():
+ pass
+
+PROFILE_DIR = os.environ.get("PROFILE_DIR", None)
+CACHE_DIR = os.environ.get("CACHE_DIR", None)
+if CACHE_DIR:
+ xr.initialize_cache(CACHE_DIR, readonly=False)
+xr.use_spmd()
+DATASET_NAME_MAPPING = {
+ "lambdalabs/naruto-blip-captions": ("image", "text"),
+}
+PORT = 9012
+
+
+def save_model_card(
+ args,
+ repo_id: str,
+ repo_folder: str = None,
+):
+ model_description = f"""
+# Text-to-image finetuning - {repo_id}
+
+This pipeline was finetuned from **{args.pretrained_model_name_or_path}** on the **{args.dataset_name}** dataset. \n
+
+## Pipeline usage
+
+You can use the pipeline like so:
+
+```python
+import torch
+import os
+import sys
+import numpy as np
+
+import torch_xla.core.xla_model as xm
+from time import time
+from typing import Tuple
+from diffusers import StableDiffusionPipeline
+
+def main(args):
+ device = xm.xla_device()
+ model_path =
+ pipe = StableDiffusionPipeline.from_pretrained(
+ model_path,
+ torch_dtype=torch.bfloat16
+ )
+ pipe.to(device)
+ prompt = ["A naruto with green eyes and red legs."]
+ image = pipe(prompt, num_inference_steps=30, guidance_scale=7.5).images[0]
+ image.save("naruto.png")
+
+if __name__ == '__main__':
+ main()
+```
+
+## Training info
+
+These are the key hyperparameters used during training:
+
+* Steps: {args.max_train_steps}
+* Learning rate: {args.learning_rate}
+* Batch size: {args.train_batch_size}
+* Image resolution: {args.resolution}
+* Mixed-precision: {args.mixed_precision}
+
+"""
+
+ model_card = load_or_create_model_card(
+ repo_id_or_path=repo_id,
+ from_training=True,
+ license="creativeml-openrail-m",
+ base_model=args.pretrained_model_name_or_path,
+ model_description=model_description,
+ inference=True,
+ )
+
+ tags = ["stable-diffusion", "stable-diffusion-diffusers", "text-to-image", "diffusers", "diffusers-training"]
+ model_card = populate_model_card(model_card, tags=tags)
+
+ model_card.save(os.path.join(repo_folder, "README.md"))
+
+
+class TrainSD:
+ def __init__(
+ self,
+ vae,
+ weight_dtype,
+ device,
+ noise_scheduler,
+ unet,
+ optimizer,
+ text_encoder,
+ dataloader,
+ args,
+ ):
+ self.vae = vae
+ self.weight_dtype = weight_dtype
+ self.device = device
+ self.noise_scheduler = noise_scheduler
+ self.unet = unet
+ self.optimizer = optimizer
+ self.text_encoder = text_encoder
+ self.args = args
+ self.mesh = xs.get_global_mesh()
+ self.dataloader = iter(dataloader)
+ self.global_step = 0
+
+ def run_optimizer(self):
+ self.optimizer.step()
+
+ def start_training(self):
+ times = []
+ last_time = time.time()
+ step = 0
+ while True:
+ if self.global_step >= self.args.max_train_steps:
+ xm.mark_step()
+ break
+ if step == 4 and PROFILE_DIR is not None:
+ xm.wait_device_ops()
+ xp.trace_detached(f"localhost:{PORT}", PROFILE_DIR, duration_ms=args.profile_duration)
+ try:
+ batch = next(self.dataloader)
+ except Exception as e:
+ print(e)
+ break
+ loss = self.step_fn(batch["pixel_values"], batch["input_ids"])
+ step_time = time.time() - last_time
+ if step >= 10:
+ times.append(step_time)
+ print(f"step: {step}, step_time: {step_time}")
+ if step % 5 == 0:
+ print(f"step: {step}, loss: {loss}")
+ last_time = time.time()
+ self.global_step += 1
+ step += 1
+ # print(f"Average step time: {sum(times)/len(times)}")
+ xm.wait_device_ops()
+
+ def step_fn(
+ self,
+ pixel_values,
+ input_ids,
+ ):
+ with xp.Trace("model.forward"):
+ self.optimizer.zero_grad()
+ latents = self.vae.encode(pixel_values).latent_dist.sample()
+ latents = latents * self.vae.config.scaling_factor
+ noise = torch.randn_like(latents).to(self.device, dtype=self.weight_dtype)
+ bsz = latents.shape[0]
+ timesteps = torch.randint(
+ 0, self.noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device
+ )
+ timesteps = timesteps.long()
+
+ noisy_latents = self.noise_scheduler.add_noise(latents, noise, timesteps)
+ encoder_hidden_states = self.text_encoder(input_ids, return_dict=False)[0]
+ if self.args.prediction_type is not None:
+ # set prediction_type of scheduler if defined
+ self.noise_scheduler.register_to_config(prediction_type=self.args.prediction_type)
+
+ if self.noise_scheduler.config.prediction_type == "epsilon":
+ target = noise
+ elif self.noise_scheduler.config.prediction_type == "v_prediction":
+ target = self.noise_scheduler.get_velocity(latents, noise, timesteps)
+ else:
+ raise ValueError(f"Unknown prediction type {self.noise_scheduler.config.prediction_type}")
+ model_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states, return_dict=False)[0]
+ with xp.Trace("model.backward"):
+ if self.args.snr_gamma is None:
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
+ else:
+ # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
+ # Since we predict the noise instead of x_0, the original formulation is slightly changed.
+ # This is discussed in Section 4.2 of the same paper.
+ snr = compute_snr(self.noise_scheduler, timesteps)
+ mse_loss_weights = torch.stack([snr, self.args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(
+ dim=1
+ )[0]
+ if self.noise_scheduler.config.prediction_type == "epsilon":
+ mse_loss_weights = mse_loss_weights / snr
+ elif self.noise_scheduler.config.prediction_type == "v_prediction":
+ mse_loss_weights = mse_loss_weights / (snr + 1)
+
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
+ loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
+ loss = loss.mean()
+ loss.backward()
+ with xp.Trace("optimizer_step"):
+ self.run_optimizer()
+ return loss
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ parser.add_argument(
+ "--input_perturbation", type=float, default=0, help="The scale of input perturbation. Recommended 0.1."
+ )
+ parser.add_argument("--profile_duration", type=int, default=10000, help="Profile duration in ms")
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--variant",
+ type=str,
+ default=None,
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
+ )
+ parser.add_argument(
+ "--dataset_name",
+ type=str,
+ default=None,
+ help=(
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
+ " or to a folder containing files that 🤗 Datasets can understand."
+ ),
+ )
+ parser.add_argument(
+ "--dataset_config_name",
+ type=str,
+ default=None,
+ help="The config of the Dataset, leave as None if there's only one config.",
+ )
+ parser.add_argument(
+ "--train_data_dir",
+ type=str,
+ default=None,
+ help=(
+ "A folder containing the training data. Folder contents must follow the structure described in"
+ " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
+ " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
+ ),
+ )
+ parser.add_argument(
+ "--image_column", type=str, default="image", help="The column of the dataset containing an image."
+ )
+ parser.add_argument(
+ "--caption_column",
+ type=str,
+ default="text",
+ help="The column of the dataset containing a caption or a list of captions.",
+ )
+ parser.add_argument(
+ "--max_train_samples",
+ type=int,
+ default=None,
+ help=(
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ ),
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="sd-model-finetuned",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--center_crop",
+ default=False,
+ action="store_true",
+ help=(
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
+ " cropped. The images will be resized to the resolution first before cropping."
+ ),
+ )
+ parser.add_argument(
+ "--random_flip",
+ action="store_true",
+ help="whether to randomly flip images horizontally",
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--snr_gamma",
+ type=float,
+ default=None,
+ help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
+ "More details here: https://arxiv.org/abs/2303.09556.",
+ )
+ parser.add_argument(
+ "--non_ema_revision",
+ type=str,
+ default=None,
+ required=False,
+ help=(
+ "Revision of pretrained non-ema model identifier. Must be a branch, tag or git identifier of the local or"
+ " remote repository specified with --pretrained_model_name_or_path."
+ ),
+ )
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ parser.add_argument(
+ "--loader_prefetch_size",
+ type=int,
+ default=1,
+ help=("Number of subprocesses to use for data loading to cpu."),
+ )
+ parser.add_argument(
+ "--device_prefetch_size",
+ type=int,
+ default=1,
+ help=("Number of subprocesses to use for data loading to tpu from cpu. "),
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument(
+ "--prediction_type",
+ type=str,
+ default=None,
+ help="The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediction_type` is chosen.",
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+
+ args = parser.parse_args()
+
+ # default to using the same revision for the non-ema model if not specified
+ if args.non_ema_revision is None:
+ args.non_ema_revision = args.revision
+
+ return args
+
+
+def setup_optimizer(unet, args):
+ optimizer_cls = torch.optim.AdamW
+ return optimizer_cls(
+ unet.parameters(),
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ foreach=True,
+ )
+
+
+def load_dataset(args):
+ if args.dataset_name is not None:
+ # Downloading and loading a dataset from the hub.
+ dataset = datasets.load_dataset(
+ args.dataset_name,
+ args.dataset_config_name,
+ cache_dir=args.cache_dir,
+ data_dir=args.train_data_dir,
+ )
+ else:
+ data_files = {}
+ if args.train_data_dir is not None:
+ data_files["train"] = os.path.join(args.train_data_dir, "**")
+ dataset = datasets.load_dataset(
+ "imagefolder",
+ data_files=data_files,
+ cache_dir=args.cache_dir,
+ )
+ return dataset
+
+
+def get_column_names(dataset, args):
+ column_names = dataset["train"].column_names
+
+ dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None)
+ if args.image_column is None:
+ image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
+ else:
+ image_column = args.image_column
+ if image_column not in column_names:
+ raise ValueError(
+ f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}"
+ )
+ if args.caption_column is None:
+ caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
+ else:
+ caption_column = args.caption_column
+ if caption_column not in column_names:
+ raise ValueError(
+ f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}"
+ )
+ return image_column, caption_column
+
+
+def main(args):
+ args = parse_args()
+
+ _ = xp.start_server(PORT)
+
+ num_devices = xr.global_runtime_device_count()
+ device_ids = np.arange(num_devices)
+ mesh_shape = (num_devices, 1)
+ mesh = xs.Mesh(device_ids, mesh_shape, ("x", "y"))
+ xs.set_global_mesh(mesh)
+
+ text_encoder = CLIPTextModel.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="text_encoder",
+ revision=args.revision,
+ variant=args.variant,
+ )
+ vae = AutoencoderKL.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="vae",
+ revision=args.revision,
+ variant=args.variant,
+ )
+
+ unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="unet",
+ revision=args.non_ema_revision,
+ )
+
+ if xm.is_master_ordinal() and args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
+ ).repo_id
+
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
+ tokenizer = CLIPTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="tokenizer",
+ revision=args.revision,
+ )
+
+ from torch_xla.distributed.fsdp.utils import apply_xla_patch_to_nn_linear
+
+ unet = apply_xla_patch_to_nn_linear(unet, xs.xla_patched_nn_linear_forward)
+
+ vae.requires_grad_(False)
+ text_encoder.requires_grad_(False)
+ unet.train()
+
+ # For mixed precision training we cast all non-trainable weights (vae,
+ # non-lora text_encoder and non-lora unet) to half-precision
+ # as these weights are only used for inference, keeping weights in full
+ # precision is not required.
+ weight_dtype = torch.float32
+ if args.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif args.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ device = xm.xla_device()
+ print("device: ", device)
+ print("weight_dtype: ", weight_dtype)
+
+ text_encoder = text_encoder.to(device, dtype=weight_dtype)
+ vae = vae.to(device, dtype=weight_dtype)
+ unet = unet.to(device, dtype=weight_dtype)
+ optimizer = setup_optimizer(unet, args)
+ vae.requires_grad_(False)
+ text_encoder.requires_grad_(False)
+ unet.train()
+
+ dataset = load_dataset(args)
+ image_column, caption_column = get_column_names(dataset, args)
+
+ def tokenize_captions(examples, is_train=True):
+ captions = []
+ for caption in examples[caption_column]:
+ if isinstance(caption, str):
+ captions.append(caption)
+ elif isinstance(caption, (list, np.ndarray)):
+ # take a random caption if there are multiple
+ captions.append(random.choice(caption) if is_train else caption[0])
+ else:
+ raise ValueError(
+ f"Caption column `{caption_column}` should contain either strings or lists of strings."
+ )
+ inputs = tokenizer(
+ captions,
+ max_length=tokenizer.model_max_length,
+ padding="max_length",
+ truncation=True,
+ return_tensors="pt",
+ )
+ return inputs.input_ids
+
+ train_transforms = transforms.Compose(
+ [
+ transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
+ (transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution)),
+ (transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x)),
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+ )
+
+ def preprocess_train(examples):
+ images = [image.convert("RGB") for image in examples[image_column]]
+ examples["pixel_values"] = [train_transforms(image) for image in images]
+ examples["input_ids"] = tokenize_captions(examples)
+ return examples
+
+ train_dataset = dataset["train"]
+ train_dataset.set_format("torch")
+ train_dataset.set_transform(preprocess_train)
+
+ def collate_fn(examples):
+ pixel_values = torch.stack([example["pixel_values"] for example in examples])
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).to(weight_dtype)
+ input_ids = torch.stack([example["input_ids"] for example in examples])
+ return {"pixel_values": pixel_values, "input_ids": input_ids}
+
+ g = torch.Generator()
+ g.manual_seed(xr.host_index())
+ sampler = torch.utils.data.RandomSampler(train_dataset, replacement=True, num_samples=int(1e10), generator=g)
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset,
+ sampler=sampler,
+ collate_fn=collate_fn,
+ num_workers=args.dataloader_num_workers,
+ batch_size=args.train_batch_size,
+ )
+
+ train_dataloader = pl.MpDeviceLoader(
+ train_dataloader,
+ device,
+ input_sharding={
+ "pixel_values": xs.ShardingSpec(mesh, ("x", None, None, None), minibatch=True),
+ "input_ids": xs.ShardingSpec(mesh, ("x", None), minibatch=True),
+ },
+ loader_prefetch_size=args.loader_prefetch_size,
+ device_prefetch_size=args.device_prefetch_size,
+ )
+
+ if xm.is_master_ordinal():
+ print("***** Running training *****")
+ print(f"Instantaneous batch size per device = {args.train_batch_size}")
+ print(
+ f"Total train batch size (w. parallel, distributed & accumulation) = {args.train_batch_size * num_devices}"
+ )
+ print(f" Total optimization steps = {args.max_train_steps}")
+
+ trainer = TrainSD(
+ vae=vae,
+ weight_dtype=weight_dtype,
+ device=device,
+ noise_scheduler=noise_scheduler,
+ unet=unet,
+ optimizer=optimizer,
+ text_encoder=text_encoder,
+ dataloader=train_dataloader,
+ args=args,
+ )
+
+ trainer.start_training()
+ unet = trainer.unet.to("cpu")
+ vae = trainer.vae.to("cpu")
+ text_encoder = trainer.text_encoder.to("cpu")
+
+ pipeline = StableDiffusionPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ text_encoder=text_encoder,
+ vae=vae,
+ unet=unet,
+ revision=args.revision,
+ variant=args.variant,
+ )
+ pipeline.save_pretrained(args.output_dir)
+
+ if xm.is_master_ordinal() and args.push_to_hub:
+ save_model_card(args, repo_id, repo_folder=args.output_dir)
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ main(args)
diff --git a/diffusers/examples/research_projects/rdm/README.md b/diffusers/examples/research_projects/rdm/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..cfd755e9c9c3d030953fdc0c40834c0d5f51c04d
--- /dev/null
+++ b/diffusers/examples/research_projects/rdm/README.md
@@ -0,0 +1,5 @@
+## Diffusers examples with ONNXRuntime optimizations
+
+**This research project is not actively maintained by the diffusers team. For any questions or comments, please contact Isamu Isozaki(isamu-isozaki) on github with any questions.**
+
+The aim of this project is to provide retrieval augmented diffusion models to diffusers!
\ No newline at end of file
diff --git a/diffusers/examples/research_projects/rdm/pipeline_rdm.py b/diffusers/examples/research_projects/rdm/pipeline_rdm.py
new file mode 100644
index 0000000000000000000000000000000000000000..f8093a3f217dd4e5b9e0e80da40c199fd461c8fa
--- /dev/null
+++ b/diffusers/examples/research_projects/rdm/pipeline_rdm.py
@@ -0,0 +1,344 @@
+import inspect
+from typing import Callable, List, Optional, Union
+
+import torch
+from PIL import Image
+from retriever import Retriever, normalize_images, preprocess_images
+from transformers import CLIPImageProcessor, CLIPModel, CLIPTokenizer
+
+from diffusers import (
+ AutoencoderKL,
+ DDIMScheduler,
+ DiffusionPipeline,
+ DPMSolverMultistepScheduler,
+ EulerAncestralDiscreteScheduler,
+ EulerDiscreteScheduler,
+ ImagePipelineOutput,
+ LMSDiscreteScheduler,
+ PNDMScheduler,
+ UNet2DConditionModel,
+)
+from diffusers.image_processor import VaeImageProcessor
+from diffusers.pipelines.pipeline_utils import StableDiffusionMixin
+from diffusers.utils import logging
+from diffusers.utils.torch_utils import randn_tensor
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class RDMPipeline(DiffusionPipeline, StableDiffusionMixin):
+ r"""
+ Pipeline for text-to-image generation using Retrieval Augmented Diffusion.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ clip ([`CLIPModel`]):
+ Frozen CLIP model. Retrieval Augmented Diffusion uses the CLIP model, specifically the
+ [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ feature_extractor ([`CLIPImageProcessor`]):
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+ """
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ clip: CLIPModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: Union[
+ DDIMScheduler,
+ PNDMScheduler,
+ LMSDiscreteScheduler,
+ EulerDiscreteScheduler,
+ EulerAncestralDiscreteScheduler,
+ DPMSolverMultistepScheduler,
+ ],
+ feature_extractor: CLIPImageProcessor,
+ retriever: Optional[Retriever] = None,
+ ):
+ super().__init__()
+ self.register_modules(
+ vae=vae,
+ clip=clip,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ feature_extractor=feature_extractor,
+ )
+ # Copy from statement here and all the methods we take from stable_diffusion_pipeline
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+ self.retriever = retriever
+
+ def _encode_prompt(self, prompt):
+ # get prompt text embeddings
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+
+ if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
+ removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+ text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
+ prompt_embeds = self.clip.get_text_features(text_input_ids.to(self.device))
+ prompt_embeds = prompt_embeds / torch.linalg.norm(prompt_embeds, dim=-1, keepdim=True)
+ prompt_embeds = prompt_embeds[:, None, :]
+ return prompt_embeds
+
+ def _encode_image(self, retrieved_images, batch_size):
+ if len(retrieved_images[0]) == 0:
+ return None
+ for i in range(len(retrieved_images)):
+ retrieved_images[i] = normalize_images(retrieved_images[i])
+ retrieved_images[i] = preprocess_images(retrieved_images[i], self.feature_extractor).to(
+ self.clip.device, dtype=self.clip.dtype
+ )
+ _, c, h, w = retrieved_images[0].shape
+
+ retrieved_images = torch.reshape(torch.cat(retrieved_images, dim=0), (-1, c, h, w))
+ image_embeddings = self.clip.get_image_features(retrieved_images)
+ image_embeddings = image_embeddings / torch.linalg.norm(image_embeddings, dim=-1, keepdim=True)
+ _, d = image_embeddings.shape
+ image_embeddings = torch.reshape(image_embeddings, (batch_size, -1, d))
+ return image_embeddings
+
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
+ shape = (
+ batch_size,
+ num_channels_latents,
+ int(height) // self.vae_scale_factor,
+ int(width) // self.vae_scale_factor,
+ )
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ def retrieve_images(self, retrieved_images, prompt_embeds, knn=10):
+ if self.retriever is not None:
+ additional_images = self.retriever.retrieve_imgs_batch(prompt_embeds[:, 0].cpu(), knn).total_examples
+ for i in range(len(retrieved_images)):
+ retrieved_images[i] += additional_images[i][self.retriever.config.image_column]
+ return retrieved_images
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]],
+ retrieved_images: Optional[List[Image.Image]] = None,
+ height: int = 768,
+ width: int = 768,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[torch.Generator] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
+ callback_steps: Optional[int] = 1,
+ knn: Optional[int] = 10,
+ **kwargs,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image generation.
+ height (`int`, *optional*, defaults to 512):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to 512):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator`, *optional*):
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
+ deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+
+ Returns:
+ [`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if
+ `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
+ generated images.
+ """
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
+ if isinstance(prompt, str):
+ batch_size = 1
+ elif isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+ if retrieved_images is not None:
+ retrieved_images = [retrieved_images for _ in range(batch_size)]
+ else:
+ retrieved_images = [[] for _ in range(batch_size)]
+ device = self._execution_device
+
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+ if prompt_embeds is None:
+ prompt_embeds = self._encode_prompt(prompt)
+ retrieved_images = self.retrieve_images(retrieved_images, prompt_embeds, knn=knn)
+ image_embeddings = self._encode_image(retrieved_images, batch_size)
+ if image_embeddings is not None:
+ prompt_embeds = torch.cat([prompt_embeds, image_embeddings], dim=1)
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance:
+ uncond_embeddings = torch.zeros_like(prompt_embeds).to(prompt_embeds.device)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ prompt_embeds = torch.cat([uncond_embeddings, prompt_embeds])
+ # get the initial random noise unless the user supplied it
+ num_channels_latents = self.unet.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # set timesteps
+ self.scheduler.set_timesteps(num_inference_steps)
+
+ # Some schedulers like PNDM have timesteps as arrays
+ # It's more optimized to move all timesteps to correct device beforehand
+ timesteps_tensor = self.scheduler.timesteps.to(self.device)
+
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+
+ for i, t in enumerate(self.progress_bar(timesteps_tensor)):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds).sample
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+ # call the callback, if provided
+ if callback is not None and i % callback_steps == 0:
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
+ if not output_type == "latent":
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+ else:
+ image = latents
+
+ image = self.image_processor.postprocess(
+ image, output_type=output_type, do_denormalize=[True] * image.shape[0]
+ )
+
+ # Offload last model to CPU
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.final_offload_hook.offload()
+
+ if not return_dict:
+ return (image,)
+
+ return ImagePipelineOutput(images=image)
diff --git a/diffusers/examples/research_projects/rdm/retriever.py b/diffusers/examples/research_projects/rdm/retriever.py
new file mode 100644
index 0000000000000000000000000000000000000000..4ae4989bd8bb26c47622ae33c950241c8b49c10c
--- /dev/null
+++ b/diffusers/examples/research_projects/rdm/retriever.py
@@ -0,0 +1,248 @@
+import os
+from typing import List
+
+import faiss
+import numpy as np
+import torch
+from datasets import Dataset, load_dataset
+from PIL import Image
+from transformers import CLIPImageProcessor, CLIPModel, PretrainedConfig
+
+from diffusers import logging
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+def normalize_images(images: List[Image.Image]):
+ images = [np.array(image) for image in images]
+ images = [image / 127.5 - 1 for image in images]
+ return images
+
+
+def preprocess_images(images: List[np.array], feature_extractor: CLIPImageProcessor) -> torch.Tensor:
+ """
+ Preprocesses a list of images into a batch of tensors.
+
+ Args:
+ images (:obj:`List[Image.Image]`):
+ A list of images to preprocess.
+
+ Returns:
+ :obj:`torch.Tensor`: A batch of tensors.
+ """
+ images = [np.array(image) for image in images]
+ images = [(image + 1.0) / 2.0 for image in images]
+ images = feature_extractor(images, return_tensors="pt").pixel_values
+ return images
+
+
+class IndexConfig(PretrainedConfig):
+ def __init__(
+ self,
+ clip_name_or_path="openai/clip-vit-large-patch14",
+ dataset_name="Isamu136/oxford_pets_with_l14_emb",
+ image_column="image",
+ index_name="embeddings",
+ index_path=None,
+ dataset_set="train",
+ metric_type=faiss.METRIC_L2,
+ faiss_device=-1,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.clip_name_or_path = clip_name_or_path
+ self.dataset_name = dataset_name
+ self.image_column = image_column
+ self.index_name = index_name
+ self.index_path = index_path
+ self.dataset_set = dataset_set
+ self.metric_type = metric_type
+ self.faiss_device = faiss_device
+
+
+class Index:
+ """
+ Each index for a retrieval model is specific to the clip model used and the dataset used.
+ """
+
+ def __init__(self, config: IndexConfig, dataset: Dataset):
+ self.config = config
+ self.dataset = dataset
+ self.index_initialized = False
+ self.index_name = config.index_name
+ self.index_path = config.index_path
+ self.init_index()
+
+ def set_index_name(self, index_name: str):
+ self.index_name = index_name
+
+ def init_index(self):
+ if not self.index_initialized:
+ if self.index_path and self.index_name:
+ try:
+ self.dataset.add_faiss_index(
+ column=self.index_name, metric_type=self.config.metric_type, device=self.config.faiss_device
+ )
+ self.index_initialized = True
+ except Exception as e:
+ print(e)
+ logger.info("Index not initialized")
+ if self.index_name in self.dataset.features:
+ self.dataset.add_faiss_index(column=self.index_name)
+ self.index_initialized = True
+
+ def build_index(
+ self,
+ model=None,
+ feature_extractor: CLIPImageProcessor = None,
+ torch_dtype=torch.float32,
+ ):
+ if not self.index_initialized:
+ model = model or CLIPModel.from_pretrained(self.config.clip_name_or_path).to(dtype=torch_dtype)
+ feature_extractor = feature_extractor or CLIPImageProcessor.from_pretrained(self.config.clip_name_or_path)
+ self.dataset = get_dataset_with_emb_from_clip_model(
+ self.dataset,
+ model,
+ feature_extractor,
+ image_column=self.config.image_column,
+ index_name=self.config.index_name,
+ )
+ self.init_index()
+
+ def retrieve_imgs(self, vec, k: int = 20):
+ vec = np.array(vec).astype(np.float32)
+ return self.dataset.get_nearest_examples(self.index_name, vec, k=k)
+
+ def retrieve_imgs_batch(self, vec, k: int = 20):
+ vec = np.array(vec).astype(np.float32)
+ return self.dataset.get_nearest_examples_batch(self.index_name, vec, k=k)
+
+ def retrieve_indices(self, vec, k: int = 20):
+ vec = np.array(vec).astype(np.float32)
+ return self.dataset.search(self.index_name, vec, k=k)
+
+ def retrieve_indices_batch(self, vec, k: int = 20):
+ vec = np.array(vec).astype(np.float32)
+ return self.dataset.search_batch(self.index_name, vec, k=k)
+
+
+class Retriever:
+ def __init__(
+ self,
+ config: IndexConfig,
+ index: Index = None,
+ dataset: Dataset = None,
+ model=None,
+ feature_extractor: CLIPImageProcessor = None,
+ ):
+ self.config = config
+ self.index = index or self._build_index(config, dataset, model=model, feature_extractor=feature_extractor)
+
+ @classmethod
+ def from_pretrained(
+ cls,
+ retriever_name_or_path: str,
+ index: Index = None,
+ dataset: Dataset = None,
+ model=None,
+ feature_extractor: CLIPImageProcessor = None,
+ **kwargs,
+ ):
+ config = kwargs.pop("config", None) or IndexConfig.from_pretrained(retriever_name_or_path, **kwargs)
+ return cls(config, index=index, dataset=dataset, model=model, feature_extractor=feature_extractor)
+
+ @staticmethod
+ def _build_index(
+ config: IndexConfig, dataset: Dataset = None, model=None, feature_extractor: CLIPImageProcessor = None
+ ):
+ dataset = dataset or load_dataset(config.dataset_name)
+ dataset = dataset[config.dataset_set]
+ index = Index(config, dataset)
+ index.build_index(model=model, feature_extractor=feature_extractor)
+ return index
+
+ def save_pretrained(self, save_directory):
+ os.makedirs(save_directory, exist_ok=True)
+ if self.config.index_path is None:
+ index_path = os.path.join(save_directory, "hf_dataset_index.faiss")
+ self.index.dataset.get_index(self.config.index_name).save(index_path)
+ self.config.index_path = index_path
+ self.config.save_pretrained(save_directory)
+
+ def init_retrieval(self):
+ logger.info("initializing retrieval")
+ self.index.init_index()
+
+ def retrieve_imgs(self, embeddings: np.ndarray, k: int):
+ return self.index.retrieve_imgs(embeddings, k)
+
+ def retrieve_imgs_batch(self, embeddings: np.ndarray, k: int):
+ return self.index.retrieve_imgs_batch(embeddings, k)
+
+ def retrieve_indices(self, embeddings: np.ndarray, k: int):
+ return self.index.retrieve_indices(embeddings, k)
+
+ def retrieve_indices_batch(self, embeddings: np.ndarray, k: int):
+ return self.index.retrieve_indices_batch(embeddings, k)
+
+ def __call__(
+ self,
+ embeddings,
+ k: int = 20,
+ ):
+ return self.index.retrieve_imgs(embeddings, k)
+
+
+def map_txt_to_clip_feature(clip_model, tokenizer, prompt):
+ text_inputs = tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=tokenizer.model_max_length,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+
+ if text_input_ids.shape[-1] > tokenizer.model_max_length:
+ removed_text = tokenizer.batch_decode(text_input_ids[:, tokenizer.model_max_length :])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {tokenizer.model_max_length} tokens: {removed_text}"
+ )
+ text_input_ids = text_input_ids[:, : tokenizer.model_max_length]
+ text_embeddings = clip_model.get_text_features(text_input_ids.to(clip_model.device))
+ text_embeddings = text_embeddings / torch.linalg.norm(text_embeddings, dim=-1, keepdim=True)
+ text_embeddings = text_embeddings[:, None, :]
+ return text_embeddings[0][0].cpu().detach().numpy()
+
+
+def map_img_to_model_feature(model, feature_extractor, imgs, device):
+ for i, image in enumerate(imgs):
+ if not image.mode == "RGB":
+ imgs[i] = image.convert("RGB")
+ imgs = normalize_images(imgs)
+ retrieved_images = preprocess_images(imgs, feature_extractor).to(device)
+ image_embeddings = model(retrieved_images)
+ image_embeddings = image_embeddings / torch.linalg.norm(image_embeddings, dim=-1, keepdim=True)
+ image_embeddings = image_embeddings[None, ...]
+ return image_embeddings.cpu().detach().numpy()[0][0]
+
+
+def get_dataset_with_emb_from_model(dataset, model, feature_extractor, image_column="image", index_name="embeddings"):
+ return dataset.map(
+ lambda example: {
+ index_name: map_img_to_model_feature(model, feature_extractor, [example[image_column]], model.device)
+ }
+ )
+
+
+def get_dataset_with_emb_from_clip_model(
+ dataset, clip_model, feature_extractor, image_column="image", index_name="embeddings"
+):
+ return dataset.map(
+ lambda example: {
+ index_name: map_img_to_model_feature(
+ clip_model.get_image_features, feature_extractor, [example[image_column]], clip_model.device
+ )
+ }
+ )
diff --git a/diffusers/examples/research_projects/realfill/README.md b/diffusers/examples/research_projects/realfill/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..91821031d2e0d1fe3f217aed44d235f64531d43b
--- /dev/null
+++ b/diffusers/examples/research_projects/realfill/README.md
@@ -0,0 +1,118 @@
+# RealFill
+
+[RealFill](https://arxiv.org/abs/2309.16668) is a method to personalize text2image inpainting models like stable diffusion inpainting given just a few(1~5) images of a scene.
+The `train_realfill.py` script shows how to implement the training procedure for stable diffusion inpainting.
+
+
+## Running locally with PyTorch
+
+### Installing the dependencies
+
+Before running the scripts, make sure to install the library's training dependencies:
+
+cd to the realfill folder and run
+```bash
+cd realfill
+pip install -r requirements.txt
+```
+
+And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
+
+```bash
+accelerate config
+```
+
+Or for a default accelerate configuration without answering questions about your environment
+
+```bash
+accelerate config default
+```
+
+Or if your environment doesn't support an interactive shell e.g. a notebook
+
+```python
+from accelerate.utils import write_basic_config
+write_basic_config()
+```
+
+When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
+
+### Toy example
+
+Now let's fill the real. For this example, we will use some images of the flower girl example from the paper.
+
+We already provide some images for testing in [this link](https://github.com/thuanz123/realfill/tree/main/data/flowerwoman)
+
+You only have to launch the training using:
+
+```bash
+export MODEL_NAME="stabilityai/stable-diffusion-2-inpainting"
+export TRAIN_DIR="data/flowerwoman"
+export OUTPUT_DIR="flowerwoman-model"
+
+accelerate launch train_realfill.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --train_data_dir=$TRAIN_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --resolution=512 \
+ --train_batch_size=16 \
+ --gradient_accumulation_steps=1 \
+ --unet_learning_rate=2e-4 \
+ --text_encoder_learning_rate=4e-5 \
+ --lr_scheduler="constant" \
+ --lr_warmup_steps=100 \
+ --max_train_steps=2000 \
+ --lora_rank=8 \
+ --lora_dropout=0.1 \
+ --lora_alpha=16 \
+```
+
+### Training on a low-memory GPU:
+
+It is possible to run realfill on a low-memory GPU by using the following optimizations:
+- [gradient checkpointing and the 8-bit optimizer](#training-with-gradient-checkpointing-and-8-bit-optimizers)
+- [xformers](#training-with-xformers)
+- [setting grads to none](#set-grads-to-none)
+
+```bash
+export MODEL_NAME="stabilityai/stable-diffusion-2-inpainting"
+export TRAIN_DIR="data/flowerwoman"
+export OUTPUT_DIR="flowerwoman-model"
+
+accelerate launch train_realfill.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --train_data_dir=$TRAIN_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --resolution=512 \
+ --train_batch_size=16 \
+ --gradient_accumulation_steps=1 --gradient_checkpointing \
+ --use_8bit_adam \
+ --enable_xformers_memory_efficient_attention \
+ --set_grads_to_none \
+ --unet_learning_rate=2e-4 \
+ --text_encoder_learning_rate=4e-5 \
+ --lr_scheduler="constant" \
+ --lr_warmup_steps=100 \
+ --max_train_steps=2000 \
+ --lora_rank=8 \
+ --lora_dropout=0.1 \
+ --lora_alpha=16 \
+```
+
+### Training with gradient checkpointing and 8-bit optimizers:
+
+With the help of gradient checkpointing and the 8-bit optimizer from bitsandbytes it's possible to run train realfill on a 16GB GPU.
+
+To install `bitsandbytes` please refer to this [readme](https://github.com/TimDettmers/bitsandbytes#requirements--installation).
+
+### Training with xformers:
+You can enable memory efficient attention by [installing xFormers](https://github.com/facebookresearch/xformers#installing-xformers) and padding the `--enable_xformers_memory_efficient_attention` argument to the script.
+
+### Set grads to none
+
+To save even more memory, pass the `--set_grads_to_none` argument to the script. This will set grads to None instead of zero. However, be aware that it changes certain behaviors, so if you start experiencing any problems, remove this argument.
+
+More info: https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html
+
+## Acknowledge
+This repo is built upon the code of DreamBooth from diffusers and we thank the developers for their great works and efforts to release source code. Furthermore, a special "thank you" to RealFill's authors for publishing such an amazing work.
diff --git a/diffusers/examples/research_projects/realfill/infer.py b/diffusers/examples/research_projects/realfill/infer.py
new file mode 100644
index 0000000000000000000000000000000000000000..3153307c4ad31892df7607d3ecf9b2af79b65187
--- /dev/null
+++ b/diffusers/examples/research_projects/realfill/infer.py
@@ -0,0 +1,91 @@
+import argparse
+import os
+
+import torch
+from PIL import Image, ImageFilter
+from transformers import CLIPTextModel
+
+from diffusers import DPMSolverMultistepScheduler, StableDiffusionInpaintPipeline, UNet2DConditionModel
+
+
+parser = argparse.ArgumentParser(description="Inference")
+parser.add_argument(
+ "--model_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+)
+parser.add_argument(
+ "--validation_image",
+ type=str,
+ default=None,
+ required=True,
+ help="The directory of the validation image",
+)
+parser.add_argument(
+ "--validation_mask",
+ type=str,
+ default=None,
+ required=True,
+ help="The directory of the validation mask",
+)
+parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="./test-infer/",
+ help="The output directory where predictions are saved",
+)
+parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible inference.")
+
+args = parser.parse_args()
+
+if __name__ == "__main__":
+ os.makedirs(args.output_dir, exist_ok=True)
+ generator = None
+
+ # create & load model
+ pipe = StableDiffusionInpaintPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-2-inpainting", torch_dtype=torch.float32, revision=None
+ )
+
+ pipe.unet = UNet2DConditionModel.from_pretrained(
+ args.model_path,
+ subfolder="unet",
+ revision=None,
+ )
+ pipe.text_encoder = CLIPTextModel.from_pretrained(
+ args.model_path,
+ subfolder="text_encoder",
+ revision=None,
+ )
+ pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
+ pipe = pipe.to("cuda")
+
+ if args.seed is not None:
+ generator = torch.Generator(device="cuda").manual_seed(args.seed)
+
+ image = Image.open(args.validation_image)
+ mask_image = Image.open(args.validation_mask)
+
+ results = pipe(
+ ["a photo of sks"] * 16,
+ image=image,
+ mask_image=mask_image,
+ num_inference_steps=25,
+ guidance_scale=5,
+ generator=generator,
+ ).images
+
+ erode_kernel = ImageFilter.MaxFilter(3)
+ mask_image = mask_image.filter(erode_kernel)
+
+ blur_kernel = ImageFilter.BoxBlur(1)
+ mask_image = mask_image.filter(blur_kernel)
+
+ for idx, result in enumerate(results):
+ result = Image.composite(result, image, mask_image)
+ result.save(f"{args.output_dir}/{idx}.png")
+
+ del pipe
+ torch.cuda.empty_cache()
diff --git a/diffusers/examples/research_projects/realfill/requirements.txt b/diffusers/examples/research_projects/realfill/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..8fbaf908a2c887e77b58d8f339f0e0643892f238
--- /dev/null
+++ b/diffusers/examples/research_projects/realfill/requirements.txt
@@ -0,0 +1,9 @@
+diffusers==0.20.1
+accelerate==0.23.0
+transformers==4.38.0
+peft==0.5.0
+torch==2.2.0
+torchvision>=0.16
+ftfy==6.1.1
+tensorboard==2.14.0
+Jinja2==3.1.4
diff --git a/diffusers/examples/research_projects/realfill/train_realfill.py b/diffusers/examples/research_projects/realfill/train_realfill.py
new file mode 100644
index 0000000000000000000000000000000000000000..c7cc25df02b97c5cbd5705defa5c3d706a3d51c4
--- /dev/null
+++ b/diffusers/examples/research_projects/realfill/train_realfill.py
@@ -0,0 +1,984 @@
+import argparse
+import copy
+import itertools
+import logging
+import math
+import os
+import random
+import shutil
+from pathlib import Path
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import torchvision.transforms.v2 as transforms_v2
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import set_seed
+from huggingface_hub import create_repo, upload_folder
+from packaging import version
+from peft import LoraConfig, PeftModel, get_peft_model
+from PIL import Image
+from PIL.ImageOps import exif_transpose
+from torch.utils.data import Dataset
+from tqdm.auto import tqdm
+from transformers import AutoTokenizer, CLIPTextModel
+
+import diffusers
+from diffusers import (
+ AutoencoderKL,
+ DDPMScheduler,
+ DPMSolverMultistepScheduler,
+ StableDiffusionInpaintPipeline,
+ UNet2DConditionModel,
+)
+from diffusers.optimization import get_scheduler
+from diffusers.utils import check_min_version, is_wandb_available
+from diffusers.utils.import_utils import is_xformers_available
+
+
+if is_wandb_available():
+ import wandb
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.20.1")
+
+logger = get_logger(__name__)
+
+
+def make_mask(images, resolution, times=30):
+ mask, times = torch.ones_like(images[0:1, :, :]), np.random.randint(1, times)
+ min_size, max_size, margin = np.array([0.03, 0.25, 0.01]) * resolution
+ max_size = min(max_size, resolution - margin * 2)
+
+ for _ in range(times):
+ width = np.random.randint(int(min_size), int(max_size))
+ height = np.random.randint(int(min_size), int(max_size))
+
+ x_start = np.random.randint(int(margin), resolution - int(margin) - width + 1)
+ y_start = np.random.randint(int(margin), resolution - int(margin) - height + 1)
+ mask[:, y_start : y_start + height, x_start : x_start + width] = 0
+
+ mask = 1 - mask if random.random() < 0.5 else mask
+ return mask
+
+
+def save_model_card(
+ repo_id: str,
+ images=None,
+ base_model=str,
+ repo_folder=None,
+):
+ img_str = ""
+ for i, image in enumerate(images):
+ image.save(os.path.join(repo_folder, f"image_{i}.png"))
+ img_str += f"\n"
+
+ yaml = f"""
+---
+license: creativeml-openrail-m
+base_model: {base_model}
+prompt: "a photo of sks"
+tags:
+- stable-diffusion-inpainting
+- stable-diffusion-inpainting-diffusers
+- text-to-image
+- diffusers
+- realfill
+- diffusers-training
+inference: true
+---
+ """
+ model_card = f"""
+# RealFill - {repo_id}
+
+This is a realfill model derived from {base_model}. The weights were trained using [RealFill](https://realfill.github.io/).
+You can find some example images in the following. \n
+{img_str}
+"""
+ with open(os.path.join(repo_folder, "README.md"), "w") as f:
+ f.write(yaml + model_card)
+
+
+def log_validation(
+ text_encoder,
+ tokenizer,
+ unet,
+ args,
+ accelerator,
+ weight_dtype,
+ epoch,
+):
+ logger.info(f"Running validation... \nGenerating {args.num_validation_images} images")
+
+ # create pipeline (note: unet and vae are loaded again in float32)
+ pipeline = StableDiffusionInpaintPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ tokenizer=tokenizer,
+ revision=args.revision,
+ torch_dtype=weight_dtype,
+ )
+
+ # set `keep_fp32_wrapper` to True because we do not want to remove
+ # mixed precision hooks while we are still training
+ pipeline.unet = accelerator.unwrap_model(unet, keep_fp32_wrapper=True)
+ pipeline.text_encoder = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True)
+ pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
+
+ pipeline = pipeline.to(accelerator.device)
+ pipeline.set_progress_bar_config(disable=True)
+
+ # run inference
+ generator = None if args.seed is None else torch.Generator(device=accelerator.device).manual_seed(args.seed)
+
+ target_dir = Path(args.train_data_dir) / "target"
+ target_image, target_mask = target_dir / "target.png", target_dir / "mask.png"
+ image, mask_image = Image.open(target_image), Image.open(target_mask)
+
+ if image.mode != "RGB":
+ image = image.convert("RGB")
+
+ images = []
+ for _ in range(args.num_validation_images):
+ image = pipeline(
+ prompt="a photo of sks",
+ image=image,
+ mask_image=mask_image,
+ num_inference_steps=25,
+ guidance_scale=5,
+ generator=generator,
+ ).images[0]
+ images.append(image)
+
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
+ if tracker.name == "wandb":
+ tracker.log({"validation": [wandb.Image(image, caption=str(i)) for i, image in enumerate(images)]})
+
+ del pipeline
+ torch.cuda.empty_cache()
+
+ return images
+
+
+def parse_args(input_args=None):
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--tokenizer_name",
+ type=str,
+ default=None,
+ help="Pretrained tokenizer name or path if not the same as model_name",
+ )
+ parser.add_argument(
+ "--train_data_dir",
+ type=str,
+ default=None,
+ required=True,
+ help="A folder containing the training data of images.",
+ )
+ parser.add_argument(
+ "--num_validation_images",
+ type=int,
+ default=4,
+ help="Number of images that should be generated during validation with `validation_conditioning`.",
+ )
+ parser.add_argument(
+ "--validation_steps",
+ type=int,
+ default=100,
+ help=(
+ "Run realfill validation every X steps. RealFill validation consists of running the conditioning"
+ " `args.validation_conditioning` multiple times: `args.num_validation_images`."
+ ),
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="realfill-model",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=1)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
+ " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--unet_learning_rate",
+ type=float,
+ default=2e-4,
+ help="Learning rate to use for unet.",
+ )
+ parser.add_argument(
+ "--text_encoder_learning_rate",
+ type=float,
+ default=4e-5,
+ help="Learning rate to use for text encoder.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--lr_num_cycles",
+ type=int,
+ default=1,
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
+ )
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
+ parser.add_argument(
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument(
+ "--wandb_key",
+ type=str,
+ default=None,
+ help=("If report to option is set to wandb, api-key for wandb used for login to wandb "),
+ )
+ parser.add_argument(
+ "--wandb_project_name",
+ type=str,
+ default=None,
+ help=("If report to option is set to wandb, project name in wandb for log tracking "),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+ parser.add_argument(
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
+ )
+ parser.add_argument(
+ "--set_grads_to_none",
+ action="store_true",
+ help=(
+ "Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain"
+ " behaviors, so disable this argument if it causes any problems. More info:"
+ " https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html"
+ ),
+ )
+ parser.add_argument(
+ "--lora_rank",
+ type=int,
+ default=16,
+ help=("The dimension of the LoRA update matrices."),
+ )
+ parser.add_argument(
+ "--lora_alpha",
+ type=int,
+ default=27,
+ help=("The alpha constant of the LoRA update matrices."),
+ )
+ parser.add_argument(
+ "--lora_dropout",
+ type=float,
+ default=0.0,
+ help="The dropout rate of the LoRA update matrices.",
+ )
+ parser.add_argument(
+ "--lora_bias",
+ type=str,
+ default="none",
+ help="The bias type of the Lora update matrices. Must be 'none', 'all' or 'lora_only'.",
+ )
+
+ if input_args is not None:
+ args = parser.parse_args(input_args)
+ else:
+ args = parser.parse_args()
+
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ return args
+
+
+class RealFillDataset(Dataset):
+ """
+ A dataset to prepare the training and conditioning images and
+ the masks with the dummy prompt for fine-tuning the model.
+ It pre-processes the images, masks and tokenizes the prompts.
+ """
+
+ def __init__(
+ self,
+ train_data_root,
+ tokenizer,
+ size=512,
+ ):
+ self.size = size
+ self.tokenizer = tokenizer
+
+ self.ref_data_root = Path(train_data_root) / "ref"
+ self.target_image = Path(train_data_root) / "target" / "target.png"
+ self.target_mask = Path(train_data_root) / "target" / "mask.png"
+ if not (self.ref_data_root.exists() and self.target_image.exists() and self.target_mask.exists()):
+ raise ValueError("Train images root doesn't exists.")
+
+ self.train_images_path = list(self.ref_data_root.iterdir()) + [self.target_image]
+ self.num_train_images = len(self.train_images_path)
+ self.train_prompt = "a photo of sks"
+
+ self.transform = transforms_v2.Compose(
+ [
+ transforms_v2.ToImage(),
+ transforms_v2.RandomResize(size, int(1.125 * size)),
+ transforms_v2.RandomCrop(size),
+ transforms_v2.ToDtype(torch.float32, scale=True),
+ transforms_v2.Normalize([0.5], [0.5]),
+ ]
+ )
+
+ def __len__(self):
+ return self.num_train_images
+
+ def __getitem__(self, index):
+ example = {}
+
+ image = Image.open(self.train_images_path[index])
+ image = exif_transpose(image)
+
+ if not image.mode == "RGB":
+ image = image.convert("RGB")
+
+ if index < len(self) - 1:
+ weighting = Image.new("L", image.size)
+ else:
+ weighting = Image.open(self.target_mask)
+ weighting = exif_transpose(weighting)
+
+ image, weighting = self.transform(image, weighting)
+ example["images"], example["weightings"] = image, weighting < 0
+
+ if random.random() < 0.1:
+ example["masks"] = torch.ones_like(example["images"][0:1, :, :])
+ else:
+ example["masks"] = make_mask(example["images"], self.size)
+
+ example["conditioning_images"] = example["images"] * (example["masks"] < 0.5)
+
+ train_prompt = "" if random.random() < 0.1 else self.train_prompt
+ example["prompt_ids"] = self.tokenizer(
+ train_prompt,
+ truncation=True,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ return_tensors="pt",
+ ).input_ids
+
+ return example
+
+
+def collate_fn(examples):
+ input_ids = [example["prompt_ids"] for example in examples]
+ images = [example["images"] for example in examples]
+
+ masks = [example["masks"] for example in examples]
+ weightings = [example["weightings"] for example in examples]
+ conditioning_images = [example["conditioning_images"] for example in examples]
+
+ images = torch.stack(images)
+ images = images.to(memory_format=torch.contiguous_format).float()
+
+ masks = torch.stack(masks)
+ masks = masks.to(memory_format=torch.contiguous_format).float()
+
+ weightings = torch.stack(weightings)
+ weightings = weightings.to(memory_format=torch.contiguous_format).float()
+
+ conditioning_images = torch.stack(conditioning_images)
+ conditioning_images = conditioning_images.to(memory_format=torch.contiguous_format).float()
+
+ input_ids = torch.cat(input_ids, dim=0)
+
+ batch = {
+ "input_ids": input_ids,
+ "images": images,
+ "masks": masks,
+ "weightings": weightings,
+ "conditioning_images": conditioning_images,
+ }
+ return batch
+
+
+def main(args):
+ if args.report_to == "wandb" and args.hub_token is not None:
+ raise ValueError(
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
+ " Please use `huggingface-cli login` to authenticate with the Hub."
+ )
+
+ logging_dir = Path(args.output_dir, args.logging_dir)
+
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_dir=logging_dir,
+ )
+
+ if args.report_to == "wandb":
+ if not is_wandb_available():
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
+
+ wandb.login(key=args.wandb_key)
+ wandb.init(project=args.wandb_project_name)
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
+ ).repo_id
+
+ # Load the tokenizer
+ if args.tokenizer_name:
+ tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False)
+ elif args.pretrained_model_name_or_path:
+ tokenizer = AutoTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="tokenizer",
+ revision=args.revision,
+ use_fast=False,
+ )
+
+ # Load scheduler and models
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
+ text_encoder = CLIPTextModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
+ )
+ vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
+ unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
+ )
+
+ config = LoraConfig(
+ r=args.lora_rank,
+ lora_alpha=args.lora_alpha,
+ target_modules=["to_k", "to_q", "to_v", "key", "query", "value"],
+ lora_dropout=args.lora_dropout,
+ bias=args.lora_bias,
+ )
+ unet = get_peft_model(unet, config)
+
+ config = LoraConfig(
+ r=args.lora_rank,
+ lora_alpha=args.lora_alpha,
+ target_modules=["k_proj", "q_proj", "v_proj"],
+ lora_dropout=args.lora_dropout,
+ bias=args.lora_bias,
+ )
+ text_encoder = get_peft_model(text_encoder, config)
+
+ vae.requires_grad_(False)
+
+ if args.enable_xformers_memory_efficient_attention:
+ if is_xformers_available():
+ import xformers
+
+ xformers_version = version.parse(xformers.__version__)
+ if xformers_version == version.parse("0.0.16"):
+ logger.warning(
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
+ )
+ unet.enable_xformers_memory_efficient_attention()
+ else:
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
+
+ if args.gradient_checkpointing:
+ unet.enable_gradient_checkpointing()
+ text_encoder.gradient_checkpointing_enable()
+
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
+ def save_model_hook(models, weights, output_dir):
+ if accelerator.is_main_process:
+ for model in models:
+ sub_dir = (
+ "unet"
+ if isinstance(model.base_model.model, type(accelerator.unwrap_model(unet).base_model.model))
+ else "text_encoder"
+ )
+ model.save_pretrained(os.path.join(output_dir, sub_dir))
+
+ # make sure to pop weight so that corresponding model is not saved again
+ weights.pop()
+
+ def load_model_hook(models, input_dir):
+ while len(models) > 0:
+ # pop models so that they are not loaded again
+ model = models.pop()
+
+ sub_dir = (
+ "unet"
+ if isinstance(model.base_model.model, type(accelerator.unwrap_model(unet).base_model.model))
+ else "text_encoder"
+ )
+ model_cls = (
+ UNet2DConditionModel
+ if isinstance(model.base_model.model, type(accelerator.unwrap_model(unet).base_model.model))
+ else CLIPTextModel
+ )
+
+ load_model = model_cls.from_pretrained(args.pretrained_model_name_or_path, subfolder=sub_dir)
+ load_model = PeftModel.from_pretrained(load_model, input_dir, subfolder=sub_dir)
+
+ model.load_state_dict(load_model.state_dict())
+ del load_model
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.scale_lr:
+ args.unet_learning_rate = (
+ args.unet_learning_rate
+ * args.gradient_accumulation_steps
+ * args.train_batch_size
+ * accelerator.num_processes
+ )
+
+ args.text_encoder_learning_rate = (
+ args.text_encoder_learning_rate
+ * args.gradient_accumulation_steps
+ * args.train_batch_size
+ * accelerator.num_processes
+ )
+
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
+ )
+
+ optimizer_class = bnb.optim.AdamW8bit
+ else:
+ optimizer_class = torch.optim.AdamW
+
+ # Optimizer creation
+ optimizer = optimizer_class(
+ [
+ {"params": unet.parameters(), "lr": args.unet_learning_rate},
+ {"params": text_encoder.parameters(), "lr": args.text_encoder_learning_rate},
+ ],
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ # Dataset and DataLoaders creation:
+ train_dataset = RealFillDataset(
+ train_data_root=args.train_data_dir,
+ tokenizer=tokenizer,
+ size=args.resolution,
+ )
+
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset,
+ batch_size=args.train_batch_size,
+ shuffle=True,
+ collate_fn=collate_fn,
+ num_workers=1,
+ )
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
+ num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
+ num_cycles=args.lr_num_cycles,
+ power=args.lr_power,
+ )
+
+ # Prepare everything with our `accelerator`.
+ unet, text_encoder, optimizer, train_dataloader = accelerator.prepare(
+ unet, text_encoder, optimizer, train_dataloader
+ )
+
+ # For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision
+ # as these weights are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ # Move vae to device and cast to weight_dtype
+ vae.to(accelerator.device, dtype=weight_dtype)
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ tracker_config = vars(copy.deepcopy(args))
+ accelerator.init_trackers("realfill", config=tracker_config)
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the mos recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ initial_global_step = 0
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ initial_global_step = global_step
+ first_epoch = global_step // num_update_steps_per_epoch
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
+
+ for epoch in range(first_epoch, args.num_train_epochs):
+ unet.train()
+ text_encoder.train()
+
+ for step, batch in enumerate(train_dataloader):
+ with accelerator.accumulate(unet, text_encoder):
+ # Convert images to latent space
+ latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
+ latents = latents * 0.18215
+
+ # Convert masked images to latent space
+ conditionings = vae.encode(batch["conditioning_images"].to(dtype=weight_dtype)).latent_dist.sample()
+ conditionings = conditionings * 0.18215
+
+ # Downsample mask and weighting so that they match with the latents
+ masks, size = batch["masks"].to(dtype=weight_dtype), latents.shape[2:]
+ masks = F.interpolate(masks, size=size)
+
+ weightings = batch["weightings"].to(dtype=weight_dtype)
+ weightings = F.interpolate(weightings, size=size)
+
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(latents)
+ bsz = latents.shape[0]
+
+ # Sample a random timestep for each image
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
+ timesteps = timesteps.long()
+
+ # Add noise to the latents according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
+
+ # Concatenate noisy latents, masks and conditionings to get inputs to unet
+ inputs = torch.cat([noisy_latents, masks, conditionings], dim=1)
+
+ # Get the text embedding for conditioning
+ encoder_hidden_states = text_encoder(batch["input_ids"])[0]
+
+ # Predict the noise residual
+ model_pred = unet(inputs, timesteps, encoder_hidden_states).sample
+
+ # Compute the diffusion loss
+ assert noise_scheduler.config.prediction_type == "epsilon"
+ loss = (weightings * F.mse_loss(model_pred.float(), noise.float(), reduction="none")).mean()
+
+ # Backpropagate
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ params_to_clip = itertools.chain(unet.parameters(), text_encoder.parameters())
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
+
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad(set_to_none=args.set_grads_to_none)
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ if args.report_to == "wandb":
+ accelerator.print(progress_bar)
+ global_step += 1
+
+ if accelerator.is_main_process:
+ if global_step % args.checkpointing_steps == 0:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ if global_step % args.validation_steps == 0:
+ log_validation(
+ text_encoder,
+ tokenizer,
+ unet,
+ args,
+ accelerator,
+ weight_dtype,
+ global_step,
+ )
+
+ logs = {"loss": loss.detach().item()}
+ progress_bar.set_postfix(**logs)
+ accelerator.log(logs, step=global_step)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ # Save the lora layers
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ pipeline = StableDiffusionInpaintPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ unet=accelerator.unwrap_model(unet, keep_fp32_wrapper=True).merge_and_unload(),
+ text_encoder=accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True).merge_and_unload(),
+ revision=args.revision,
+ )
+
+ pipeline.save_pretrained(args.output_dir)
+
+ # Final inference
+ images = log_validation(
+ text_encoder,
+ tokenizer,
+ unet,
+ args,
+ accelerator,
+ weight_dtype,
+ global_step,
+ )
+
+ if args.push_to_hub:
+ save_model_card(
+ repo_id,
+ images=images,
+ base_model=args.pretrained_model_name_or_path,
+ repo_folder=args.output_dir,
+ )
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ main(args)
diff --git a/diffusers/examples/research_projects/scheduled_huber_loss_training/README.md b/diffusers/examples/research_projects/scheduled_huber_loss_training/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..239f94ba10057b20b193606a0678bc4853934ee7
--- /dev/null
+++ b/diffusers/examples/research_projects/scheduled_huber_loss_training/README.md
@@ -0,0 +1,15 @@
+# Scheduled Pseudo-Huber Loss for Diffusers
+
+These are the modifications of to include the possibility of training text2image models with Scheduled Pseudo Huber loss, introduced in https://arxiv.org/abs/2403.16728. (https://github.com/kabachuha/SPHL-for-stable-diffusion)
+
+## Why this might be useful?
+
+- If you suspect that the part of the training dataset might be corrupted, and you don't want these outliers to distort the model's supposed output
+
+- If you want to improve the aesthetic quality of pictures by helping the model disentangle concepts and be less influenced by another sorts of pictures.
+
+See https://github.com/huggingface/diffusers/issues/7488 for the detailed description.
+
+## Instructions
+
+The same usage as in the case of the corresponding vanilla Diffusers scripts https://github.com/huggingface/diffusers/tree/main/examples
diff --git a/diffusers/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth.py b/diffusers/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f7ca2262dccfaa19cfc3750c2a2f317f4f6c4f5
--- /dev/null
+++ b/diffusers/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth.py
@@ -0,0 +1,1518 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+import argparse
+import copy
+import gc
+import importlib
+import itertools
+import logging
+import math
+import os
+import shutil
+import warnings
+from pathlib import Path
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import ProjectConfiguration, set_seed
+from huggingface_hub import create_repo, model_info, upload_folder
+from huggingface_hub.utils import insecure_hashlib
+from packaging import version
+from PIL import Image
+from PIL.ImageOps import exif_transpose
+from torch.utils.data import Dataset
+from torchvision import transforms
+from tqdm.auto import tqdm
+from transformers import AutoTokenizer, PretrainedConfig
+
+import diffusers
+from diffusers import (
+ AutoencoderKL,
+ DDPMScheduler,
+ DiffusionPipeline,
+ StableDiffusionPipeline,
+ UNet2DConditionModel,
+)
+from diffusers.optimization import get_scheduler
+from diffusers.training_utils import compute_snr
+from diffusers.utils import check_min_version, is_wandb_available
+from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
+from diffusers.utils.import_utils import is_xformers_available
+from diffusers.utils.torch_utils import is_compiled_module
+
+
+if is_wandb_available():
+ import wandb
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.28.0.dev0")
+
+logger = get_logger(__name__)
+
+
+def save_model_card(
+ repo_id: str,
+ images: list = None,
+ base_model: str = None,
+ train_text_encoder=False,
+ prompt: str = None,
+ repo_folder: str = None,
+ pipeline: DiffusionPipeline = None,
+):
+ img_str = ""
+ if images is not None:
+ for i, image in enumerate(images):
+ image.save(os.path.join(repo_folder, f"image_{i}.png"))
+ img_str += f"\n"
+
+ model_description = f"""
+# DreamBooth - {repo_id}
+
+This is a dreambooth model derived from {base_model}. The weights were trained on {prompt} using [DreamBooth](https://dreambooth.github.io/).
+You can find some example images in the following. \n
+{img_str}
+
+DreamBooth for the text encoder was enabled: {train_text_encoder}.
+"""
+ model_card = load_or_create_model_card(
+ repo_id_or_path=repo_id,
+ from_training=True,
+ license="creativeml-openrail-m",
+ base_model=base_model,
+ prompt=prompt,
+ model_description=model_description,
+ inference=True,
+ )
+
+ tags = ["text-to-image", "dreambooth", "diffusers-training"]
+ if isinstance(pipeline, StableDiffusionPipeline):
+ tags.extend(["stable-diffusion", "stable-diffusion-diffusers"])
+ else:
+ tags.extend(["if", "if-diffusers"])
+ model_card = populate_model_card(model_card, tags=tags)
+
+ model_card.save(os.path.join(repo_folder, "README.md"))
+
+
+def log_validation(
+ text_encoder,
+ tokenizer,
+ unet,
+ vae,
+ args,
+ accelerator,
+ weight_dtype,
+ global_step,
+ prompt_embeds,
+ negative_prompt_embeds,
+):
+ logger.info(
+ f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
+ f" {args.validation_prompt}."
+ )
+
+ pipeline_args = {}
+
+ if vae is not None:
+ pipeline_args["vae"] = vae
+
+ # create pipeline (note: unet and vae are loaded again in float32)
+ pipeline = DiffusionPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ tokenizer=tokenizer,
+ text_encoder=text_encoder,
+ unet=unet,
+ revision=args.revision,
+ variant=args.variant,
+ torch_dtype=weight_dtype,
+ **pipeline_args,
+ )
+
+ # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
+ scheduler_args = {}
+
+ if "variance_type" in pipeline.scheduler.config:
+ variance_type = pipeline.scheduler.config.variance_type
+
+ if variance_type in ["learned", "learned_range"]:
+ variance_type = "fixed_small"
+
+ scheduler_args["variance_type"] = variance_type
+
+ module = importlib.import_module("diffusers")
+ scheduler_class = getattr(module, args.validation_scheduler)
+ pipeline.scheduler = scheduler_class.from_config(pipeline.scheduler.config, **scheduler_args)
+ pipeline = pipeline.to(accelerator.device)
+ pipeline.set_progress_bar_config(disable=True)
+
+ if args.pre_compute_text_embeddings:
+ pipeline_args = {
+ "prompt_embeds": prompt_embeds,
+ "negative_prompt_embeds": negative_prompt_embeds,
+ }
+ else:
+ pipeline_args = {"prompt": args.validation_prompt}
+
+ # run inference
+ generator = None if args.seed is None else torch.Generator(device=accelerator.device).manual_seed(args.seed)
+ images = []
+ if args.validation_images is None:
+ for _ in range(args.num_validation_images):
+ with torch.autocast("cuda"):
+ image = pipeline(**pipeline_args, num_inference_steps=25, generator=generator).images[0]
+ images.append(image)
+ else:
+ for image in args.validation_images:
+ image = Image.open(image)
+ image = pipeline(**pipeline_args, image=image, generator=generator).images[0]
+ images.append(image)
+
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images("validation", np_images, global_step, dataformats="NHWC")
+ if tracker.name == "wandb":
+ tracker.log(
+ {
+ "validation": [
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images)
+ ]
+ }
+ )
+
+ del pipeline
+ torch.cuda.empty_cache()
+
+ return images
+
+
+def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
+ text_encoder_config = PretrainedConfig.from_pretrained(
+ pretrained_model_name_or_path,
+ subfolder="text_encoder",
+ revision=revision,
+ )
+ model_class = text_encoder_config.architectures[0]
+
+ if model_class == "CLIPTextModel":
+ from transformers import CLIPTextModel
+
+ return CLIPTextModel
+ elif model_class == "RobertaSeriesModelWithTransformation":
+ from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation
+
+ return RobertaSeriesModelWithTransformation
+ elif model_class == "T5EncoderModel":
+ from transformers import T5EncoderModel
+
+ return T5EncoderModel
+ else:
+ raise ValueError(f"{model_class} is not supported.")
+
+
+def parse_args(input_args=None):
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--variant",
+ type=str,
+ default=None,
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
+ )
+ parser.add_argument(
+ "--tokenizer_name",
+ type=str,
+ default=None,
+ help="Pretrained tokenizer name or path if not the same as model_name",
+ )
+ parser.add_argument(
+ "--instance_data_dir",
+ type=str,
+ default=None,
+ required=True,
+ help="A folder containing the training data of instance images.",
+ )
+ parser.add_argument(
+ "--class_data_dir",
+ type=str,
+ default=None,
+ required=False,
+ help="A folder containing the training data of class images.",
+ )
+ parser.add_argument(
+ "--instance_prompt",
+ type=str,
+ default=None,
+ required=True,
+ help="The prompt with identifier specifying the instance",
+ )
+ parser.add_argument(
+ "--class_prompt",
+ type=str,
+ default=None,
+ help="The prompt to specify images in the same class as provided instance images.",
+ )
+ parser.add_argument(
+ "--with_prior_preservation",
+ default=False,
+ action="store_true",
+ help="Flag to add prior preservation loss.",
+ )
+ parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.")
+ parser.add_argument(
+ "--num_class_images",
+ type=int,
+ default=100,
+ help=(
+ "Minimal class images for prior preservation loss. If there are not enough images already present in"
+ " class_data_dir, additional images will be sampled with class_prompt."
+ ),
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="dreambooth-model",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--center_crop",
+ default=False,
+ action="store_true",
+ help=(
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
+ " cropped. The images will be resized to the resolution first before cropping."
+ ),
+ )
+ parser.add_argument(
+ "--train_text_encoder",
+ action="store_true",
+ help="Whether to train the text encoder. If set, the text encoder should be float32 precision.",
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument(
+ "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=1)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. "
+ "In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference."
+ "Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components."
+ "See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step"
+ "instructions."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=(
+ "Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
+ " See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
+ " for more details"
+ ),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=5e-6,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--lr_num_cycles",
+ type=int,
+ default=1,
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
+ )
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
+ parser.add_argument(
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
+ )
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument(
+ "--validation_prompt",
+ type=str,
+ default=None,
+ help="A prompt that is used during validation to verify that the model is learning.",
+ )
+ parser.add_argument(
+ "--num_validation_images",
+ type=int,
+ default=4,
+ help="Number of images that should be generated during validation with `validation_prompt`.",
+ )
+ parser.add_argument(
+ "--validation_steps",
+ type=int,
+ default=100,
+ help=(
+ "Run validation every X steps. Validation consists of running the prompt"
+ " `args.validation_prompt` multiple times: `args.num_validation_images`"
+ " and logging the images."
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument(
+ "--prior_generation_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp32", "fp16", "bf16"],
+ help=(
+ "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32."
+ ),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+ parser.add_argument(
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
+ )
+ parser.add_argument(
+ "--set_grads_to_none",
+ action="store_true",
+ help=(
+ "Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain"
+ " behaviors, so disable this argument if it causes any problems. More info:"
+ " https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html"
+ ),
+ )
+
+ parser.add_argument(
+ "--offset_noise",
+ action="store_true",
+ default=False,
+ help=(
+ "Fine-tuning against a modified noise"
+ " See: https://www.crosslabs.org//blog/diffusion-with-offset-noise for more information."
+ ),
+ )
+ parser.add_argument(
+ "--snr_gamma",
+ type=float,
+ default=None,
+ help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
+ "More details here: https://arxiv.org/abs/2303.09556.",
+ )
+ parser.add_argument(
+ "--pre_compute_text_embeddings",
+ action="store_true",
+ help="Whether or not to pre-compute text embeddings. If text embeddings are pre-computed, the text encoder will not be kept in memory during training and will leave more GPU memory available for training the rest of the model. This is not compatible with `--train_text_encoder`.",
+ )
+ parser.add_argument(
+ "--tokenizer_max_length",
+ type=int,
+ default=None,
+ required=False,
+ help="The maximum length of the tokenizer. If not set, will default to the tokenizer's max length.",
+ )
+ parser.add_argument(
+ "--text_encoder_use_attention_mask",
+ action="store_true",
+ required=False,
+ help="Whether to use attention mask for the text encoder",
+ )
+ parser.add_argument(
+ "--skip_save_text_encoder", action="store_true", required=False, help="Set to not save text encoder"
+ )
+ parser.add_argument(
+ "--validation_images",
+ required=False,
+ default=None,
+ nargs="+",
+ help="Optional set of images to use for validation. Used when the target pipeline takes an initial image as input such as when training image variation or superresolution.",
+ )
+ parser.add_argument(
+ "--class_labels_conditioning",
+ required=False,
+ default=None,
+ help="The optional `class_label` conditioning to pass to the unet, available values are `timesteps`.",
+ )
+ parser.add_argument(
+ "--loss_type",
+ type=str,
+ default="l2",
+ choices=["l2", "huber", "smooth_l1"],
+ help="The type of loss to use and whether it's timestep-scheduled. See Issue #7488 for more info.",
+ )
+ parser.add_argument(
+ "--huber_schedule",
+ type=str,
+ default="snr",
+ choices=["constant", "exponential", "snr"],
+ help="The schedule to use for the huber losses parameter",
+ )
+ parser.add_argument(
+ "--huber_c",
+ type=float,
+ default=0.1,
+ help="The huber loss parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type.",
+ )
+ parser.add_argument(
+ "--validation_scheduler",
+ type=str,
+ default="DPMSolverMultistepScheduler",
+ choices=["DPMSolverMultistepScheduler", "DDPMScheduler"],
+ help="Select which scheduler to use for validation. DDPMScheduler is recommended for DeepFloyd IF.",
+ )
+
+ if input_args is not None:
+ args = parser.parse_args(input_args)
+ else:
+ args = parser.parse_args()
+
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ if args.with_prior_preservation:
+ if args.class_data_dir is None:
+ raise ValueError("You must specify a data directory for class images.")
+ if args.class_prompt is None:
+ raise ValueError("You must specify prompt for class images.")
+ else:
+ # logger is not available yet
+ if args.class_data_dir is not None:
+ warnings.warn("You need not use --class_data_dir without --with_prior_preservation.")
+ if args.class_prompt is not None:
+ warnings.warn("You need not use --class_prompt without --with_prior_preservation.")
+
+ if args.train_text_encoder and args.pre_compute_text_embeddings:
+ raise ValueError("`--train_text_encoder` cannot be used with `--pre_compute_text_embeddings`")
+
+ return args
+
+
+class DreamBoothDataset(Dataset):
+ """
+ A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
+ It pre-processes the images and the tokenizes prompts.
+ """
+
+ def __init__(
+ self,
+ instance_data_root,
+ instance_prompt,
+ tokenizer,
+ class_data_root=None,
+ class_prompt=None,
+ class_num=None,
+ size=512,
+ center_crop=False,
+ encoder_hidden_states=None,
+ class_prompt_encoder_hidden_states=None,
+ tokenizer_max_length=None,
+ ):
+ self.size = size
+ self.center_crop = center_crop
+ self.tokenizer = tokenizer
+ self.encoder_hidden_states = encoder_hidden_states
+ self.class_prompt_encoder_hidden_states = class_prompt_encoder_hidden_states
+ self.tokenizer_max_length = tokenizer_max_length
+
+ self.instance_data_root = Path(instance_data_root)
+ if not self.instance_data_root.exists():
+ raise ValueError(f"Instance {self.instance_data_root} images root doesn't exists.")
+
+ self.instance_images_path = list(Path(instance_data_root).iterdir())
+ self.num_instance_images = len(self.instance_images_path)
+ self.instance_prompt = instance_prompt
+ self._length = self.num_instance_images
+
+ if class_data_root is not None:
+ self.class_data_root = Path(class_data_root)
+ self.class_data_root.mkdir(parents=True, exist_ok=True)
+ self.class_images_path = list(self.class_data_root.iterdir())
+ if class_num is not None:
+ self.num_class_images = min(len(self.class_images_path), class_num)
+ else:
+ self.num_class_images = len(self.class_images_path)
+ self._length = max(self.num_class_images, self.num_instance_images)
+ self.class_prompt = class_prompt
+ else:
+ self.class_data_root = None
+
+ self.image_transforms = transforms.Compose(
+ [
+ transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
+ transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+ )
+
+ def __len__(self):
+ return self._length
+
+ def __getitem__(self, index):
+ example = {}
+ instance_image = Image.open(self.instance_images_path[index % self.num_instance_images])
+ instance_image = exif_transpose(instance_image)
+
+ if not instance_image.mode == "RGB":
+ instance_image = instance_image.convert("RGB")
+ example["instance_images"] = self.image_transforms(instance_image)
+
+ if self.encoder_hidden_states is not None:
+ example["instance_prompt_ids"] = self.encoder_hidden_states
+ else:
+ text_inputs = tokenize_prompt(
+ self.tokenizer, self.instance_prompt, tokenizer_max_length=self.tokenizer_max_length
+ )
+ example["instance_prompt_ids"] = text_inputs.input_ids
+ example["instance_attention_mask"] = text_inputs.attention_mask
+
+ if self.class_data_root:
+ class_image = Image.open(self.class_images_path[index % self.num_class_images])
+ class_image = exif_transpose(class_image)
+
+ if not class_image.mode == "RGB":
+ class_image = class_image.convert("RGB")
+ example["class_images"] = self.image_transforms(class_image)
+
+ if self.class_prompt_encoder_hidden_states is not None:
+ example["class_prompt_ids"] = self.class_prompt_encoder_hidden_states
+ else:
+ class_text_inputs = tokenize_prompt(
+ self.tokenizer, self.class_prompt, tokenizer_max_length=self.tokenizer_max_length
+ )
+ example["class_prompt_ids"] = class_text_inputs.input_ids
+ example["class_attention_mask"] = class_text_inputs.attention_mask
+
+ return example
+
+
+def collate_fn(examples, with_prior_preservation=False):
+ has_attention_mask = "instance_attention_mask" in examples[0]
+
+ input_ids = [example["instance_prompt_ids"] for example in examples]
+ pixel_values = [example["instance_images"] for example in examples]
+
+ if has_attention_mask:
+ attention_mask = [example["instance_attention_mask"] for example in examples]
+
+ # Concat class and instance examples for prior preservation.
+ # We do this to avoid doing two forward passes.
+ if with_prior_preservation:
+ input_ids += [example["class_prompt_ids"] for example in examples]
+ pixel_values += [example["class_images"] for example in examples]
+
+ if has_attention_mask:
+ attention_mask += [example["class_attention_mask"] for example in examples]
+
+ pixel_values = torch.stack(pixel_values)
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
+
+ input_ids = torch.cat(input_ids, dim=0)
+
+ batch = {
+ "input_ids": input_ids,
+ "pixel_values": pixel_values,
+ }
+
+ if has_attention_mask:
+ attention_mask = torch.cat(attention_mask, dim=0)
+ batch["attention_mask"] = attention_mask
+
+ return batch
+
+
+class PromptDataset(Dataset):
+ """A simple dataset to prepare the prompts to generate class images on multiple GPUs."""
+
+ def __init__(self, prompt, num_samples):
+ self.prompt = prompt
+ self.num_samples = num_samples
+
+ def __len__(self):
+ return self.num_samples
+
+ def __getitem__(self, index):
+ example = {}
+ example["prompt"] = self.prompt
+ example["index"] = index
+ return example
+
+
+def model_has_vae(args):
+ config_file_name = os.path.join("vae", AutoencoderKL.config_name)
+ if os.path.isdir(args.pretrained_model_name_or_path):
+ config_file_name = os.path.join(args.pretrained_model_name_or_path, config_file_name)
+ return os.path.isfile(config_file_name)
+ else:
+ files_in_repo = model_info(args.pretrained_model_name_or_path, revision=args.revision).siblings
+ return any(file.rfilename == config_file_name for file in files_in_repo)
+
+
+def tokenize_prompt(tokenizer, prompt, tokenizer_max_length=None):
+ if tokenizer_max_length is not None:
+ max_length = tokenizer_max_length
+ else:
+ max_length = tokenizer.model_max_length
+
+ text_inputs = tokenizer(
+ prompt,
+ truncation=True,
+ padding="max_length",
+ max_length=max_length,
+ return_tensors="pt",
+ )
+
+ return text_inputs
+
+
+def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_attention_mask=None):
+ text_input_ids = input_ids.to(text_encoder.device)
+
+ if text_encoder_use_attention_mask:
+ attention_mask = attention_mask.to(text_encoder.device)
+ else:
+ attention_mask = None
+
+ prompt_embeds = text_encoder(
+ text_input_ids,
+ attention_mask=attention_mask,
+ return_dict=False,
+ )
+ prompt_embeds = prompt_embeds[0]
+
+ return prompt_embeds
+
+
+# NOTE: if you're using the scheduled version, huber_c has to depend on the timesteps already
+def conditional_loss(
+ model_pred: torch.Tensor,
+ target: torch.Tensor,
+ reduction: str = "mean",
+ loss_type: str = "l2",
+ huber_c: float = 0.1,
+):
+ if loss_type == "l2":
+ loss = F.mse_loss(model_pred, target, reduction=reduction)
+ elif loss_type == "huber":
+ loss = 2 * huber_c * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
+ if reduction == "mean":
+ loss = torch.mean(loss)
+ elif reduction == "sum":
+ loss = torch.sum(loss)
+ elif loss_type == "smooth_l1":
+ loss = 2 * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
+ if reduction == "mean":
+ loss = torch.mean(loss)
+ elif reduction == "sum":
+ loss = torch.sum(loss)
+ else:
+ raise NotImplementedError(f"Unsupported Loss Type {loss_type}")
+ return loss
+
+
+def main(args):
+ if args.report_to == "wandb" and args.hub_token is not None:
+ raise ValueError(
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
+ " Please use `huggingface-cli login` to authenticate with the Hub."
+ )
+
+ logging_dir = Path(args.output_dir, args.logging_dir)
+
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ )
+
+ if args.report_to == "wandb":
+ if not is_wandb_available():
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
+
+ # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
+ # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
+ # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.
+ if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1:
+ raise ValueError(
+ "Gradient accumulation is not supported when training the text encoder in distributed training. "
+ "Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
+ )
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Generate class images if prior preservation is enabled.
+ if args.with_prior_preservation:
+ class_images_dir = Path(args.class_data_dir)
+ if not class_images_dir.exists():
+ class_images_dir.mkdir(parents=True)
+ cur_class_images = len(list(class_images_dir.iterdir()))
+
+ if cur_class_images < args.num_class_images:
+ torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
+ if args.prior_generation_precision == "fp32":
+ torch_dtype = torch.float32
+ elif args.prior_generation_precision == "fp16":
+ torch_dtype = torch.float16
+ elif args.prior_generation_precision == "bf16":
+ torch_dtype = torch.bfloat16
+ pipeline = DiffusionPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ torch_dtype=torch_dtype,
+ safety_checker=None,
+ revision=args.revision,
+ variant=args.variant,
+ )
+ pipeline.set_progress_bar_config(disable=True)
+
+ num_new_images = args.num_class_images - cur_class_images
+ logger.info(f"Number of class images to sample: {num_new_images}.")
+
+ sample_dataset = PromptDataset(args.class_prompt, num_new_images)
+ sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
+
+ sample_dataloader = accelerator.prepare(sample_dataloader)
+ pipeline.to(accelerator.device)
+
+ for example in tqdm(
+ sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
+ ):
+ images = pipeline(example["prompt"]).images
+
+ for i, image in enumerate(images):
+ hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest()
+ image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
+ image.save(image_filename)
+
+ del pipeline
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
+ ).repo_id
+
+ # Load the tokenizer
+ if args.tokenizer_name:
+ tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False)
+ elif args.pretrained_model_name_or_path:
+ tokenizer = AutoTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="tokenizer",
+ revision=args.revision,
+ use_fast=False,
+ )
+
+ # import correct text encoder class
+ text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)
+
+ # Load scheduler and models
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
+ text_encoder = text_encoder_cls.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
+ )
+
+ if model_has_vae(args):
+ vae = AutoencoderKL.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant
+ )
+ else:
+ vae = None
+
+ unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
+ )
+
+ def unwrap_model(model):
+ model = accelerator.unwrap_model(model)
+ model = model._orig_mod if is_compiled_module(model) else model
+ return model
+
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
+ def save_model_hook(models, weights, output_dir):
+ if accelerator.is_main_process:
+ for model in models:
+ sub_dir = "unet" if isinstance(model, type(unwrap_model(unet))) else "text_encoder"
+ model.save_pretrained(os.path.join(output_dir, sub_dir))
+
+ # make sure to pop weight so that corresponding model is not saved again
+ weights.pop()
+
+ def load_model_hook(models, input_dir):
+ while len(models) > 0:
+ # pop models so that they are not loaded again
+ model = models.pop()
+
+ if isinstance(model, type(unwrap_model(text_encoder))):
+ # load transformers style into model
+ load_model = text_encoder_cls.from_pretrained(input_dir, subfolder="text_encoder")
+ model.config = load_model.config
+ else:
+ # load diffusers style into model
+ load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet")
+ model.register_to_config(**load_model.config)
+
+ model.load_state_dict(load_model.state_dict())
+ del load_model
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
+
+ if vae is not None:
+ vae.requires_grad_(False)
+
+ if not args.train_text_encoder:
+ text_encoder.requires_grad_(False)
+
+ if args.enable_xformers_memory_efficient_attention:
+ if is_xformers_available():
+ import xformers
+
+ xformers_version = version.parse(xformers.__version__)
+ if xformers_version == version.parse("0.0.16"):
+ logger.warning(
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
+ )
+ unet.enable_xformers_memory_efficient_attention()
+ else:
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
+
+ if args.gradient_checkpointing:
+ unet.enable_gradient_checkpointing()
+ if args.train_text_encoder:
+ text_encoder.gradient_checkpointing_enable()
+
+ # Check that all trainable models are in full precision
+ low_precision_error_string = (
+ "Please make sure to always have all model weights in full float32 precision when starting training - even if"
+ " doing mixed precision training. copy of the weights should still be float32."
+ )
+
+ if unwrap_model(unet).dtype != torch.float32:
+ raise ValueError(f"Unet loaded as datatype {unwrap_model(unet).dtype}. {low_precision_error_string}")
+
+ if args.train_text_encoder and unwrap_model(text_encoder).dtype != torch.float32:
+ raise ValueError(
+ f"Text encoder loaded as datatype {unwrap_model(text_encoder).dtype}." f" {low_precision_error_string}"
+ )
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
+ )
+
+ optimizer_class = bnb.optim.AdamW8bit
+ else:
+ optimizer_class = torch.optim.AdamW
+
+ # Optimizer creation
+ params_to_optimize = (
+ itertools.chain(unet.parameters(), text_encoder.parameters()) if args.train_text_encoder else unet.parameters()
+ )
+ optimizer = optimizer_class(
+ params_to_optimize,
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ if args.pre_compute_text_embeddings:
+
+ def compute_text_embeddings(prompt):
+ with torch.no_grad():
+ text_inputs = tokenize_prompt(tokenizer, prompt, tokenizer_max_length=args.tokenizer_max_length)
+ prompt_embeds = encode_prompt(
+ text_encoder,
+ text_inputs.input_ids,
+ text_inputs.attention_mask,
+ text_encoder_use_attention_mask=args.text_encoder_use_attention_mask,
+ )
+
+ return prompt_embeds
+
+ pre_computed_encoder_hidden_states = compute_text_embeddings(args.instance_prompt)
+ validation_prompt_negative_prompt_embeds = compute_text_embeddings("")
+
+ if args.validation_prompt is not None:
+ validation_prompt_encoder_hidden_states = compute_text_embeddings(args.validation_prompt)
+ else:
+ validation_prompt_encoder_hidden_states = None
+
+ if args.class_prompt is not None:
+ pre_computed_class_prompt_encoder_hidden_states = compute_text_embeddings(args.class_prompt)
+ else:
+ pre_computed_class_prompt_encoder_hidden_states = None
+
+ text_encoder = None
+ tokenizer = None
+
+ gc.collect()
+ torch.cuda.empty_cache()
+ else:
+ pre_computed_encoder_hidden_states = None
+ validation_prompt_encoder_hidden_states = None
+ validation_prompt_negative_prompt_embeds = None
+ pre_computed_class_prompt_encoder_hidden_states = None
+
+ # Dataset and DataLoaders creation:
+ train_dataset = DreamBoothDataset(
+ instance_data_root=args.instance_data_dir,
+ instance_prompt=args.instance_prompt,
+ class_data_root=args.class_data_dir if args.with_prior_preservation else None,
+ class_prompt=args.class_prompt,
+ class_num=args.num_class_images,
+ tokenizer=tokenizer,
+ size=args.resolution,
+ center_crop=args.center_crop,
+ encoder_hidden_states=pre_computed_encoder_hidden_states,
+ class_prompt_encoder_hidden_states=pre_computed_class_prompt_encoder_hidden_states,
+ tokenizer_max_length=args.tokenizer_max_length,
+ )
+
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset,
+ batch_size=args.train_batch_size,
+ shuffle=True,
+ collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),
+ num_workers=args.dataloader_num_workers,
+ )
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
+ num_cycles=args.lr_num_cycles,
+ power=args.lr_power,
+ )
+
+ # Prepare everything with our `accelerator`.
+ if args.train_text_encoder:
+ unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ unet, text_encoder, optimizer, train_dataloader, lr_scheduler
+ )
+ else:
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ unet, optimizer, train_dataloader, lr_scheduler
+ )
+
+ # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
+ # as these weights are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ # Move vae and text_encoder to device and cast to weight_dtype
+ if vae is not None:
+ vae.to(accelerator.device, dtype=weight_dtype)
+
+ if not args.train_text_encoder and text_encoder is not None:
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ tracker_config = vars(copy.deepcopy(args))
+ tracker_config.pop("validation_images")
+ accelerator.init_trackers("dreambooth", config=tracker_config)
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the most recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ initial_global_step = 0
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ initial_global_step = global_step
+ first_epoch = global_step // num_update_steps_per_epoch
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
+
+ for epoch in range(first_epoch, args.num_train_epochs):
+ unet.train()
+ if args.train_text_encoder:
+ text_encoder.train()
+ for step, batch in enumerate(train_dataloader):
+ with accelerator.accumulate(unet):
+ pixel_values = batch["pixel_values"].to(dtype=weight_dtype)
+
+ if vae is not None:
+ # Convert images to latent space
+ model_input = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
+ model_input = model_input * vae.config.scaling_factor
+ else:
+ model_input = pixel_values
+
+ # Sample noise that we'll add to the model input
+ if args.offset_noise:
+ noise = torch.randn_like(model_input) + 0.1 * torch.randn(
+ model_input.shape[0], model_input.shape[1], 1, 1, device=model_input.device
+ )
+ else:
+ noise = torch.randn_like(model_input)
+ bsz, channels, height, width = model_input.shape
+ # Sample a random timestep for each image
+ if args.loss_type == "huber" or args.loss_type == "smooth_l1":
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (1,), device="cpu")
+ timestep = timesteps.item()
+
+ if args.huber_schedule == "exponential":
+ alpha = -math.log(args.huber_c) / noise_scheduler.config.num_train_timesteps
+ huber_c = math.exp(-alpha * timestep)
+ elif args.huber_schedule == "snr":
+ alphas_cumprod = noise_scheduler.alphas_cumprod[timestep]
+ sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5
+ huber_c = (1 - args.huber_c) / (1 + sigmas) ** 2 + args.huber_c
+ elif args.huber_schedule == "constant":
+ huber_c = args.huber_c
+ else:
+ raise NotImplementedError(f"Unknown Huber loss schedule {args.huber_schedule}!")
+
+ timesteps = timesteps.repeat(bsz).to(model_input.device)
+ elif args.loss_type == "l2":
+ timesteps = torch.randint(
+ 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device
+ )
+ huber_c = 1 # may be anything, as it's not used
+ else:
+ raise NotImplementedError(f"Unknown loss type {args.loss_type}")
+
+ timesteps = timesteps.long()
+
+ # Add noise to the model input according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)
+
+ # Get the text embedding for conditioning
+ if args.pre_compute_text_embeddings:
+ encoder_hidden_states = batch["input_ids"]
+ else:
+ encoder_hidden_states = encode_prompt(
+ text_encoder,
+ batch["input_ids"],
+ batch["attention_mask"],
+ text_encoder_use_attention_mask=args.text_encoder_use_attention_mask,
+ )
+
+ if unwrap_model(unet).config.in_channels == channels * 2:
+ noisy_model_input = torch.cat([noisy_model_input, noisy_model_input], dim=1)
+
+ if args.class_labels_conditioning == "timesteps":
+ class_labels = timesteps
+ else:
+ class_labels = None
+
+ # Predict the noise residual
+ model_pred = unet(
+ noisy_model_input, timesteps, encoder_hidden_states, class_labels=class_labels, return_dict=False
+ )[0]
+
+ if model_pred.shape[1] == 6:
+ model_pred, _ = torch.chunk(model_pred, 2, dim=1)
+
+ # Get the target for loss depending on the prediction type
+ if noise_scheduler.config.prediction_type == "epsilon":
+ target = noise
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ target = noise_scheduler.get_velocity(model_input, noise, timesteps)
+ else:
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
+
+ if args.with_prior_preservation:
+ # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
+ model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
+ target, target_prior = torch.chunk(target, 2, dim=0)
+ # Compute prior loss
+ prior_loss = conditional_loss(
+ model_pred_prior.float(),
+ target_prior.float(),
+ reduction="mean",
+ loss_type=args.loss_type,
+ huber_c=huber_c,
+ )
+
+ # Compute instance loss
+ if args.snr_gamma is None:
+ loss = conditional_loss(
+ model_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c
+ )
+ else:
+ # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
+ # Since we predict the noise instead of x_0, the original formulation is slightly changed.
+ # This is discussed in Section 4.2 of the same paper.
+ snr = compute_snr(noise_scheduler, timesteps)
+ base_weight = (
+ torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
+ )
+
+ if noise_scheduler.config.prediction_type == "v_prediction":
+ # Velocity objective needs to be floored to an SNR weight of one.
+ mse_loss_weights = base_weight + 1
+ else:
+ # Epsilon and sample both use the same loss weights.
+ mse_loss_weights = base_weight
+ loss = conditional_loss(
+ model_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
+ )
+ loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
+ loss = loss.mean()
+
+ if args.with_prior_preservation:
+ # Add the prior loss to the instance loss.
+ loss = loss + args.prior_loss_weight * prior_loss
+
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ params_to_clip = (
+ itertools.chain(unet.parameters(), text_encoder.parameters())
+ if args.train_text_encoder
+ else unet.parameters()
+ )
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad(set_to_none=args.set_grads_to_none)
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ global_step += 1
+
+ if accelerator.is_main_process:
+ if global_step % args.checkpointing_steps == 0:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ images = []
+
+ if args.validation_prompt is not None and global_step % args.validation_steps == 0:
+ images = log_validation(
+ unwrap_model(text_encoder) if text_encoder is not None else text_encoder,
+ tokenizer,
+ unwrap_model(unet),
+ vae,
+ args,
+ accelerator,
+ weight_dtype,
+ global_step,
+ validation_prompt_encoder_hidden_states,
+ validation_prompt_negative_prompt_embeds,
+ )
+
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+ accelerator.log(logs, step=global_step)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ # Create the pipeline using the trained modules and save it.
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ pipeline_args = {}
+
+ if text_encoder is not None:
+ pipeline_args["text_encoder"] = unwrap_model(text_encoder)
+
+ if args.skip_save_text_encoder:
+ pipeline_args["text_encoder"] = None
+
+ pipeline = DiffusionPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ unet=unwrap_model(unet),
+ revision=args.revision,
+ variant=args.variant,
+ **pipeline_args,
+ )
+
+ # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
+ scheduler_args = {}
+
+ if "variance_type" in pipeline.scheduler.config:
+ variance_type = pipeline.scheduler.config.variance_type
+
+ if variance_type in ["learned", "learned_range"]:
+ variance_type = "fixed_small"
+
+ scheduler_args["variance_type"] = variance_type
+
+ pipeline.scheduler = pipeline.scheduler.from_config(pipeline.scheduler.config, **scheduler_args)
+
+ pipeline.save_pretrained(args.output_dir)
+
+ if args.push_to_hub:
+ save_model_card(
+ repo_id,
+ images=images,
+ base_model=args.pretrained_model_name_or_path,
+ train_text_encoder=args.train_text_encoder,
+ prompt=args.instance_prompt,
+ repo_folder=args.output_dir,
+ pipeline=pipeline,
+ )
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ main(args)
diff --git a/diffusers/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth_lora.py b/diffusers/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth_lora.py
new file mode 100644
index 0000000000000000000000000000000000000000..663dbbf99473fab9e4e02849f7ca55ad88d30152
--- /dev/null
+++ b/diffusers/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth_lora.py
@@ -0,0 +1,1504 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+import argparse
+import copy
+import gc
+import logging
+import math
+import os
+import shutil
+import warnings
+from pathlib import Path
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import ProjectConfiguration, set_seed
+from huggingface_hub import create_repo, upload_folder
+from huggingface_hub.utils import insecure_hashlib
+from packaging import version
+from peft import LoraConfig
+from peft.utils import get_peft_model_state_dict, set_peft_model_state_dict
+from PIL import Image
+from PIL.ImageOps import exif_transpose
+from torch.utils.data import Dataset
+from torchvision import transforms
+from tqdm.auto import tqdm
+from transformers import AutoTokenizer, PretrainedConfig
+
+import diffusers
+from diffusers import (
+ AutoencoderKL,
+ DDPMScheduler,
+ DiffusionPipeline,
+ DPMSolverMultistepScheduler,
+ StableDiffusionPipeline,
+ UNet2DConditionModel,
+)
+from diffusers.loaders import StableDiffusionLoraLoaderMixin
+from diffusers.optimization import get_scheduler
+from diffusers.training_utils import _set_state_dict_into_text_encoder, cast_training_params
+from diffusers.utils import (
+ check_min_version,
+ convert_state_dict_to_diffusers,
+ convert_unet_state_dict_to_peft,
+ is_wandb_available,
+)
+from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
+from diffusers.utils.import_utils import is_xformers_available
+from diffusers.utils.torch_utils import is_compiled_module
+
+
+if is_wandb_available():
+ import wandb
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.28.0.dev0")
+
+logger = get_logger(__name__)
+
+
+def save_model_card(
+ repo_id: str,
+ images=None,
+ base_model=str,
+ train_text_encoder=False,
+ prompt=str,
+ repo_folder=None,
+ pipeline: DiffusionPipeline = None,
+):
+ img_str = ""
+ for i, image in enumerate(images):
+ image.save(os.path.join(repo_folder, f"image_{i}.png"))
+ img_str += f"\n"
+
+ model_description = f"""
+# LoRA DreamBooth - {repo_id}
+
+These are LoRA adaption weights for {base_model}. The weights were trained on {prompt} using [DreamBooth](https://dreambooth.github.io/). You can find some example images in the following. \n
+{img_str}
+
+LoRA for the text encoder was enabled: {train_text_encoder}.
+"""
+ model_card = load_or_create_model_card(
+ repo_id_or_path=repo_id,
+ from_training=True,
+ license="creativeml-openrail-m",
+ base_model=base_model,
+ prompt=prompt,
+ model_description=model_description,
+ inference=True,
+ )
+ tags = ["text-to-image", "diffusers", "lora", "diffusers-training"]
+ if isinstance(pipeline, StableDiffusionPipeline):
+ tags.extend(["stable-diffusion", "stable-diffusion-diffusers"])
+ else:
+ tags.extend(["if", "if-diffusers"])
+ model_card = populate_model_card(model_card, tags=tags)
+
+ model_card.save(os.path.join(repo_folder, "README.md"))
+
+
+def log_validation(
+ pipeline,
+ args,
+ accelerator,
+ pipeline_args,
+ epoch,
+ is_final_validation=False,
+):
+ logger.info(
+ f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
+ f" {args.validation_prompt}."
+ )
+ # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
+ scheduler_args = {}
+
+ if "variance_type" in pipeline.scheduler.config:
+ variance_type = pipeline.scheduler.config.variance_type
+
+ if variance_type in ["learned", "learned_range"]:
+ variance_type = "fixed_small"
+
+ scheduler_args["variance_type"] = variance_type
+
+ pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args)
+
+ pipeline = pipeline.to(accelerator.device)
+ pipeline.set_progress_bar_config(disable=True)
+
+ # run inference
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
+
+ if args.validation_images is None:
+ images = []
+ for _ in range(args.num_validation_images):
+ with torch.cuda.amp.autocast():
+ image = pipeline(**pipeline_args, generator=generator).images[0]
+ images.append(image)
+ else:
+ images = []
+ for image in args.validation_images:
+ image = Image.open(image)
+ with torch.cuda.amp.autocast():
+ image = pipeline(**pipeline_args, image=image, generator=generator).images[0]
+ images.append(image)
+
+ for tracker in accelerator.trackers:
+ phase_name = "test" if is_final_validation else "validation"
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images(phase_name, np_images, epoch, dataformats="NHWC")
+ if tracker.name == "wandb":
+ tracker.log(
+ {
+ phase_name: [
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images)
+ ]
+ }
+ )
+
+ del pipeline
+ torch.cuda.empty_cache()
+
+ return images
+
+
+def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
+ text_encoder_config = PretrainedConfig.from_pretrained(
+ pretrained_model_name_or_path,
+ subfolder="text_encoder",
+ revision=revision,
+ )
+ model_class = text_encoder_config.architectures[0]
+
+ if model_class == "CLIPTextModel":
+ from transformers import CLIPTextModel
+
+ return CLIPTextModel
+ elif model_class == "RobertaSeriesModelWithTransformation":
+ from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation
+
+ return RobertaSeriesModelWithTransformation
+ elif model_class == "T5EncoderModel":
+ from transformers import T5EncoderModel
+
+ return T5EncoderModel
+ else:
+ raise ValueError(f"{model_class} is not supported.")
+
+
+def parse_args(input_args=None):
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--variant",
+ type=str,
+ default=None,
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
+ )
+ parser.add_argument(
+ "--tokenizer_name",
+ type=str,
+ default=None,
+ help="Pretrained tokenizer name or path if not the same as model_name",
+ )
+ parser.add_argument(
+ "--instance_data_dir",
+ type=str,
+ default=None,
+ required=True,
+ help="A folder containing the training data of instance images.",
+ )
+ parser.add_argument(
+ "--class_data_dir",
+ type=str,
+ default=None,
+ required=False,
+ help="A folder containing the training data of class images.",
+ )
+ parser.add_argument(
+ "--instance_prompt",
+ type=str,
+ default=None,
+ required=True,
+ help="The prompt with identifier specifying the instance",
+ )
+ parser.add_argument(
+ "--class_prompt",
+ type=str,
+ default=None,
+ help="The prompt to specify images in the same class as provided instance images.",
+ )
+ parser.add_argument(
+ "--validation_prompt",
+ type=str,
+ default=None,
+ help="A prompt that is used during validation to verify that the model is learning.",
+ )
+ parser.add_argument(
+ "--num_validation_images",
+ type=int,
+ default=4,
+ help="Number of images that should be generated during validation with `validation_prompt`.",
+ )
+ parser.add_argument(
+ "--validation_epochs",
+ type=int,
+ default=50,
+ help=(
+ "Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt"
+ " `args.validation_prompt` multiple times: `args.num_validation_images`."
+ ),
+ )
+ parser.add_argument(
+ "--with_prior_preservation",
+ default=False,
+ action="store_true",
+ help="Flag to add prior preservation loss.",
+ )
+ parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.")
+ parser.add_argument(
+ "--num_class_images",
+ type=int,
+ default=100,
+ help=(
+ "Minimal class images for prior preservation loss. If there are not enough images already present in"
+ " class_data_dir, additional images will be sampled with class_prompt."
+ ),
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="lora-dreambooth-model",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--center_crop",
+ default=False,
+ action="store_true",
+ help=(
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
+ " cropped. The images will be resized to the resolution first before cropping."
+ ),
+ )
+ parser.add_argument(
+ "--train_text_encoder",
+ action="store_true",
+ help="Whether to train the text encoder. If set, the text encoder should be float32 precision.",
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument(
+ "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=1)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
+ " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=5e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--lr_num_cycles",
+ type=int,
+ default=1,
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
+ )
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ parser.add_argument(
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument(
+ "--prior_generation_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp32", "fp16", "bf16"],
+ help=(
+ "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32."
+ ),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+ parser.add_argument(
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
+ )
+ parser.add_argument(
+ "--pre_compute_text_embeddings",
+ action="store_true",
+ help="Whether or not to pre-compute text embeddings. If text embeddings are pre-computed, the text encoder will not be kept in memory during training and will leave more GPU memory available for training the rest of the model. This is not compatible with `--train_text_encoder`.",
+ )
+ parser.add_argument(
+ "--tokenizer_max_length",
+ type=int,
+ default=None,
+ required=False,
+ help="The maximum length of the tokenizer. If not set, will default to the tokenizer's max length.",
+ )
+ parser.add_argument(
+ "--text_encoder_use_attention_mask",
+ action="store_true",
+ required=False,
+ help="Whether to use attention mask for the text encoder",
+ )
+ parser.add_argument(
+ "--validation_images",
+ required=False,
+ default=None,
+ nargs="+",
+ help="Optional set of images to use for validation. Used when the target pipeline takes an initial image as input such as when training image variation or superresolution.",
+ )
+ parser.add_argument(
+ "--class_labels_conditioning",
+ required=False,
+ default=None,
+ help="The optional `class_label` conditioning to pass to the unet, available values are `timesteps`.",
+ )
+ parser.add_argument(
+ "--loss_type",
+ type=str,
+ default="l2",
+ choices=["l2", "huber", "smooth_l1"],
+ help="The type of loss to use and whether it's timestep-scheduled. See Issue #7488 for more info.",
+ )
+ parser.add_argument(
+ "--huber_schedule",
+ type=str,
+ default="snr",
+ choices=["constant", "exponential", "snr"],
+ help="The schedule to use for the huber losses parameter",
+ )
+ parser.add_argument(
+ "--huber_c",
+ type=float,
+ default=0.1,
+ help="The huber loss parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type.",
+ )
+ parser.add_argument(
+ "--rank",
+ type=int,
+ default=4,
+ help=("The dimension of the LoRA update matrices."),
+ )
+
+ if input_args is not None:
+ args = parser.parse_args(input_args)
+ else:
+ args = parser.parse_args()
+
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ if args.with_prior_preservation:
+ if args.class_data_dir is None:
+ raise ValueError("You must specify a data directory for class images.")
+ if args.class_prompt is None:
+ raise ValueError("You must specify prompt for class images.")
+ else:
+ # logger is not available yet
+ if args.class_data_dir is not None:
+ warnings.warn("You need not use --class_data_dir without --with_prior_preservation.")
+ if args.class_prompt is not None:
+ warnings.warn("You need not use --class_prompt without --with_prior_preservation.")
+
+ if args.train_text_encoder and args.pre_compute_text_embeddings:
+ raise ValueError("`--train_text_encoder` cannot be used with `--pre_compute_text_embeddings`")
+
+ return args
+
+
+class DreamBoothDataset(Dataset):
+ """
+ A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
+ It pre-processes the images and the tokenizes prompts.
+ """
+
+ def __init__(
+ self,
+ instance_data_root,
+ instance_prompt,
+ tokenizer,
+ class_data_root=None,
+ class_prompt=None,
+ class_num=None,
+ size=512,
+ center_crop=False,
+ encoder_hidden_states=None,
+ class_prompt_encoder_hidden_states=None,
+ tokenizer_max_length=None,
+ ):
+ self.size = size
+ self.center_crop = center_crop
+ self.tokenizer = tokenizer
+ self.encoder_hidden_states = encoder_hidden_states
+ self.class_prompt_encoder_hidden_states = class_prompt_encoder_hidden_states
+ self.tokenizer_max_length = tokenizer_max_length
+
+ self.instance_data_root = Path(instance_data_root)
+ if not self.instance_data_root.exists():
+ raise ValueError("Instance images root doesn't exists.")
+
+ self.instance_images_path = list(Path(instance_data_root).iterdir())
+ self.num_instance_images = len(self.instance_images_path)
+ self.instance_prompt = instance_prompt
+ self._length = self.num_instance_images
+
+ if class_data_root is not None:
+ self.class_data_root = Path(class_data_root)
+ self.class_data_root.mkdir(parents=True, exist_ok=True)
+ self.class_images_path = list(self.class_data_root.iterdir())
+ if class_num is not None:
+ self.num_class_images = min(len(self.class_images_path), class_num)
+ else:
+ self.num_class_images = len(self.class_images_path)
+ self._length = max(self.num_class_images, self.num_instance_images)
+ self.class_prompt = class_prompt
+ else:
+ self.class_data_root = None
+
+ self.image_transforms = transforms.Compose(
+ [
+ transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
+ transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+ )
+
+ def __len__(self):
+ return self._length
+
+ def __getitem__(self, index):
+ example = {}
+ instance_image = Image.open(self.instance_images_path[index % self.num_instance_images])
+ instance_image = exif_transpose(instance_image)
+
+ if not instance_image.mode == "RGB":
+ instance_image = instance_image.convert("RGB")
+ example["instance_images"] = self.image_transforms(instance_image)
+
+ if self.encoder_hidden_states is not None:
+ example["instance_prompt_ids"] = self.encoder_hidden_states
+ else:
+ text_inputs = tokenize_prompt(
+ self.tokenizer, self.instance_prompt, tokenizer_max_length=self.tokenizer_max_length
+ )
+ example["instance_prompt_ids"] = text_inputs.input_ids
+ example["instance_attention_mask"] = text_inputs.attention_mask
+
+ if self.class_data_root:
+ class_image = Image.open(self.class_images_path[index % self.num_class_images])
+ class_image = exif_transpose(class_image)
+
+ if not class_image.mode == "RGB":
+ class_image = class_image.convert("RGB")
+ example["class_images"] = self.image_transforms(class_image)
+
+ if self.class_prompt_encoder_hidden_states is not None:
+ example["class_prompt_ids"] = self.class_prompt_encoder_hidden_states
+ else:
+ class_text_inputs = tokenize_prompt(
+ self.tokenizer, self.class_prompt, tokenizer_max_length=self.tokenizer_max_length
+ )
+ example["class_prompt_ids"] = class_text_inputs.input_ids
+ example["class_attention_mask"] = class_text_inputs.attention_mask
+
+ return example
+
+
+def collate_fn(examples, with_prior_preservation=False):
+ has_attention_mask = "instance_attention_mask" in examples[0]
+
+ input_ids = [example["instance_prompt_ids"] for example in examples]
+ pixel_values = [example["instance_images"] for example in examples]
+
+ if has_attention_mask:
+ attention_mask = [example["instance_attention_mask"] for example in examples]
+
+ # Concat class and instance examples for prior preservation.
+ # We do this to avoid doing two forward passes.
+ if with_prior_preservation:
+ input_ids += [example["class_prompt_ids"] for example in examples]
+ pixel_values += [example["class_images"] for example in examples]
+ if has_attention_mask:
+ attention_mask += [example["class_attention_mask"] for example in examples]
+
+ pixel_values = torch.stack(pixel_values)
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
+
+ input_ids = torch.cat(input_ids, dim=0)
+
+ batch = {
+ "input_ids": input_ids,
+ "pixel_values": pixel_values,
+ }
+
+ if has_attention_mask:
+ batch["attention_mask"] = attention_mask
+
+ return batch
+
+
+class PromptDataset(Dataset):
+ """A simple dataset to prepare the prompts to generate class images on multiple GPUs."""
+
+ def __init__(self, prompt, num_samples):
+ self.prompt = prompt
+ self.num_samples = num_samples
+
+ def __len__(self):
+ return self.num_samples
+
+ def __getitem__(self, index):
+ example = {}
+ example["prompt"] = self.prompt
+ example["index"] = index
+ return example
+
+
+def tokenize_prompt(tokenizer, prompt, tokenizer_max_length=None):
+ if tokenizer_max_length is not None:
+ max_length = tokenizer_max_length
+ else:
+ max_length = tokenizer.model_max_length
+
+ text_inputs = tokenizer(
+ prompt,
+ truncation=True,
+ padding="max_length",
+ max_length=max_length,
+ return_tensors="pt",
+ )
+
+ return text_inputs
+
+
+def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_attention_mask=None):
+ text_input_ids = input_ids.to(text_encoder.device)
+
+ if text_encoder_use_attention_mask:
+ attention_mask = attention_mask.to(text_encoder.device)
+ else:
+ attention_mask = None
+
+ prompt_embeds = text_encoder(
+ text_input_ids,
+ attention_mask=attention_mask,
+ return_dict=False,
+ )
+ prompt_embeds = prompt_embeds[0]
+
+ return prompt_embeds
+
+
+# NOTE: if you're using the scheduled version, huber_c has to depend on the timesteps already
+def conditional_loss(
+ model_pred: torch.Tensor,
+ target: torch.Tensor,
+ reduction: str = "mean",
+ loss_type: str = "l2",
+ huber_c: float = 0.1,
+):
+ if loss_type == "l2":
+ loss = F.mse_loss(model_pred, target, reduction=reduction)
+ elif loss_type == "huber":
+ loss = 2 * huber_c * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
+ if reduction == "mean":
+ loss = torch.mean(loss)
+ elif reduction == "sum":
+ loss = torch.sum(loss)
+ elif loss_type == "smooth_l1":
+ loss = 2 * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
+ if reduction == "mean":
+ loss = torch.mean(loss)
+ elif reduction == "sum":
+ loss = torch.sum(loss)
+ else:
+ raise NotImplementedError(f"Unsupported Loss Type {loss_type}")
+ return loss
+
+
+def main(args):
+ if args.report_to == "wandb" and args.hub_token is not None:
+ raise ValueError(
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
+ " Please use `huggingface-cli login` to authenticate with the Hub."
+ )
+
+ logging_dir = Path(args.output_dir, args.logging_dir)
+
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ )
+
+ if args.report_to == "wandb":
+ if not is_wandb_available():
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
+
+ # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
+ # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
+ # TODO (sayakpaul): Remove this check when gradient accumulation with two models is enabled in accelerate.
+ if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1:
+ raise ValueError(
+ "Gradient accumulation is not supported when training the text encoder in distributed training. "
+ "Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
+ )
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Generate class images if prior preservation is enabled.
+ if args.with_prior_preservation:
+ class_images_dir = Path(args.class_data_dir)
+ if not class_images_dir.exists():
+ class_images_dir.mkdir(parents=True)
+ cur_class_images = len(list(class_images_dir.iterdir()))
+
+ if cur_class_images < args.num_class_images:
+ torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
+ if args.prior_generation_precision == "fp32":
+ torch_dtype = torch.float32
+ elif args.prior_generation_precision == "fp16":
+ torch_dtype = torch.float16
+ elif args.prior_generation_precision == "bf16":
+ torch_dtype = torch.bfloat16
+ pipeline = DiffusionPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ torch_dtype=torch_dtype,
+ safety_checker=None,
+ revision=args.revision,
+ variant=args.variant,
+ )
+ pipeline.set_progress_bar_config(disable=True)
+
+ num_new_images = args.num_class_images - cur_class_images
+ logger.info(f"Number of class images to sample: {num_new_images}.")
+
+ sample_dataset = PromptDataset(args.class_prompt, num_new_images)
+ sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
+
+ sample_dataloader = accelerator.prepare(sample_dataloader)
+ pipeline.to(accelerator.device)
+
+ for example in tqdm(
+ sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
+ ):
+ images = pipeline(example["prompt"]).images
+
+ for i, image in enumerate(images):
+ hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest()
+ image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
+ image.save(image_filename)
+
+ del pipeline
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
+ ).repo_id
+
+ # Load the tokenizer
+ if args.tokenizer_name:
+ tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False)
+ elif args.pretrained_model_name_or_path:
+ tokenizer = AutoTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="tokenizer",
+ revision=args.revision,
+ use_fast=False,
+ )
+
+ # import correct text encoder class
+ text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)
+
+ # Load scheduler and models
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
+ text_encoder = text_encoder_cls.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
+ )
+ try:
+ vae = AutoencoderKL.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant
+ )
+ except OSError:
+ # IF does not have a VAE so let's just set it to None
+ # We don't have to error out here
+ vae = None
+
+ unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
+ )
+
+ # We only train the additional adapter LoRA layers
+ if vae is not None:
+ vae.requires_grad_(False)
+ text_encoder.requires_grad_(False)
+ unet.requires_grad_(False)
+
+ # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
+ # as these weights are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ # Move unet, vae and text_encoder to device and cast to weight_dtype
+ unet.to(accelerator.device, dtype=weight_dtype)
+ if vae is not None:
+ vae.to(accelerator.device, dtype=weight_dtype)
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
+
+ if args.enable_xformers_memory_efficient_attention:
+ if is_xformers_available():
+ import xformers
+
+ xformers_version = version.parse(xformers.__version__)
+ if xformers_version == version.parse("0.0.16"):
+ logger.warning(
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
+ )
+ unet.enable_xformers_memory_efficient_attention()
+ else:
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
+
+ if args.gradient_checkpointing:
+ unet.enable_gradient_checkpointing()
+ if args.train_text_encoder:
+ text_encoder.gradient_checkpointing_enable()
+
+ # now we will add new LoRA weights to the attention layers
+ unet_lora_config = LoraConfig(
+ r=args.rank,
+ lora_alpha=args.rank,
+ init_lora_weights="gaussian",
+ target_modules=["to_k", "to_q", "to_v", "to_out.0", "add_k_proj", "add_v_proj"],
+ )
+ unet.add_adapter(unet_lora_config)
+
+ # The text encoder comes from 🤗 transformers, we will also attach adapters to it.
+ if args.train_text_encoder:
+ text_lora_config = LoraConfig(
+ r=args.rank,
+ lora_alpha=args.rank,
+ init_lora_weights="gaussian",
+ target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
+ )
+ text_encoder.add_adapter(text_lora_config)
+
+ def unwrap_model(model):
+ model = accelerator.unwrap_model(model)
+ model = model._orig_mod if is_compiled_module(model) else model
+ return model
+
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
+ def save_model_hook(models, weights, output_dir):
+ if accelerator.is_main_process:
+ # there are only two options here. Either are just the unet attn processor layers
+ # or there are the unet and text encoder atten layers
+ unet_lora_layers_to_save = None
+ text_encoder_lora_layers_to_save = None
+
+ for model in models:
+ if isinstance(model, type(unwrap_model(unet))):
+ unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model))
+ elif isinstance(model, type(unwrap_model(text_encoder))):
+ text_encoder_lora_layers_to_save = convert_state_dict_to_diffusers(
+ get_peft_model_state_dict(model)
+ )
+ else:
+ raise ValueError(f"unexpected save model: {model.__class__}")
+
+ # make sure to pop weight so that corresponding model is not saved again
+ weights.pop()
+
+ StableDiffusionLoraLoaderMixin.save_lora_weights(
+ output_dir,
+ unet_lora_layers=unet_lora_layers_to_save,
+ text_encoder_lora_layers=text_encoder_lora_layers_to_save,
+ )
+
+ def load_model_hook(models, input_dir):
+ unet_ = None
+ text_encoder_ = None
+
+ while len(models) > 0:
+ model = models.pop()
+
+ if isinstance(model, type(unwrap_model(unet))):
+ unet_ = model
+ elif isinstance(model, type(unwrap_model(text_encoder))):
+ text_encoder_ = model
+ else:
+ raise ValueError(f"unexpected save model: {model.__class__}")
+
+ lora_state_dict, network_alphas = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir)
+
+ unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")}
+ unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
+ incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
+
+ if incompatible_keys is not None:
+ # check only for unexpected keys
+ unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
+ if unexpected_keys:
+ logger.warning(
+ f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
+ f" {unexpected_keys}. "
+ )
+
+ if args.train_text_encoder:
+ _set_state_dict_into_text_encoder(lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_)
+
+ # Make sure the trainable params are in float32. This is again needed since the base models
+ # are in `weight_dtype`. More details:
+ # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
+ if args.mixed_precision == "fp16":
+ models = [unet_]
+ if args.train_text_encoder:
+ models.append(text_encoder_)
+
+ # only upcast trainable parameters (LoRA) into fp32
+ cast_training_params(models, dtype=torch.float32)
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ # Make sure the trainable params are in float32.
+ if args.mixed_precision == "fp16":
+ models = [unet]
+ if args.train_text_encoder:
+ models.append(text_encoder)
+
+ # only upcast trainable parameters (LoRA) into fp32
+ cast_training_params(models, dtype=torch.float32)
+
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
+ )
+
+ optimizer_class = bnb.optim.AdamW8bit
+ else:
+ optimizer_class = torch.optim.AdamW
+
+ # Optimizer creation
+ params_to_optimize = list(filter(lambda p: p.requires_grad, unet.parameters()))
+ if args.train_text_encoder:
+ params_to_optimize = params_to_optimize + list(filter(lambda p: p.requires_grad, text_encoder.parameters()))
+
+ optimizer = optimizer_class(
+ params_to_optimize,
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ if args.pre_compute_text_embeddings:
+
+ def compute_text_embeddings(prompt):
+ with torch.no_grad():
+ text_inputs = tokenize_prompt(tokenizer, prompt, tokenizer_max_length=args.tokenizer_max_length)
+ prompt_embeds = encode_prompt(
+ text_encoder,
+ text_inputs.input_ids,
+ text_inputs.attention_mask,
+ text_encoder_use_attention_mask=args.text_encoder_use_attention_mask,
+ )
+
+ return prompt_embeds
+
+ pre_computed_encoder_hidden_states = compute_text_embeddings(args.instance_prompt)
+ validation_prompt_negative_prompt_embeds = compute_text_embeddings("")
+
+ if args.validation_prompt is not None:
+ validation_prompt_encoder_hidden_states = compute_text_embeddings(args.validation_prompt)
+ else:
+ validation_prompt_encoder_hidden_states = None
+
+ if args.class_prompt is not None:
+ pre_computed_class_prompt_encoder_hidden_states = compute_text_embeddings(args.class_prompt)
+ else:
+ pre_computed_class_prompt_encoder_hidden_states = None
+
+ text_encoder = None
+ tokenizer = None
+
+ gc.collect()
+ torch.cuda.empty_cache()
+ else:
+ pre_computed_encoder_hidden_states = None
+ validation_prompt_encoder_hidden_states = None
+ validation_prompt_negative_prompt_embeds = None
+ pre_computed_class_prompt_encoder_hidden_states = None
+
+ # Dataset and DataLoaders creation:
+ train_dataset = DreamBoothDataset(
+ instance_data_root=args.instance_data_dir,
+ instance_prompt=args.instance_prompt,
+ class_data_root=args.class_data_dir if args.with_prior_preservation else None,
+ class_prompt=args.class_prompt,
+ class_num=args.num_class_images,
+ tokenizer=tokenizer,
+ size=args.resolution,
+ center_crop=args.center_crop,
+ encoder_hidden_states=pre_computed_encoder_hidden_states,
+ class_prompt_encoder_hidden_states=pre_computed_class_prompt_encoder_hidden_states,
+ tokenizer_max_length=args.tokenizer_max_length,
+ )
+
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset,
+ batch_size=args.train_batch_size,
+ shuffle=True,
+ collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),
+ num_workers=args.dataloader_num_workers,
+ )
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
+ num_cycles=args.lr_num_cycles,
+ power=args.lr_power,
+ )
+
+ # Prepare everything with our `accelerator`.
+ if args.train_text_encoder:
+ unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ unet, text_encoder, optimizer, train_dataloader, lr_scheduler
+ )
+ else:
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ unet, optimizer, train_dataloader, lr_scheduler
+ )
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ tracker_config = vars(copy.deepcopy(args))
+ tracker_config.pop("validation_images")
+ accelerator.init_trackers("dreambooth-lora", config=tracker_config)
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the mos recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ initial_global_step = 0
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ initial_global_step = global_step
+ first_epoch = global_step // num_update_steps_per_epoch
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
+
+ for epoch in range(first_epoch, args.num_train_epochs):
+ unet.train()
+ if args.train_text_encoder:
+ text_encoder.train()
+ for step, batch in enumerate(train_dataloader):
+ with accelerator.accumulate(unet):
+ pixel_values = batch["pixel_values"].to(dtype=weight_dtype)
+
+ if vae is not None:
+ # Convert images to latent space
+ model_input = vae.encode(pixel_values).latent_dist.sample()
+ model_input = model_input * vae.config.scaling_factor
+ else:
+ model_input = pixel_values
+
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(model_input)
+ bsz, channels, height, width = model_input.shape
+ # Sample a random timestep for each image
+ if args.loss_type == "huber" or args.loss_type == "smooth_l1":
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (1,), device="cpu")
+ timestep = timesteps.item()
+
+ if args.huber_schedule == "exponential":
+ alpha = -math.log(args.huber_c) / noise_scheduler.config.num_train_timesteps
+ huber_c = math.exp(-alpha * timestep)
+ elif args.huber_schedule == "snr":
+ alphas_cumprod = noise_scheduler.alphas_cumprod[timestep]
+ sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5
+ huber_c = (1 - args.huber_c) / (1 + sigmas) ** 2 + args.huber_c
+ elif args.huber_schedule == "constant":
+ huber_c = args.huber_c
+ else:
+ raise NotImplementedError(f"Unknown Huber loss schedule {args.huber_schedule}!")
+
+ timesteps = timesteps.repeat(bsz).to(model_input.device)
+ elif args.loss_type == "l2":
+ timesteps = torch.randint(
+ 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device
+ )
+ huber_c = 1 # may be anything, as it's not used
+ else:
+ raise NotImplementedError(f"Unknown loss type {args.loss_type}")
+
+ timesteps = timesteps.long()
+
+ # Add noise to the model input according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)
+
+ # Get the text embedding for conditioning
+ if args.pre_compute_text_embeddings:
+ encoder_hidden_states = batch["input_ids"]
+ else:
+ encoder_hidden_states = encode_prompt(
+ text_encoder,
+ batch["input_ids"],
+ batch["attention_mask"],
+ text_encoder_use_attention_mask=args.text_encoder_use_attention_mask,
+ )
+
+ if unwrap_model(unet).config.in_channels == channels * 2:
+ noisy_model_input = torch.cat([noisy_model_input, noisy_model_input], dim=1)
+
+ if args.class_labels_conditioning == "timesteps":
+ class_labels = timesteps
+ else:
+ class_labels = None
+
+ # Predict the noise residual
+ model_pred = unet(
+ noisy_model_input,
+ timesteps,
+ encoder_hidden_states,
+ class_labels=class_labels,
+ return_dict=False,
+ )[0]
+
+ # if model predicts variance, throw away the prediction. we will only train on the
+ # simplified training objective. This means that all schedulers using the fine tuned
+ # model must be configured to use one of the fixed variance variance types.
+ if model_pred.shape[1] == 6:
+ model_pred, _ = torch.chunk(model_pred, 2, dim=1)
+
+ # Get the target for loss depending on the prediction type
+ if noise_scheduler.config.prediction_type == "epsilon":
+ target = noise
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ target = noise_scheduler.get_velocity(model_input, noise, timesteps)
+ else:
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
+
+ if args.with_prior_preservation:
+ # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
+ model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
+ target, target_prior = torch.chunk(target, 2, dim=0)
+
+ # Compute instance loss
+ loss = conditional_loss(
+ model_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c
+ )
+
+ # Compute prior loss
+ prior_loss = conditional_loss(
+ model_pred_prior.float(),
+ target_prior.float(),
+ reduction="mean",
+ loss_type=args.loss_type,
+ huber_c=huber_c,
+ )
+
+ # Add the prior loss to the instance loss.
+ loss = loss + args.prior_loss_weight * prior_loss
+ else:
+ loss = conditional_loss(
+ model_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c
+ )
+
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ accelerator.clip_grad_norm_(params_to_optimize, args.max_grad_norm)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ global_step += 1
+
+ if accelerator.is_main_process:
+ if global_step % args.checkpointing_steps == 0:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+ accelerator.log(logs, step=global_step)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ if accelerator.is_main_process:
+ if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
+ # create pipeline
+ pipeline = DiffusionPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ unet=unwrap_model(unet),
+ text_encoder=None if args.pre_compute_text_embeddings else unwrap_model(text_encoder),
+ revision=args.revision,
+ variant=args.variant,
+ torch_dtype=weight_dtype,
+ )
+
+ if args.pre_compute_text_embeddings:
+ pipeline_args = {
+ "prompt_embeds": validation_prompt_encoder_hidden_states,
+ "negative_prompt_embeds": validation_prompt_negative_prompt_embeds,
+ }
+ else:
+ pipeline_args = {"prompt": args.validation_prompt}
+
+ images = log_validation(
+ pipeline,
+ args,
+ accelerator,
+ pipeline_args,
+ epoch,
+ )
+
+ # Save the lora layers
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ unet = unwrap_model(unet)
+ unet = unet.to(torch.float32)
+
+ unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet))
+
+ if args.train_text_encoder:
+ text_encoder = unwrap_model(text_encoder)
+ text_encoder_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(text_encoder))
+ else:
+ text_encoder_state_dict = None
+
+ StableDiffusionLoraLoaderMixin.save_lora_weights(
+ save_directory=args.output_dir,
+ unet_lora_layers=unet_lora_state_dict,
+ text_encoder_lora_layers=text_encoder_state_dict,
+ )
+
+ # Final inference
+ # Load previous pipeline
+ pipeline = DiffusionPipeline.from_pretrained(
+ args.pretrained_model_name_or_path, revision=args.revision, variant=args.variant, torch_dtype=weight_dtype
+ )
+
+ # load attention processors
+ pipeline.load_lora_weights(args.output_dir, weight_name="pytorch_lora_weights.safetensors")
+
+ # run inference
+ images = []
+ if args.validation_prompt and args.num_validation_images > 0:
+ pipeline_args = {"prompt": args.validation_prompt, "num_inference_steps": 25}
+ images = log_validation(
+ pipeline,
+ args,
+ accelerator,
+ pipeline_args,
+ epoch,
+ is_final_validation=True,
+ )
+
+ if args.push_to_hub:
+ save_model_card(
+ repo_id,
+ images=images,
+ base_model=args.pretrained_model_name_or_path,
+ train_text_encoder=args.train_text_encoder,
+ prompt=args.instance_prompt,
+ repo_folder=args.output_dir,
+ pipeline=pipeline,
+ )
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ main(args)
diff --git a/diffusers/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth_lora_sdxl.py b/diffusers/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth_lora_sdxl.py
new file mode 100644
index 0000000000000000000000000000000000000000..d16780131139123a61cff6398a6bd45d26678993
--- /dev/null
+++ b/diffusers/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth_lora_sdxl.py
@@ -0,0 +1,2078 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+import argparse
+import contextlib
+import gc
+import itertools
+import json
+import logging
+import math
+import os
+import random
+import shutil
+import warnings
+from pathlib import Path
+from typing import Optional
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
+from huggingface_hub import create_repo, hf_hub_download, upload_folder
+from huggingface_hub.utils import insecure_hashlib
+from packaging import version
+from peft import LoraConfig, set_peft_model_state_dict
+from peft.utils import get_peft_model_state_dict
+from PIL import Image
+from PIL.ImageOps import exif_transpose
+from safetensors.torch import load_file, save_file
+from torch.utils.data import Dataset
+from torchvision import transforms
+from torchvision.transforms.functional import crop
+from tqdm.auto import tqdm
+from transformers import AutoTokenizer, PretrainedConfig
+
+import diffusers
+from diffusers import (
+ AutoencoderKL,
+ DDPMScheduler,
+ DPMSolverMultistepScheduler,
+ EDMEulerScheduler,
+ EulerDiscreteScheduler,
+ StableDiffusionXLPipeline,
+ UNet2DConditionModel,
+)
+from diffusers.loaders import StableDiffusionLoraLoaderMixin
+from diffusers.optimization import get_scheduler
+from diffusers.training_utils import _set_state_dict_into_text_encoder, cast_training_params, compute_snr
+from diffusers.utils import (
+ check_min_version,
+ convert_all_state_dict_to_peft,
+ convert_state_dict_to_diffusers,
+ convert_state_dict_to_kohya,
+ convert_unet_state_dict_to_peft,
+ is_wandb_available,
+)
+from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
+from diffusers.utils.import_utils import is_xformers_available
+from diffusers.utils.torch_utils import is_compiled_module
+
+
+if is_wandb_available():
+ import wandb
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.28.0.dev0")
+
+logger = get_logger(__name__)
+
+
+def determine_scheduler_type(pretrained_model_name_or_path, revision):
+ model_index_filename = "model_index.json"
+ if os.path.isdir(pretrained_model_name_or_path):
+ model_index = os.path.join(pretrained_model_name_or_path, model_index_filename)
+ else:
+ model_index = hf_hub_download(
+ repo_id=pretrained_model_name_or_path, filename=model_index_filename, revision=revision
+ )
+
+ with open(model_index, "r") as f:
+ scheduler_type = json.load(f)["scheduler"][1]
+ return scheduler_type
+
+
+def save_model_card(
+ repo_id: str,
+ use_dora: bool,
+ images=None,
+ base_model: str = None,
+ train_text_encoder=False,
+ instance_prompt=None,
+ validation_prompt=None,
+ repo_folder=None,
+ vae_path=None,
+):
+ widget_dict = []
+ if images is not None:
+ for i, image in enumerate(images):
+ image.save(os.path.join(repo_folder, f"image_{i}.png"))
+ widget_dict.append(
+ {"text": validation_prompt if validation_prompt else " ", "output": {"url": f"image_{i}.png"}}
+ )
+
+ model_description = f"""
+# {'SDXL' if 'playground' not in base_model else 'Playground'} LoRA DreamBooth - {repo_id}
+
+
+
+## Model description
+
+These are {repo_id} LoRA adaption weights for {base_model}.
+
+The weights were trained using [DreamBooth](https://dreambooth.github.io/).
+
+LoRA for the text encoder was enabled: {train_text_encoder}.
+
+Special VAE used for training: {vae_path}.
+
+## Trigger words
+
+You should use {instance_prompt} to trigger the image generation.
+
+## Download model
+
+Weights for this model are available in Safetensors format.
+
+[Download]({repo_id}/tree/main) them in the Files & versions tab.
+
+"""
+ if "playground" in base_model:
+ model_description += """\n
+## License
+
+Please adhere to the licensing terms as described [here](https://huggingface.co/playgroundai/playground-v2.5-1024px-aesthetic/blob/main/LICENSE.md).
+"""
+ model_card = load_or_create_model_card(
+ repo_id_or_path=repo_id,
+ from_training=True,
+ license="openrail++" if "playground" not in base_model else "playground-v2dot5-community",
+ base_model=base_model,
+ prompt=instance_prompt,
+ model_description=model_description,
+ widget=widget_dict,
+ )
+ tags = [
+ "text-to-image",
+ "text-to-image",
+ "diffusers-training",
+ "diffusers",
+ "lora" if not use_dora else "dora",
+ "template:sd-lora",
+ ]
+ if "playground" in base_model:
+ tags.extend(["playground", "playground-diffusers"])
+ else:
+ tags.extend(["stable-diffusion-xl", "stable-diffusion-xl-diffusers"])
+
+ model_card = populate_model_card(model_card, tags=tags)
+ model_card.save(os.path.join(repo_folder, "README.md"))
+
+
+def log_validation(
+ pipeline,
+ args,
+ accelerator,
+ pipeline_args,
+ epoch,
+ is_final_validation=False,
+):
+ logger.info(
+ f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
+ f" {args.validation_prompt}."
+ )
+
+ # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
+ scheduler_args = {}
+
+ if not args.do_edm_style_training:
+ if "variance_type" in pipeline.scheduler.config:
+ variance_type = pipeline.scheduler.config.variance_type
+
+ if variance_type in ["learned", "learned_range"]:
+ variance_type = "fixed_small"
+
+ scheduler_args["variance_type"] = variance_type
+
+ pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args)
+
+ pipeline = pipeline.to(accelerator.device)
+ pipeline.set_progress_bar_config(disable=True)
+
+ # run inference
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
+ # Currently the context determination is a bit hand-wavy. We can improve it in the future if there's a better
+ # way to condition it. Reference: https://github.com/huggingface/diffusers/pull/7126#issuecomment-1968523051
+ inference_ctx = (
+ contextlib.nullcontext() if "playground" in args.pretrained_model_name_or_path else torch.cuda.amp.autocast()
+ )
+
+ with inference_ctx:
+ images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)]
+
+ for tracker in accelerator.trackers:
+ phase_name = "test" if is_final_validation else "validation"
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images(phase_name, np_images, epoch, dataformats="NHWC")
+ if tracker.name == "wandb":
+ tracker.log(
+ {
+ phase_name: [
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images)
+ ]
+ }
+ )
+
+ del pipeline
+ torch.cuda.empty_cache()
+
+ return images
+
+
+def import_model_class_from_model_name_or_path(
+ pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
+):
+ text_encoder_config = PretrainedConfig.from_pretrained(
+ pretrained_model_name_or_path, subfolder=subfolder, revision=revision
+ )
+ model_class = text_encoder_config.architectures[0]
+
+ if model_class == "CLIPTextModel":
+ from transformers import CLIPTextModel
+
+ return CLIPTextModel
+ elif model_class == "CLIPTextModelWithProjection":
+ from transformers import CLIPTextModelWithProjection
+
+ return CLIPTextModelWithProjection
+ else:
+ raise ValueError(f"{model_class} is not supported.")
+
+
+def parse_args(input_args=None):
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--pretrained_vae_model_name_or_path",
+ type=str,
+ default=None,
+ help="Path to pretrained VAE model with better numerical stability. More details: https://github.com/huggingface/diffusers/pull/4038.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--variant",
+ type=str,
+ default=None,
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
+ )
+ parser.add_argument(
+ "--dataset_name",
+ type=str,
+ default=None,
+ help=(
+ "The name of the Dataset (from the HuggingFace hub) containing the training data of instance images (could be your own, possibly private,"
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
+ " or to a folder containing files that 🤗 Datasets can understand."
+ ),
+ )
+ parser.add_argument(
+ "--dataset_config_name",
+ type=str,
+ default=None,
+ help="The config of the Dataset, leave as None if there's only one config.",
+ )
+ parser.add_argument(
+ "--instance_data_dir",
+ type=str,
+ default=None,
+ help=("A folder containing the training data. "),
+ )
+
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+
+ parser.add_argument(
+ "--image_column",
+ type=str,
+ default="image",
+ help="The column of the dataset containing the target image. By "
+ "default, the standard Image Dataset maps out 'file_name' "
+ "to 'image'.",
+ )
+ parser.add_argument(
+ "--caption_column",
+ type=str,
+ default=None,
+ help="The column of the dataset containing the instance prompt for each image",
+ )
+
+ parser.add_argument("--repeats", type=int, default=1, help="How many times to repeat the training data.")
+
+ parser.add_argument(
+ "--class_data_dir",
+ type=str,
+ default=None,
+ required=False,
+ help="A folder containing the training data of class images.",
+ )
+ parser.add_argument(
+ "--instance_prompt",
+ type=str,
+ default=None,
+ required=True,
+ help="The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'",
+ )
+ parser.add_argument(
+ "--class_prompt",
+ type=str,
+ default=None,
+ help="The prompt to specify images in the same class as provided instance images.",
+ )
+ parser.add_argument(
+ "--validation_prompt",
+ type=str,
+ default=None,
+ help="A prompt that is used during validation to verify that the model is learning.",
+ )
+ parser.add_argument(
+ "--num_validation_images",
+ type=int,
+ default=4,
+ help="Number of images that should be generated during validation with `validation_prompt`.",
+ )
+ parser.add_argument(
+ "--validation_epochs",
+ type=int,
+ default=50,
+ help=(
+ "Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt"
+ " `args.validation_prompt` multiple times: `args.num_validation_images`."
+ ),
+ )
+ parser.add_argument(
+ "--do_edm_style_training",
+ default=False,
+ action="store_true",
+ help="Flag to conduct training using the EDM formulation as introduced in https://arxiv.org/abs/2206.00364.",
+ )
+ parser.add_argument(
+ "--with_prior_preservation",
+ default=False,
+ action="store_true",
+ help="Flag to add prior preservation loss.",
+ )
+ parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.")
+ parser.add_argument(
+ "--num_class_images",
+ type=int,
+ default=100,
+ help=(
+ "Minimal class images for prior preservation loss. If there are not enough images already present in"
+ " class_data_dir, additional images will be sampled with class_prompt."
+ ),
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="lora-dreambooth-model",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument(
+ "--output_kohya_format",
+ action="store_true",
+ help="Flag to additionally generate final state dict in the Kohya format so that it becomes compatible with A111, Comfy, Kohya, etc.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=1024,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--center_crop",
+ default=False,
+ action="store_true",
+ help=(
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
+ " cropped. The images will be resized to the resolution first before cropping."
+ ),
+ )
+ parser.add_argument(
+ "--random_flip",
+ action="store_true",
+ help="whether to randomly flip images horizontally",
+ )
+ parser.add_argument(
+ "--train_text_encoder",
+ action="store_true",
+ help="Whether to train the text encoder. If set, the text encoder should be float32 precision.",
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument(
+ "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=1)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
+ " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+
+ parser.add_argument(
+ "--text_encoder_lr",
+ type=float,
+ default=5e-6,
+ help="Text encoder learning rate to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+
+ parser.add_argument(
+ "--snr_gamma",
+ type=float,
+ default=None,
+ help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
+ "More details here: https://arxiv.org/abs/2303.09556.",
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--lr_num_cycles",
+ type=int,
+ default=1,
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
+ )
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+
+ parser.add_argument(
+ "--optimizer",
+ type=str,
+ default="AdamW",
+ help=('The optimizer type to use. Choose between ["AdamW", "prodigy"]'),
+ )
+
+ parser.add_argument(
+ "--use_8bit_adam",
+ action="store_true",
+ help="Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW",
+ )
+
+ parser.add_argument(
+ "--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam and Prodigy optimizers."
+ )
+ parser.add_argument(
+ "--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam and Prodigy optimizers."
+ )
+ parser.add_argument(
+ "--prodigy_beta3",
+ type=float,
+ default=None,
+ help="coefficients for computing the Prodigy stepsize using running averages. If set to None, "
+ "uses the value of square root of beta2. Ignored if optimizer is adamW",
+ )
+ parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params")
+ parser.add_argument(
+ "--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder"
+ )
+
+ parser.add_argument(
+ "--adam_epsilon",
+ type=float,
+ default=1e-08,
+ help="Epsilon value for the Adam optimizer and Prodigy optimizers.",
+ )
+
+ parser.add_argument(
+ "--prodigy_use_bias_correction",
+ type=bool,
+ default=True,
+ help="Turn on Adam's bias correction. True by default. Ignored if optimizer is adamW",
+ )
+ parser.add_argument(
+ "--prodigy_safeguard_warmup",
+ type=bool,
+ default=True,
+ help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. "
+ "Ignored if optimizer is adamW",
+ )
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument(
+ "--prior_generation_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp32", "fp16", "bf16"],
+ help=(
+ "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32."
+ ),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+ parser.add_argument(
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
+ )
+ parser.add_argument(
+ "--rank",
+ type=int,
+ default=4,
+ help=("The dimension of the LoRA update matrices."),
+ )
+ parser.add_argument(
+ "--use_dora",
+ action="store_true",
+ default=False,
+ help=(
+ "Wether to train a DoRA as proposed in- DoRA: Weight-Decomposed Low-Rank Adaptation https://arxiv.org/abs/2402.09353. "
+ "Note: to use DoRA you need to install peft from main, `pip install git+https://github.com/huggingface/peft.git`"
+ ),
+ )
+ parser.add_argument(
+ "--loss_type",
+ type=str,
+ default="l2",
+ choices=["l2", "huber", "smooth_l1"],
+ help="The type of loss to use and whether it's timestep-scheduled. See Issue #7488 for more info.",
+ )
+ parser.add_argument(
+ "--huber_schedule",
+ type=str,
+ default="snr",
+ choices=["constant", "exponential", "snr"],
+ help="The schedule to use for the huber losses parameter",
+ )
+ parser.add_argument(
+ "--huber_c",
+ type=float,
+ default=0.1,
+ help="The huber loss parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type.",
+ )
+
+ if input_args is not None:
+ args = parser.parse_args(input_args)
+ else:
+ args = parser.parse_args()
+
+ if args.dataset_name is None and args.instance_data_dir is None:
+ raise ValueError("Specify either `--dataset_name` or `--instance_data_dir`")
+
+ if args.dataset_name is not None and args.instance_data_dir is not None:
+ raise ValueError("Specify only one of `--dataset_name` or `--instance_data_dir`")
+
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ if args.with_prior_preservation:
+ if args.class_data_dir is None:
+ raise ValueError("You must specify a data directory for class images.")
+ if args.class_prompt is None:
+ raise ValueError("You must specify prompt for class images.")
+ else:
+ # logger is not available yet
+ if args.class_data_dir is not None:
+ warnings.warn("You need not use --class_data_dir without --with_prior_preservation.")
+ if args.class_prompt is not None:
+ warnings.warn("You need not use --class_prompt without --with_prior_preservation.")
+
+ return args
+
+
+class DreamBoothDataset(Dataset):
+ """
+ A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
+ It pre-processes the images.
+ """
+
+ def __init__(
+ self,
+ instance_data_root,
+ instance_prompt,
+ class_prompt,
+ class_data_root=None,
+ class_num=None,
+ size=1024,
+ repeats=1,
+ center_crop=False,
+ ):
+ self.size = size
+ self.center_crop = center_crop
+
+ self.instance_prompt = instance_prompt
+ self.custom_instance_prompts = None
+ self.class_prompt = class_prompt
+
+ # if --dataset_name is provided or a metadata jsonl file is provided in the local --instance_data directory,
+ # we load the training data using load_dataset
+ if args.dataset_name is not None:
+ try:
+ from datasets import load_dataset
+ except ImportError:
+ raise ImportError(
+ "You are trying to load your data using the datasets library. If you wish to train using custom "
+ "captions please install the datasets library: `pip install datasets`. If you wish to load a "
+ "local folder containing images only, specify --instance_data_dir instead."
+ )
+ # Downloading and loading a dataset from the hub.
+ # See more about loading custom images at
+ # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script
+ dataset = load_dataset(
+ args.dataset_name,
+ args.dataset_config_name,
+ cache_dir=args.cache_dir,
+ )
+ # Preprocessing the datasets.
+ column_names = dataset["train"].column_names
+
+ # 6. Get the column names for input/target.
+ if args.image_column is None:
+ image_column = column_names[0]
+ logger.info(f"image column defaulting to {image_column}")
+ else:
+ image_column = args.image_column
+ if image_column not in column_names:
+ raise ValueError(
+ f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
+ )
+ instance_images = dataset["train"][image_column]
+
+ if args.caption_column is None:
+ logger.info(
+ "No caption column provided, defaulting to instance_prompt for all images. If your dataset "
+ "contains captions/prompts for the images, make sure to specify the "
+ "column as --caption_column"
+ )
+ self.custom_instance_prompts = None
+ else:
+ if args.caption_column not in column_names:
+ raise ValueError(
+ f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
+ )
+ custom_instance_prompts = dataset["train"][args.caption_column]
+ # create final list of captions according to --repeats
+ self.custom_instance_prompts = []
+ for caption in custom_instance_prompts:
+ self.custom_instance_prompts.extend(itertools.repeat(caption, repeats))
+ else:
+ self.instance_data_root = Path(instance_data_root)
+ if not self.instance_data_root.exists():
+ raise ValueError("Instance images root doesn't exists.")
+
+ instance_images = [Image.open(path) for path in list(Path(instance_data_root).iterdir())]
+ self.custom_instance_prompts = None
+
+ self.instance_images = []
+ for img in instance_images:
+ self.instance_images.extend(itertools.repeat(img, repeats))
+
+ # image processing to prepare for using SD-XL micro-conditioning
+ self.original_sizes = []
+ self.crop_top_lefts = []
+ self.pixel_values = []
+ train_resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR)
+ train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size)
+ train_flip = transforms.RandomHorizontalFlip(p=1.0)
+ train_transforms = transforms.Compose(
+ [
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+ )
+ for image in self.instance_images:
+ image = exif_transpose(image)
+ if not image.mode == "RGB":
+ image = image.convert("RGB")
+ self.original_sizes.append((image.height, image.width))
+ image = train_resize(image)
+ if args.random_flip and random.random() < 0.5:
+ # flip
+ image = train_flip(image)
+ if args.center_crop:
+ y1 = max(0, int(round((image.height - args.resolution) / 2.0)))
+ x1 = max(0, int(round((image.width - args.resolution) / 2.0)))
+ image = train_crop(image)
+ else:
+ y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution))
+ image = crop(image, y1, x1, h, w)
+ crop_top_left = (y1, x1)
+ self.crop_top_lefts.append(crop_top_left)
+ image = train_transforms(image)
+ self.pixel_values.append(image)
+
+ self.num_instance_images = len(self.instance_images)
+ self._length = self.num_instance_images
+
+ if class_data_root is not None:
+ self.class_data_root = Path(class_data_root)
+ self.class_data_root.mkdir(parents=True, exist_ok=True)
+ self.class_images_path = list(self.class_data_root.iterdir())
+ if class_num is not None:
+ self.num_class_images = min(len(self.class_images_path), class_num)
+ else:
+ self.num_class_images = len(self.class_images_path)
+ self._length = max(self.num_class_images, self.num_instance_images)
+ else:
+ self.class_data_root = None
+
+ self.image_transforms = transforms.Compose(
+ [
+ transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
+ transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+ )
+
+ def __len__(self):
+ return self._length
+
+ def __getitem__(self, index):
+ example = {}
+ instance_image = self.pixel_values[index % self.num_instance_images]
+ original_size = self.original_sizes[index % self.num_instance_images]
+ crop_top_left = self.crop_top_lefts[index % self.num_instance_images]
+ example["instance_images"] = instance_image
+ example["original_size"] = original_size
+ example["crop_top_left"] = crop_top_left
+
+ if self.custom_instance_prompts:
+ caption = self.custom_instance_prompts[index % self.num_instance_images]
+ if caption:
+ example["instance_prompt"] = caption
+ else:
+ example["instance_prompt"] = self.instance_prompt
+
+ else: # custom prompts were provided, but length does not match size of image dataset
+ example["instance_prompt"] = self.instance_prompt
+
+ if self.class_data_root:
+ class_image = Image.open(self.class_images_path[index % self.num_class_images])
+ class_image = exif_transpose(class_image)
+
+ if not class_image.mode == "RGB":
+ class_image = class_image.convert("RGB")
+ example["class_images"] = self.image_transforms(class_image)
+ example["class_prompt"] = self.class_prompt
+
+ return example
+
+
+def collate_fn(examples, with_prior_preservation=False):
+ pixel_values = [example["instance_images"] for example in examples]
+ prompts = [example["instance_prompt"] for example in examples]
+ original_sizes = [example["original_size"] for example in examples]
+ crop_top_lefts = [example["crop_top_left"] for example in examples]
+
+ # Concat class and instance examples for prior preservation.
+ # We do this to avoid doing two forward passes.
+ if with_prior_preservation:
+ pixel_values += [example["class_images"] for example in examples]
+ prompts += [example["class_prompt"] for example in examples]
+ original_sizes += [example["original_size"] for example in examples]
+ crop_top_lefts += [example["crop_top_left"] for example in examples]
+
+ pixel_values = torch.stack(pixel_values)
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
+
+ batch = {
+ "pixel_values": pixel_values,
+ "prompts": prompts,
+ "original_sizes": original_sizes,
+ "crop_top_lefts": crop_top_lefts,
+ }
+ return batch
+
+
+class PromptDataset(Dataset):
+ """A simple dataset to prepare the prompts to generate class images on multiple GPUs."""
+
+ def __init__(self, prompt, num_samples):
+ self.prompt = prompt
+ self.num_samples = num_samples
+
+ def __len__(self):
+ return self.num_samples
+
+ def __getitem__(self, index):
+ example = {}
+ example["prompt"] = self.prompt
+ example["index"] = index
+ return example
+
+
+def tokenize_prompt(tokenizer, prompt):
+ text_inputs = tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ return text_input_ids
+
+
+# Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt
+def encode_prompt(text_encoders, tokenizers, prompt, text_input_ids_list=None):
+ prompt_embeds_list = []
+
+ for i, text_encoder in enumerate(text_encoders):
+ if tokenizers is not None:
+ tokenizer = tokenizers[i]
+ text_input_ids = tokenize_prompt(tokenizer, prompt)
+ else:
+ assert text_input_ids_list is not None
+ text_input_ids = text_input_ids_list[i]
+
+ prompt_embeds = text_encoder(
+ text_input_ids.to(text_encoder.device), output_hidden_states=True, return_dict=False
+ )
+
+ # We are only ALWAYS interested in the pooled output of the final text encoder
+ pooled_prompt_embeds = prompt_embeds[0]
+ prompt_embeds = prompt_embeds[-1][-2]
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
+ prompt_embeds_list.append(prompt_embeds)
+
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
+ pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
+ return prompt_embeds, pooled_prompt_embeds
+
+
+# NOTE: if you're using the scheduled version, huber_c has to depend on the timesteps already
+def conditional_loss(
+ model_pred: torch.Tensor,
+ target: torch.Tensor,
+ reduction: str = "mean",
+ loss_type: str = "l2",
+ huber_c: float = 0.1,
+ weighting: Optional[torch.Tensor] = None,
+):
+ if loss_type == "l2":
+ if weighting is not None:
+ loss = torch.mean(
+ (weighting * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1),
+ 1,
+ )
+ if reduction == "mean":
+ loss = torch.mean(loss)
+ elif reduction == "sum":
+ loss = torch.sum(loss)
+ else:
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction=reduction)
+
+ elif loss_type == "huber":
+ if weighting is not None:
+ loss = torch.mean(
+ (
+ 2
+ * huber_c
+ * (
+ torch.sqrt(weighting.float() * (model_pred.float() - target.float()) ** 2 + huber_c**2)
+ - huber_c
+ )
+ ).reshape(target.shape[0], -1),
+ 1,
+ )
+ if reduction == "mean":
+ loss = torch.mean(loss)
+ elif reduction == "sum":
+ loss = torch.sum(loss)
+ else:
+ loss = 2 * huber_c * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
+ if reduction == "mean":
+ loss = torch.mean(loss)
+ elif reduction == "sum":
+ loss = torch.sum(loss)
+ elif loss_type == "smooth_l1":
+ if weighting is not None:
+ loss = torch.mean(
+ (
+ 2
+ * (
+ torch.sqrt(weighting.float() * (model_pred.float() - target.float()) ** 2 + huber_c**2)
+ - huber_c
+ )
+ ).reshape(target.shape[0], -1),
+ 1,
+ )
+ if reduction == "mean":
+ loss = torch.mean(loss)
+ elif reduction == "sum":
+ loss = torch.sum(loss)
+ else:
+ loss = 2 * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
+ if reduction == "mean":
+ loss = torch.mean(loss)
+ elif reduction == "sum":
+ loss = torch.sum(loss)
+ else:
+ raise NotImplementedError(f"Unsupported Loss Type {loss_type}")
+ return loss
+
+
+def main(args):
+ if args.report_to == "wandb" and args.hub_token is not None:
+ raise ValueError(
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
+ " Please use `huggingface-cli login` to authenticate with the Hub."
+ )
+
+ if args.do_edm_style_training and args.snr_gamma is not None:
+ raise ValueError("Min-SNR formulation is not supported when conducting EDM-style training.")
+
+ logging_dir = Path(args.output_dir, args.logging_dir)
+
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+ kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ kwargs_handlers=[kwargs],
+ )
+
+ if args.report_to == "wandb":
+ if not is_wandb_available():
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Generate class images if prior preservation is enabled.
+ if args.with_prior_preservation:
+ class_images_dir = Path(args.class_data_dir)
+ if not class_images_dir.exists():
+ class_images_dir.mkdir(parents=True)
+ cur_class_images = len(list(class_images_dir.iterdir()))
+
+ if cur_class_images < args.num_class_images:
+ torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
+ if args.prior_generation_precision == "fp32":
+ torch_dtype = torch.float32
+ elif args.prior_generation_precision == "fp16":
+ torch_dtype = torch.float16
+ elif args.prior_generation_precision == "bf16":
+ torch_dtype = torch.bfloat16
+ pipeline = StableDiffusionXLPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ torch_dtype=torch_dtype,
+ revision=args.revision,
+ variant=args.variant,
+ )
+ pipeline.set_progress_bar_config(disable=True)
+
+ num_new_images = args.num_class_images - cur_class_images
+ logger.info(f"Number of class images to sample: {num_new_images}.")
+
+ sample_dataset = PromptDataset(args.class_prompt, num_new_images)
+ sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
+
+ sample_dataloader = accelerator.prepare(sample_dataloader)
+ pipeline.to(accelerator.device)
+
+ for example in tqdm(
+ sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
+ ):
+ images = pipeline(example["prompt"]).images
+
+ for i, image in enumerate(images):
+ hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest()
+ image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
+ image.save(image_filename)
+
+ del pipeline
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
+ ).repo_id
+
+ # Load the tokenizers
+ tokenizer_one = AutoTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="tokenizer",
+ revision=args.revision,
+ use_fast=False,
+ )
+ tokenizer_two = AutoTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="tokenizer_2",
+ revision=args.revision,
+ use_fast=False,
+ )
+
+ # import correct text encoder classes
+ text_encoder_cls_one = import_model_class_from_model_name_or_path(
+ args.pretrained_model_name_or_path, args.revision
+ )
+ text_encoder_cls_two = import_model_class_from_model_name_or_path(
+ args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2"
+ )
+
+ # Load scheduler and models
+ scheduler_type = determine_scheduler_type(args.pretrained_model_name_or_path, args.revision)
+ if "EDM" in scheduler_type:
+ args.do_edm_style_training = True
+ noise_scheduler = EDMEulerScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
+ logger.info("Performing EDM-style training!")
+ elif args.do_edm_style_training:
+ noise_scheduler = EulerDiscreteScheduler.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="scheduler"
+ )
+ logger.info("Performing EDM-style training!")
+ else:
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
+
+ text_encoder_one = text_encoder_cls_one.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
+ )
+ text_encoder_two = text_encoder_cls_two.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant
+ )
+ vae_path = (
+ args.pretrained_model_name_or_path
+ if args.pretrained_vae_model_name_or_path is None
+ else args.pretrained_vae_model_name_or_path
+ )
+ vae = AutoencoderKL.from_pretrained(
+ vae_path,
+ subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
+ revision=args.revision,
+ variant=args.variant,
+ )
+ latents_mean = latents_std = None
+ if hasattr(vae.config, "latents_mean") and vae.config.latents_mean is not None:
+ latents_mean = torch.tensor(vae.config.latents_mean).view(1, 4, 1, 1)
+ if hasattr(vae.config, "latents_std") and vae.config.latents_std is not None:
+ latents_std = torch.tensor(vae.config.latents_std).view(1, 4, 1, 1)
+
+ unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
+ )
+
+ # We only train the additional adapter LoRA layers
+ vae.requires_grad_(False)
+ text_encoder_one.requires_grad_(False)
+ text_encoder_two.requires_grad_(False)
+ unet.requires_grad_(False)
+
+ # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
+ # as these weights are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ # Move unet, vae and text_encoder to device and cast to weight_dtype
+ unet.to(accelerator.device, dtype=weight_dtype)
+
+ # The VAE is always in float32 to avoid NaN losses.
+ vae.to(accelerator.device, dtype=torch.float32)
+
+ text_encoder_one.to(accelerator.device, dtype=weight_dtype)
+ text_encoder_two.to(accelerator.device, dtype=weight_dtype)
+
+ if args.enable_xformers_memory_efficient_attention:
+ if is_xformers_available():
+ import xformers
+
+ xformers_version = version.parse(xformers.__version__)
+ if xformers_version == version.parse("0.0.16"):
+ logger.warning(
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, "
+ "please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
+ )
+ unet.enable_xformers_memory_efficient_attention()
+ else:
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
+
+ if args.gradient_checkpointing:
+ unet.enable_gradient_checkpointing()
+ if args.train_text_encoder:
+ text_encoder_one.gradient_checkpointing_enable()
+ text_encoder_two.gradient_checkpointing_enable()
+
+ # now we will add new LoRA weights to the attention layers
+ unet_lora_config = LoraConfig(
+ r=args.rank,
+ use_dora=args.use_dora,
+ lora_alpha=args.rank,
+ init_lora_weights="gaussian",
+ target_modules=["to_k", "to_q", "to_v", "to_out.0"],
+ )
+ unet.add_adapter(unet_lora_config)
+
+ # The text encoder comes from 🤗 transformers, so we cannot directly modify it.
+ # So, instead, we monkey-patch the forward calls of its attention-blocks.
+ if args.train_text_encoder:
+ text_lora_config = LoraConfig(
+ r=args.rank,
+ use_dora=args.use_dora,
+ lora_alpha=args.rank,
+ init_lora_weights="gaussian",
+ target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
+ )
+ text_encoder_one.add_adapter(text_lora_config)
+ text_encoder_two.add_adapter(text_lora_config)
+
+ def unwrap_model(model):
+ model = accelerator.unwrap_model(model)
+ model = model._orig_mod if is_compiled_module(model) else model
+ return model
+
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
+ def save_model_hook(models, weights, output_dir):
+ if accelerator.is_main_process:
+ # there are only two options here. Either are just the unet attn processor layers
+ # or there are the unet and text encoder atten layers
+ unet_lora_layers_to_save = None
+ text_encoder_one_lora_layers_to_save = None
+ text_encoder_two_lora_layers_to_save = None
+
+ for model in models:
+ if isinstance(model, type(unwrap_model(unet))):
+ unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model))
+ elif isinstance(model, type(unwrap_model(text_encoder_one))):
+ text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers(
+ get_peft_model_state_dict(model)
+ )
+ elif isinstance(model, type(unwrap_model(text_encoder_two))):
+ text_encoder_two_lora_layers_to_save = convert_state_dict_to_diffusers(
+ get_peft_model_state_dict(model)
+ )
+ else:
+ raise ValueError(f"unexpected save model: {model.__class__}")
+
+ # make sure to pop weight so that corresponding model is not saved again
+ weights.pop()
+
+ StableDiffusionXLPipeline.save_lora_weights(
+ output_dir,
+ unet_lora_layers=unet_lora_layers_to_save,
+ text_encoder_lora_layers=text_encoder_one_lora_layers_to_save,
+ text_encoder_2_lora_layers=text_encoder_two_lora_layers_to_save,
+ )
+
+ def load_model_hook(models, input_dir):
+ unet_ = None
+ text_encoder_one_ = None
+ text_encoder_two_ = None
+
+ while len(models) > 0:
+ model = models.pop()
+
+ if isinstance(model, type(unwrap_model(unet))):
+ unet_ = model
+ elif isinstance(model, type(unwrap_model(text_encoder_one))):
+ text_encoder_one_ = model
+ elif isinstance(model, type(unwrap_model(text_encoder_two))):
+ text_encoder_two_ = model
+ else:
+ raise ValueError(f"unexpected save model: {model.__class__}")
+
+ lora_state_dict, network_alphas = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir)
+
+ unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")}
+ unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
+ incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
+ if incompatible_keys is not None:
+ # check only for unexpected keys
+ unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
+ if unexpected_keys:
+ logger.warning(
+ f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
+ f" {unexpected_keys}. "
+ )
+
+ if args.train_text_encoder:
+ # Do we need to call `scale_lora_layers()` here?
+ _set_state_dict_into_text_encoder(lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_one_)
+
+ _set_state_dict_into_text_encoder(
+ lora_state_dict, prefix="text_encoder_2.", text_encoder=text_encoder_two_
+ )
+
+ # Make sure the trainable params are in float32. This is again needed since the base models
+ # are in `weight_dtype`. More details:
+ # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
+ if args.mixed_precision == "fp16":
+ models = [unet_]
+ if args.train_text_encoder:
+ models.extend([text_encoder_one_, text_encoder_two_])
+ # only upcast trainable parameters (LoRA) into fp32
+ cast_training_params(models)
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ # Make sure the trainable params are in float32.
+ if args.mixed_precision == "fp16":
+ models = [unet]
+ if args.train_text_encoder:
+ models.extend([text_encoder_one, text_encoder_two])
+
+ # only upcast trainable parameters (LoRA) into fp32
+ cast_training_params(models, dtype=torch.float32)
+
+ unet_lora_parameters = list(filter(lambda p: p.requires_grad, unet.parameters()))
+
+ if args.train_text_encoder:
+ text_lora_parameters_one = list(filter(lambda p: p.requires_grad, text_encoder_one.parameters()))
+ text_lora_parameters_two = list(filter(lambda p: p.requires_grad, text_encoder_two.parameters()))
+
+ # Optimization parameters
+ unet_lora_parameters_with_lr = {"params": unet_lora_parameters, "lr": args.learning_rate}
+ if args.train_text_encoder:
+ # different learning rate for text encoder and unet
+ text_lora_parameters_one_with_lr = {
+ "params": text_lora_parameters_one,
+ "weight_decay": args.adam_weight_decay_text_encoder,
+ "lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate,
+ }
+ text_lora_parameters_two_with_lr = {
+ "params": text_lora_parameters_two,
+ "weight_decay": args.adam_weight_decay_text_encoder,
+ "lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate,
+ }
+ params_to_optimize = [
+ unet_lora_parameters_with_lr,
+ text_lora_parameters_one_with_lr,
+ text_lora_parameters_two_with_lr,
+ ]
+ else:
+ params_to_optimize = [unet_lora_parameters_with_lr]
+
+ # Optimizer creation
+ if not (args.optimizer.lower() == "prodigy" or args.optimizer.lower() == "adamw"):
+ logger.warning(
+ f"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy]."
+ "Defaulting to adamW"
+ )
+ args.optimizer = "adamw"
+
+ if args.use_8bit_adam and not args.optimizer.lower() == "adamw":
+ logger.warning(
+ f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was "
+ f"set to {args.optimizer.lower()}"
+ )
+
+ if args.optimizer.lower() == "adamw":
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
+ )
+
+ optimizer_class = bnb.optim.AdamW8bit
+ else:
+ optimizer_class = torch.optim.AdamW
+
+ optimizer = optimizer_class(
+ params_to_optimize,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ if args.optimizer.lower() == "prodigy":
+ try:
+ import prodigyopt
+ except ImportError:
+ raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`")
+
+ optimizer_class = prodigyopt.Prodigy
+
+ if args.learning_rate <= 0.1:
+ logger.warning(
+ "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0"
+ )
+ if args.train_text_encoder and args.text_encoder_lr:
+ logger.warning(
+ f"Learning rates were provided both for the unet and the text encoder- e.g. text_encoder_lr:"
+ f" {args.text_encoder_lr} and learning_rate: {args.learning_rate}. "
+ f"When using prodigy only learning_rate is used as the initial learning rate."
+ )
+ # changes the learning rate of text_encoder_parameters_one and text_encoder_parameters_two to be
+ # --learning_rate
+ params_to_optimize[1]["lr"] = args.learning_rate
+ params_to_optimize[2]["lr"] = args.learning_rate
+
+ optimizer = optimizer_class(
+ params_to_optimize,
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ beta3=args.prodigy_beta3,
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ decouple=args.prodigy_decouple,
+ use_bias_correction=args.prodigy_use_bias_correction,
+ safeguard_warmup=args.prodigy_safeguard_warmup,
+ )
+
+ # Dataset and DataLoaders creation:
+ train_dataset = DreamBoothDataset(
+ instance_data_root=args.instance_data_dir,
+ instance_prompt=args.instance_prompt,
+ class_prompt=args.class_prompt,
+ class_data_root=args.class_data_dir if args.with_prior_preservation else None,
+ class_num=args.num_class_images,
+ size=args.resolution,
+ repeats=args.repeats,
+ center_crop=args.center_crop,
+ )
+
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset,
+ batch_size=args.train_batch_size,
+ shuffle=True,
+ collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),
+ num_workers=args.dataloader_num_workers,
+ )
+
+ # Computes additional embeddings/ids required by the SDXL UNet.
+ # regular text embeddings (when `train_text_encoder` is not True)
+ # pooled text embeddings
+ # time ids
+
+ def compute_time_ids(original_size, crops_coords_top_left):
+ # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
+ target_size = (args.resolution, args.resolution)
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
+ add_time_ids = torch.tensor([add_time_ids])
+ add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype)
+ return add_time_ids
+
+ if not args.train_text_encoder:
+ tokenizers = [tokenizer_one, tokenizer_two]
+ text_encoders = [text_encoder_one, text_encoder_two]
+
+ def compute_text_embeddings(prompt, text_encoders, tokenizers):
+ with torch.no_grad():
+ prompt_embeds, pooled_prompt_embeds = encode_prompt(text_encoders, tokenizers, prompt)
+ prompt_embeds = prompt_embeds.to(accelerator.device)
+ pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device)
+ return prompt_embeds, pooled_prompt_embeds
+
+ # If no type of tuning is done on the text_encoder and custom instance prompts are NOT
+ # provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid
+ # the redundant encoding.
+ if not args.train_text_encoder and not train_dataset.custom_instance_prompts:
+ instance_prompt_hidden_states, instance_pooled_prompt_embeds = compute_text_embeddings(
+ args.instance_prompt, text_encoders, tokenizers
+ )
+
+ # Handle class prompt for prior-preservation.
+ if args.with_prior_preservation:
+ if not args.train_text_encoder:
+ class_prompt_hidden_states, class_pooled_prompt_embeds = compute_text_embeddings(
+ args.class_prompt, text_encoders, tokenizers
+ )
+
+ # Clear the memory here
+ if not args.train_text_encoder and not train_dataset.custom_instance_prompts:
+ del tokenizers, text_encoders
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
+ # pack the statically computed variables appropriately here. This is so that we don't
+ # have to pass them to the dataloader.
+
+ if not train_dataset.custom_instance_prompts:
+ if not args.train_text_encoder:
+ prompt_embeds = instance_prompt_hidden_states
+ unet_add_text_embeds = instance_pooled_prompt_embeds
+ if args.with_prior_preservation:
+ prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0)
+ unet_add_text_embeds = torch.cat([unet_add_text_embeds, class_pooled_prompt_embeds], dim=0)
+ # if we're optimizing the text encoder (both if instance prompt is used for all images or custom prompts) we need to tokenize and encode the
+ # batch prompts on all training steps
+ else:
+ tokens_one = tokenize_prompt(tokenizer_one, args.instance_prompt)
+ tokens_two = tokenize_prompt(tokenizer_two, args.instance_prompt)
+ if args.with_prior_preservation:
+ class_tokens_one = tokenize_prompt(tokenizer_one, args.class_prompt)
+ class_tokens_two = tokenize_prompt(tokenizer_two, args.class_prompt)
+ tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0)
+ tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0)
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
+ num_cycles=args.lr_num_cycles,
+ power=args.lr_power,
+ )
+
+ # Prepare everything with our `accelerator`.
+ if args.train_text_encoder:
+ unet, text_encoder_one, text_encoder_two, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ unet, text_encoder_one, text_encoder_two, optimizer, train_dataloader, lr_scheduler
+ )
+ else:
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ unet, optimizer, train_dataloader, lr_scheduler
+ )
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ tracker_name = (
+ "dreambooth-lora-sd-xl"
+ if "playground" not in args.pretrained_model_name_or_path
+ else "dreambooth-lora-playground"
+ )
+ accelerator.init_trackers(tracker_name, config=vars(args))
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the mos recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ initial_global_step = 0
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ initial_global_step = global_step
+ first_epoch = global_step // num_update_steps_per_epoch
+
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
+
+ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
+ sigmas = noise_scheduler.sigmas.to(device=accelerator.device, dtype=dtype)
+ schedule_timesteps = noise_scheduler.timesteps.to(accelerator.device)
+ timesteps = timesteps.to(accelerator.device)
+
+ step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
+
+ sigma = sigmas[step_indices].flatten()
+ while len(sigma.shape) < n_dim:
+ sigma = sigma.unsqueeze(-1)
+ return sigma
+
+ for epoch in range(first_epoch, args.num_train_epochs):
+ unet.train()
+ if args.train_text_encoder:
+ text_encoder_one.train()
+ text_encoder_two.train()
+
+ # set top parameter requires_grad = True for gradient checkpointing works
+ accelerator.unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True)
+ accelerator.unwrap_model(text_encoder_two).text_model.embeddings.requires_grad_(True)
+
+ for step, batch in enumerate(train_dataloader):
+ with accelerator.accumulate(unet):
+ pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
+ prompts = batch["prompts"]
+
+ # encode batch prompts when custom prompts are provided for each image -
+ if train_dataset.custom_instance_prompts:
+ if not args.train_text_encoder:
+ prompt_embeds, unet_add_text_embeds = compute_text_embeddings(
+ prompts, text_encoders, tokenizers
+ )
+ else:
+ tokens_one = tokenize_prompt(tokenizer_one, prompts)
+ tokens_two = tokenize_prompt(tokenizer_two, prompts)
+
+ # Convert images to latent space
+ model_input = vae.encode(pixel_values).latent_dist.sample()
+
+ if latents_mean is None and latents_std is None:
+ model_input = model_input * vae.config.scaling_factor
+ if args.pretrained_vae_model_name_or_path is None:
+ model_input = model_input.to(weight_dtype)
+ else:
+ latents_mean = latents_mean.to(device=model_input.device, dtype=model_input.dtype)
+ latents_std = latents_std.to(device=model_input.device, dtype=model_input.dtype)
+ model_input = (model_input - latents_mean) * vae.config.scaling_factor / latents_std
+ model_input = model_input.to(dtype=weight_dtype)
+
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(model_input)
+ bsz = model_input.shape[0]
+
+ # Sample a random timestep for each image
+ if not args.do_edm_style_training:
+ if args.loss_type == "huber" or args.loss_type == "smooth_l1":
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (1,), device="cpu")
+ timestep = timesteps.item()
+
+ if args.huber_schedule == "exponential":
+ alpha = -math.log(args.huber_c) / noise_scheduler.config.num_train_timesteps
+ huber_c = math.exp(-alpha * timestep)
+ elif args.huber_schedule == "snr":
+ alphas_cumprod = noise_scheduler.alphas_cumprod[timestep]
+ sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5
+ huber_c = (1 - args.huber_c) / (1 + sigmas) ** 2 + args.huber_c
+ elif args.huber_schedule == "constant":
+ huber_c = args.huber_c
+ else:
+ raise NotImplementedError(f"Unknown Huber loss schedule {args.huber_schedule}!")
+
+ timesteps = timesteps.repeat(bsz).to(model_input.device)
+ elif args.loss_type == "l2":
+ timesteps = torch.randint(
+ 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device
+ )
+ huber_c = 1 # may be anything, as it's not used
+ else:
+ raise NotImplementedError(f"Unknown loss type {args.loss_type}")
+ timesteps = timesteps.long()
+ else:
+ if "huber" in args.loss_type or "l1" in args.loss_type:
+ raise NotImplementedError("Huber loss is not implemented for EDM training yet!")
+ # in EDM formulation, the model is conditioned on the pre-conditioned noise levels
+ # instead of discrete timesteps, so here we sample indices to get the noise levels
+ # from `scheduler.timesteps`
+ indices = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,))
+ timesteps = noise_scheduler.timesteps[indices].to(device=model_input.device)
+
+ # Add noise to the model input according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)
+ # For EDM-style training, we first obtain the sigmas based on the continuous timesteps.
+ # We then precondition the final model inputs based on these sigmas instead of the timesteps.
+ # Follow: Section 5 of https://arxiv.org/abs/2206.00364.
+ if args.do_edm_style_training:
+ sigmas = get_sigmas(timesteps, len(noisy_model_input.shape), noisy_model_input.dtype)
+ if "EDM" in scheduler_type:
+ inp_noisy_latents = noise_scheduler.precondition_inputs(noisy_model_input, sigmas)
+ else:
+ inp_noisy_latents = noisy_model_input / ((sigmas**2 + 1) ** 0.5)
+
+ # time ids
+ add_time_ids = torch.cat(
+ [
+ compute_time_ids(original_size=s, crops_coords_top_left=c)
+ for s, c in zip(batch["original_sizes"], batch["crop_top_lefts"])
+ ]
+ )
+
+ # Calculate the elements to repeat depending on the use of prior-preservation and custom captions.
+ if not train_dataset.custom_instance_prompts:
+ elems_to_repeat_text_embeds = bsz // 2 if args.with_prior_preservation else bsz
+ else:
+ elems_to_repeat_text_embeds = 1
+
+ # Predict the noise residual
+ if not args.train_text_encoder:
+ unet_added_conditions = {
+ "time_ids": add_time_ids,
+ "text_embeds": unet_add_text_embeds.repeat(elems_to_repeat_text_embeds, 1),
+ }
+ prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1)
+ model_pred = unet(
+ inp_noisy_latents if args.do_edm_style_training else noisy_model_input,
+ timesteps,
+ prompt_embeds_input,
+ added_cond_kwargs=unet_added_conditions,
+ return_dict=False,
+ )[0]
+ else:
+ unet_added_conditions = {"time_ids": add_time_ids}
+ prompt_embeds, pooled_prompt_embeds = encode_prompt(
+ text_encoders=[text_encoder_one, text_encoder_two],
+ tokenizers=None,
+ prompt=None,
+ text_input_ids_list=[tokens_one, tokens_two],
+ )
+ unet_added_conditions.update(
+ {"text_embeds": pooled_prompt_embeds.repeat(elems_to_repeat_text_embeds, 1)}
+ )
+ prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1)
+ model_pred = unet(
+ inp_noisy_latents if args.do_edm_style_training else noisy_model_input,
+ timesteps,
+ prompt_embeds_input,
+ added_cond_kwargs=unet_added_conditions,
+ return_dict=False,
+ )[0]
+
+ weighting = None
+ if args.do_edm_style_training:
+ # Similar to the input preconditioning, the model predictions are also preconditioned
+ # on noised model inputs (before preconditioning) and the sigmas.
+ # Follow: Section 5 of https://arxiv.org/abs/2206.00364.
+ if "EDM" in scheduler_type:
+ model_pred = noise_scheduler.precondition_outputs(noisy_model_input, model_pred, sigmas)
+ else:
+ if noise_scheduler.config.prediction_type == "epsilon":
+ model_pred = model_pred * (-sigmas) + noisy_model_input
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ model_pred = model_pred * (-sigmas / (sigmas**2 + 1) ** 0.5) + (
+ noisy_model_input / (sigmas**2 + 1)
+ )
+ # We are not doing weighting here because it tends result in numerical problems.
+ # See: https://github.com/huggingface/diffusers/pull/7126#issuecomment-1968523051
+ # There might be other alternatives for weighting as well:
+ # https://github.com/huggingface/diffusers/pull/7126#discussion_r1505404686
+ if "EDM" not in scheduler_type:
+ weighting = (sigmas**-2.0).float()
+
+ # Get the target for loss depending on the prediction type
+ if noise_scheduler.config.prediction_type == "epsilon":
+ target = model_input if args.do_edm_style_training else noise
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ target = (
+ model_input
+ if args.do_edm_style_training
+ else noise_scheduler.get_velocity(model_input, noise, timesteps)
+ )
+ else:
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
+
+ if args.with_prior_preservation:
+ # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
+ model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
+ target, target_prior = torch.chunk(target, 2, dim=0)
+
+ # Compute prior loss
+ prior_loss = conditional_loss(
+ model_pred_prior,
+ target_prior,
+ reduction="mean",
+ loss_type=args.loss_type,
+ huber_c=huber_c,
+ weighting=weighting,
+ )
+
+ if args.snr_gamma is None:
+ loss = conditional_loss(
+ model_pred,
+ target,
+ reduction="mean",
+ loss_type=args.loss_type,
+ huber_c=huber_c,
+ weighting=weighting,
+ )
+ else:
+ # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
+ # Since we predict the noise instead of x_0, the original formulation is slightly changed.
+ # This is discussed in Section 4.2 of the same paper.
+ snr = compute_snr(noise_scheduler, timesteps)
+ base_weight = (
+ torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
+ )
+
+ if noise_scheduler.config.prediction_type == "v_prediction":
+ # Velocity objective needs to be floored to an SNR weight of one.
+ mse_loss_weights = base_weight + 1
+ else:
+ # Epsilon and sample both use the same loss weights.
+ mse_loss_weights = base_weight
+
+ loss = conditional_loss(
+ model_pred, target, reduction="none", loss_type=args.loss_type, huber_c=huber_c, weighting=None
+ )
+ loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
+ loss = loss.mean()
+
+ if args.with_prior_preservation:
+ # Add the prior loss to the instance loss.
+ loss = loss + args.prior_loss_weight * prior_loss
+
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ params_to_clip = (
+ itertools.chain(unet_lora_parameters, text_lora_parameters_one, text_lora_parameters_two)
+ if args.train_text_encoder
+ else unet_lora_parameters
+ )
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
+
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ global_step += 1
+
+ if accelerator.is_main_process:
+ if global_step % args.checkpointing_steps == 0:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+ accelerator.log(logs, step=global_step)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ if accelerator.is_main_process:
+ if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
+ # create pipeline
+ if not args.train_text_encoder:
+ text_encoder_one = text_encoder_cls_one.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="text_encoder",
+ revision=args.revision,
+ variant=args.variant,
+ )
+ text_encoder_two = text_encoder_cls_two.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="text_encoder_2",
+ revision=args.revision,
+ variant=args.variant,
+ )
+ pipeline = StableDiffusionXLPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ vae=vae,
+ text_encoder=accelerator.unwrap_model(text_encoder_one),
+ text_encoder_2=accelerator.unwrap_model(text_encoder_two),
+ unet=accelerator.unwrap_model(unet),
+ revision=args.revision,
+ variant=args.variant,
+ torch_dtype=weight_dtype,
+ )
+ pipeline_args = {"prompt": args.validation_prompt}
+
+ images = log_validation(
+ pipeline,
+ args,
+ accelerator,
+ pipeline_args,
+ epoch,
+ )
+
+ # Save the lora layers
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ unet = unwrap_model(unet)
+ unet = unet.to(torch.float32)
+ unet_lora_layers = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet))
+
+ if args.train_text_encoder:
+ text_encoder_one = unwrap_model(text_encoder_one)
+ text_encoder_lora_layers = convert_state_dict_to_diffusers(
+ get_peft_model_state_dict(text_encoder_one.to(torch.float32))
+ )
+ text_encoder_two = unwrap_model(text_encoder_two)
+ text_encoder_2_lora_layers = convert_state_dict_to_diffusers(
+ get_peft_model_state_dict(text_encoder_two.to(torch.float32))
+ )
+ else:
+ text_encoder_lora_layers = None
+ text_encoder_2_lora_layers = None
+
+ StableDiffusionXLPipeline.save_lora_weights(
+ save_directory=args.output_dir,
+ unet_lora_layers=unet_lora_layers,
+ text_encoder_lora_layers=text_encoder_lora_layers,
+ text_encoder_2_lora_layers=text_encoder_2_lora_layers,
+ )
+ if args.output_kohya_format:
+ lora_state_dict = load_file(f"{args.output_dir}/pytorch_lora_weights.safetensors")
+ peft_state_dict = convert_all_state_dict_to_peft(lora_state_dict)
+ kohya_state_dict = convert_state_dict_to_kohya(peft_state_dict)
+ save_file(kohya_state_dict, f"{args.output_dir}/pytorch_lora_weights_kohya.safetensors")
+
+ # Final inference
+ # Load previous pipeline
+ vae = AutoencoderKL.from_pretrained(
+ vae_path,
+ subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
+ revision=args.revision,
+ variant=args.variant,
+ torch_dtype=weight_dtype,
+ )
+ pipeline = StableDiffusionXLPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ vae=vae,
+ revision=args.revision,
+ variant=args.variant,
+ torch_dtype=weight_dtype,
+ )
+
+ # load attention processors
+ pipeline.load_lora_weights(args.output_dir)
+
+ # run inference
+ images = []
+ if args.validation_prompt and args.num_validation_images > 0:
+ pipeline_args = {"prompt": args.validation_prompt, "num_inference_steps": 25}
+ images = log_validation(
+ pipeline,
+ args,
+ accelerator,
+ pipeline_args,
+ epoch,
+ is_final_validation=True,
+ )
+
+ if args.push_to_hub:
+ save_model_card(
+ repo_id,
+ use_dora=args.use_dora,
+ images=images,
+ base_model=args.pretrained_model_name_or_path,
+ train_text_encoder=args.train_text_encoder,
+ instance_prompt=args.instance_prompt,
+ validation_prompt=args.validation_prompt,
+ repo_folder=args.output_dir,
+ vae_path=args.pretrained_vae_model_name_or_path,
+ )
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ main(args)
diff --git a/diffusers/examples/research_projects/scheduled_huber_loss_training/text_to_image/train_text_to_image.py b/diffusers/examples/research_projects/scheduled_huber_loss_training/text_to_image/train_text_to_image.py
new file mode 100644
index 0000000000000000000000000000000000000000..d3bf95305dad758e28b71bf2264122ed241b4372
--- /dev/null
+++ b/diffusers/examples/research_projects/scheduled_huber_loss_training/text_to_image/train_text_to_image.py
@@ -0,0 +1,1162 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import logging
+import math
+import os
+import random
+import shutil
+from pathlib import Path
+
+import accelerate
+import datasets
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.state import AcceleratorState
+from accelerate.utils import ProjectConfiguration, set_seed
+from datasets import load_dataset
+from huggingface_hub import create_repo, upload_folder
+from packaging import version
+from torchvision import transforms
+from tqdm.auto import tqdm
+from transformers import CLIPTextModel, CLIPTokenizer
+from transformers.utils import ContextManagers
+
+import diffusers
+from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
+from diffusers.optimization import get_scheduler
+from diffusers.training_utils import EMAModel, compute_snr
+from diffusers.utils import check_min_version, deprecate, is_wandb_available, make_image_grid
+from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
+from diffusers.utils.import_utils import is_xformers_available
+from diffusers.utils.torch_utils import is_compiled_module
+
+
+if is_wandb_available():
+ import wandb
+
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.28.0.dev0")
+
+logger = get_logger(__name__, log_level="INFO")
+
+DATASET_NAME_MAPPING = {
+ "lambdalabs/naruto-blip-captions": ("image", "text"),
+}
+
+
+def save_model_card(
+ args,
+ repo_id: str,
+ images: list = None,
+ repo_folder: str = None,
+):
+ img_str = ""
+ if len(images) > 0:
+ image_grid = make_image_grid(images, 1, len(args.validation_prompts))
+ image_grid.save(os.path.join(repo_folder, "val_imgs_grid.png"))
+ img_str += "\n"
+
+ model_description = f"""
+# Text-to-image finetuning - {repo_id}
+
+This pipeline was finetuned from **{args.pretrained_model_name_or_path}** on the **{args.dataset_name}** dataset. Below are some example images generated with the finetuned pipeline using the following prompts: {args.validation_prompts}: \n
+{img_str}
+
+## Pipeline usage
+
+You can use the pipeline like so:
+
+```python
+from diffusers import DiffusionPipeline
+import torch
+
+pipeline = DiffusionPipeline.from_pretrained("{repo_id}", torch_dtype=torch.float16)
+prompt = "{args.validation_prompts[0]}"
+image = pipeline(prompt).images[0]
+image.save("my_image.png")
+```
+
+## Training info
+
+These are the key hyperparameters used during training:
+
+* Epochs: {args.num_train_epochs}
+* Learning rate: {args.learning_rate}
+* Batch size: {args.train_batch_size}
+* Gradient accumulation steps: {args.gradient_accumulation_steps}
+* Image resolution: {args.resolution}
+* Mixed-precision: {args.mixed_precision}
+
+"""
+ wandb_info = ""
+ if is_wandb_available():
+ wandb_run_url = None
+ if wandb.run is not None:
+ wandb_run_url = wandb.run.url
+
+ if wandb_run_url is not None:
+ wandb_info = f"""
+More information on all the CLI arguments and the environment are available on your [`wandb` run page]({wandb_run_url}).
+"""
+
+ model_description += wandb_info
+
+ model_card = load_or_create_model_card(
+ repo_id_or_path=repo_id,
+ from_training=True,
+ license="creativeml-openrail-m",
+ base_model=args.pretrained_model_name_or_path,
+ model_description=model_description,
+ inference=True,
+ )
+
+ tags = ["stable-diffusion", "stable-diffusion-diffusers", "text-to-image", "diffusers", "diffusers-training"]
+ model_card = populate_model_card(model_card, tags=tags)
+
+ model_card.save(os.path.join(repo_folder, "README.md"))
+
+
+def log_validation(vae, text_encoder, tokenizer, unet, args, accelerator, weight_dtype, epoch):
+ logger.info("Running validation... ")
+
+ pipeline = StableDiffusionPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ vae=accelerator.unwrap_model(vae),
+ text_encoder=accelerator.unwrap_model(text_encoder),
+ tokenizer=tokenizer,
+ unet=accelerator.unwrap_model(unet),
+ safety_checker=None,
+ revision=args.revision,
+ variant=args.variant,
+ torch_dtype=weight_dtype,
+ )
+ pipeline = pipeline.to(accelerator.device)
+ pipeline.set_progress_bar_config(disable=True)
+
+ if args.enable_xformers_memory_efficient_attention:
+ pipeline.enable_xformers_memory_efficient_attention()
+
+ if args.seed is None:
+ generator = None
+ else:
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
+
+ images = []
+ for i in range(len(args.validation_prompts)):
+ with torch.autocast("cuda"):
+ image = pipeline(args.validation_prompts[i], num_inference_steps=20, generator=generator).images[0]
+
+ images.append(image)
+
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
+ elif tracker.name == "wandb":
+ tracker.log(
+ {
+ "validation": [
+ wandb.Image(image, caption=f"{i}: {args.validation_prompts[i]}")
+ for i, image in enumerate(images)
+ ]
+ }
+ )
+ else:
+ logger.warning(f"image logging not implemented for {tracker.name}")
+
+ del pipeline
+ torch.cuda.empty_cache()
+
+ return images
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ parser.add_argument(
+ "--input_perturbation", type=float, default=0, help="The scale of input perturbation. Recommended 0.1."
+ )
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--variant",
+ type=str,
+ default=None,
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
+ )
+ parser.add_argument(
+ "--dataset_name",
+ type=str,
+ default=None,
+ help=(
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
+ " or to a folder containing files that 🤗 Datasets can understand."
+ ),
+ )
+ parser.add_argument(
+ "--dataset_config_name",
+ type=str,
+ default=None,
+ help="The config of the Dataset, leave as None if there's only one config.",
+ )
+ parser.add_argument(
+ "--train_data_dir",
+ type=str,
+ default=None,
+ help=(
+ "A folder containing the training data. Folder contents must follow the structure described in"
+ " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
+ " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
+ ),
+ )
+ parser.add_argument(
+ "--image_column", type=str, default="image", help="The column of the dataset containing an image."
+ )
+ parser.add_argument(
+ "--caption_column",
+ type=str,
+ default="text",
+ help="The column of the dataset containing a caption or a list of captions.",
+ )
+ parser.add_argument(
+ "--max_train_samples",
+ type=int,
+ default=None,
+ help=(
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ ),
+ )
+ parser.add_argument(
+ "--validation_prompts",
+ type=str,
+ default=None,
+ nargs="+",
+ help=("A set of prompts evaluated every `--validation_epochs` and logged to `--report_to`."),
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="sd-model-finetuned",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--center_crop",
+ default=False,
+ action="store_true",
+ help=(
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
+ " cropped. The images will be resized to the resolution first before cropping."
+ ),
+ )
+ parser.add_argument(
+ "--random_flip",
+ action="store_true",
+ help="whether to randomly flip images horizontally",
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=100)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--snr_gamma",
+ type=float,
+ default=None,
+ help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
+ "More details here: https://arxiv.org/abs/2303.09556.",
+ )
+ parser.add_argument(
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")
+ parser.add_argument(
+ "--non_ema_revision",
+ type=str,
+ default=None,
+ required=False,
+ help=(
+ "Revision of pretrained non-ema model identifier. Must be a branch, tag or git identifier of the local or"
+ " remote repository specified with --pretrained_model_name_or_path."
+ ),
+ )
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--prediction_type",
+ type=str,
+ default=None,
+ help="The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediction_type` is chosen.",
+ )
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
+ )
+ parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.")
+ parser.add_argument(
+ "--loss_type",
+ type=str,
+ default="l2",
+ choices=["l2", "huber", "smooth_l1"],
+ help="The type of loss to use and whether it's timestep-scheduled. See Issue #7488 for more info.",
+ )
+ parser.add_argument(
+ "--huber_schedule",
+ type=str,
+ default="snr",
+ choices=["constant", "exponential", "snr"],
+ help="The schedule to use for the huber losses parameter",
+ )
+ parser.add_argument(
+ "--huber_c",
+ type=float,
+ default=0.1,
+ help="The huber loss parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type.",
+ )
+ parser.add_argument(
+ "--validation_epochs",
+ type=int,
+ default=5,
+ help="Run validation every X epochs.",
+ )
+ parser.add_argument(
+ "--tracker_project_name",
+ type=str,
+ default="text2image-fine-tune",
+ help=(
+ "The `project_name` argument passed to Accelerator.init_trackers for"
+ " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
+ ),
+ )
+
+ args = parser.parse_args()
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ # Sanity checks
+ if args.dataset_name is None and args.train_data_dir is None:
+ raise ValueError("Need either a dataset name or a training folder.")
+
+ # default to using the same revision for the non-ema model if not specified
+ if args.non_ema_revision is None:
+ args.non_ema_revision = args.revision
+
+ return args
+
+
+# NOTE: if you're using the scheduled version, huber_c has to depend on the timesteps already
+def conditional_loss(
+ model_pred: torch.Tensor,
+ target: torch.Tensor,
+ reduction: str = "mean",
+ loss_type: str = "l2",
+ huber_c: float = 0.1,
+):
+ if loss_type == "l2":
+ loss = F.mse_loss(model_pred, target, reduction=reduction)
+ elif loss_type == "huber":
+ loss = 2 * huber_c * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
+ if reduction == "mean":
+ loss = torch.mean(loss)
+ elif reduction == "sum":
+ loss = torch.sum(loss)
+ elif loss_type == "smooth_l1":
+ loss = 2 * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
+ if reduction == "mean":
+ loss = torch.mean(loss)
+ elif reduction == "sum":
+ loss = torch.sum(loss)
+ else:
+ raise NotImplementedError(f"Unsupported Loss Type {loss_type}")
+ return loss
+
+
+def main():
+ args = parse_args()
+
+ if args.report_to == "wandb" and args.hub_token is not None:
+ raise ValueError(
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
+ " Please use `huggingface-cli login` to authenticate with the Hub."
+ )
+
+ if args.non_ema_revision is not None:
+ deprecate(
+ "non_ema_revision!=None",
+ "0.15.0",
+ message=(
+ "Downloading 'non_ema' weights from revision branches of the Hub is deprecated. Please make sure to"
+ " use `--variant=non_ema` instead."
+ ),
+ )
+ logging_dir = os.path.join(args.output_dir, args.logging_dir)
+
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ )
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ datasets.utils.logging.set_verbosity_warning()
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ datasets.utils.logging.set_verbosity_error()
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
+ ).repo_id
+
+ # Load scheduler, tokenizer and models.
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
+ tokenizer = CLIPTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
+ )
+
+ def deepspeed_zero_init_disabled_context_manager():
+ """
+ returns either a context list that includes one that will disable zero.Init or an empty context list
+ """
+ deepspeed_plugin = AcceleratorState().deepspeed_plugin if accelerate.state.is_initialized() else None
+ if deepspeed_plugin is None:
+ return []
+
+ return [deepspeed_plugin.zero3_init_context_manager(enable=False)]
+
+ # Currently Accelerate doesn't know how to handle multiple models under Deepspeed ZeRO stage 3.
+ # For this to work properly all models must be run through `accelerate.prepare`. But accelerate
+ # will try to assign the same optimizer with the same weights to all models during
+ # `deepspeed.initialize`, which of course doesn't work.
+ #
+ # For now the following workaround will partially support Deepspeed ZeRO-3, by excluding the 2
+ # frozen models from being partitioned during `zero.Init` which gets called during
+ # `from_pretrained` So CLIPTextModel and AutoencoderKL will not enjoy the parameter sharding
+ # across multiple gpus and only UNet2DConditionModel will get ZeRO sharded.
+ with ContextManagers(deepspeed_zero_init_disabled_context_manager()):
+ text_encoder = CLIPTextModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
+ )
+ vae = AutoencoderKL.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant
+ )
+
+ unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.non_ema_revision
+ )
+
+ # Freeze vae and text_encoder and set unet to trainable
+ vae.requires_grad_(False)
+ text_encoder.requires_grad_(False)
+ unet.train()
+
+ # Create EMA for the unet.
+ if args.use_ema:
+ ema_unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
+ )
+ ema_unet = EMAModel(ema_unet.parameters(), model_cls=UNet2DConditionModel, model_config=ema_unet.config)
+
+ if args.enable_xformers_memory_efficient_attention:
+ if is_xformers_available():
+ import xformers
+
+ xformers_version = version.parse(xformers.__version__)
+ if xformers_version == version.parse("0.0.16"):
+ logger.warning(
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
+ )
+ unet.enable_xformers_memory_efficient_attention()
+ else:
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
+
+ # `accelerate` 0.16.0 will have better support for customized saving
+ if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
+ def save_model_hook(models, weights, output_dir):
+ if accelerator.is_main_process:
+ if args.use_ema:
+ ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema"))
+
+ for i, model in enumerate(models):
+ model.save_pretrained(os.path.join(output_dir, "unet"))
+
+ # make sure to pop weight so that corresponding model is not saved again
+ weights.pop()
+
+ def load_model_hook(models, input_dir):
+ if args.use_ema:
+ load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), UNet2DConditionModel)
+ ema_unet.load_state_dict(load_model.state_dict())
+ ema_unet.to(accelerator.device)
+ del load_model
+
+ for _ in range(len(models)):
+ # pop models so that they are not loaded again
+ model = models.pop()
+
+ # load diffusers style into model
+ load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet")
+ model.register_to_config(**load_model.config)
+
+ model.load_state_dict(load_model.state_dict())
+ del load_model
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
+
+ if args.gradient_checkpointing:
+ unet.enable_gradient_checkpointing()
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ # Initialize the optimizer
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
+ )
+
+ optimizer_cls = bnb.optim.AdamW8bit
+ else:
+ optimizer_cls = torch.optim.AdamW
+
+ optimizer = optimizer_cls(
+ unet.parameters(),
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ # Get the datasets: you can either provide your own training and evaluation files (see below)
+ # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
+
+ # In distributed training, the load_dataset function guarantees that only one local process can concurrently
+ # download the dataset.
+ if args.dataset_name is not None:
+ # Downloading and loading a dataset from the hub.
+ dataset = load_dataset(
+ args.dataset_name,
+ args.dataset_config_name,
+ cache_dir=args.cache_dir,
+ data_dir=args.train_data_dir,
+ )
+ else:
+ data_files = {}
+ if args.train_data_dir is not None:
+ data_files["train"] = os.path.join(args.train_data_dir, "**")
+ dataset = load_dataset(
+ "imagefolder",
+ data_files=data_files,
+ cache_dir=args.cache_dir,
+ )
+ # See more about loading custom images at
+ # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
+
+ # Preprocessing the datasets.
+ # We need to tokenize inputs and targets.
+ column_names = dataset["train"].column_names
+
+ # 6. Get the column names for input/target.
+ dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None)
+ if args.image_column is None:
+ image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
+ else:
+ image_column = args.image_column
+ if image_column not in column_names:
+ raise ValueError(
+ f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}"
+ )
+ if args.caption_column is None:
+ caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
+ else:
+ caption_column = args.caption_column
+ if caption_column not in column_names:
+ raise ValueError(
+ f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}"
+ )
+
+ # Preprocessing the datasets.
+ # We need to tokenize input captions and transform the images.
+ def tokenize_captions(examples, is_train=True):
+ captions = []
+ for caption in examples[caption_column]:
+ if isinstance(caption, str):
+ captions.append(caption)
+ elif isinstance(caption, (list, np.ndarray)):
+ # take a random caption if there are multiple
+ captions.append(random.choice(caption) if is_train else caption[0])
+ else:
+ raise ValueError(
+ f"Caption column `{caption_column}` should contain either strings or lists of strings."
+ )
+ inputs = tokenizer(
+ captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
+ )
+ return inputs.input_ids
+
+ # Preprocessing the datasets.
+ train_transforms = transforms.Compose(
+ [
+ transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
+ transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
+ transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+ )
+
+ def preprocess_train(examples):
+ images = [image.convert("RGB") for image in examples[image_column]]
+ examples["pixel_values"] = [train_transforms(image) for image in images]
+ examples["input_ids"] = tokenize_captions(examples)
+ return examples
+
+ with accelerator.main_process_first():
+ if args.max_train_samples is not None:
+ dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
+ # Set the training transforms
+ train_dataset = dataset["train"].with_transform(preprocess_train)
+
+ def collate_fn(examples):
+ pixel_values = torch.stack([example["pixel_values"] for example in examples])
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
+ input_ids = torch.stack([example["input_ids"] for example in examples])
+ return {"pixel_values": pixel_values, "input_ids": input_ids}
+
+ # DataLoaders creation:
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset,
+ shuffle=True,
+ collate_fn=collate_fn,
+ batch_size=args.train_batch_size,
+ num_workers=args.dataloader_num_workers,
+ )
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
+ )
+
+ # Prepare everything with our `accelerator`.
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ unet, optimizer, train_dataloader, lr_scheduler
+ )
+
+ if args.use_ema:
+ ema_unet.to(accelerator.device)
+
+ # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
+ # as these weights are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ args.mixed_precision = accelerator.mixed_precision
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+ args.mixed_precision = accelerator.mixed_precision
+
+ # Move text_encode and vae to gpu and cast to weight_dtype
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
+ vae.to(accelerator.device, dtype=weight_dtype)
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ tracker_config = dict(vars(args))
+ tracker_config.pop("validation_prompts")
+ accelerator.init_trackers(args.tracker_project_name, tracker_config)
+
+ # Function for unwrapping if model was compiled with `torch.compile`.
+ def unwrap_model(model):
+ model = accelerator.unwrap_model(model)
+ model = model._orig_mod if is_compiled_module(model) else model
+ return model
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the most recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ initial_global_step = 0
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ initial_global_step = global_step
+ first_epoch = global_step // num_update_steps_per_epoch
+
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
+
+ for epoch in range(first_epoch, args.num_train_epochs):
+ train_loss = 0.0
+ for step, batch in enumerate(train_dataloader):
+ with accelerator.accumulate(unet):
+ # Convert images to latent space
+ latents = vae.encode(batch["pixel_values"].to(weight_dtype)).latent_dist.sample()
+ latents = latents * vae.config.scaling_factor
+
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(latents)
+ if args.noise_offset:
+ # https://www.crosslabs.org//blog/diffusion-with-offset-noise
+ noise += args.noise_offset * torch.randn(
+ (latents.shape[0], latents.shape[1], 1, 1), device=latents.device
+ )
+ if args.input_perturbation:
+ new_noise = noise + args.input_perturbation * torch.randn_like(noise)
+ bsz = latents.shape[0]
+
+ # Sample a random timestep for each image
+ if args.loss_type == "huber" or args.loss_type == "smooth_l1":
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (1,), device="cpu")
+ timestep = timesteps.item()
+
+ if args.huber_schedule == "exponential":
+ alpha = -math.log(args.huber_c) / noise_scheduler.config.num_train_timesteps
+ huber_c = math.exp(-alpha * timestep)
+ elif args.huber_schedule == "snr":
+ alphas_cumprod = noise_scheduler.alphas_cumprod[timestep]
+ sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5
+ huber_c = (1 - args.huber_c) / (1 + sigmas) ** 2 + args.huber_c
+ elif args.huber_schedule == "constant":
+ huber_c = args.huber_c
+ else:
+ raise NotImplementedError(f"Unknown Huber loss schedule {args.huber_schedule}!")
+
+ timesteps = timesteps.repeat(bsz).to(latents.device)
+ elif args.loss_type == "l2":
+ timesteps = torch.randint(
+ 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device
+ )
+ huber_c = 1 # may be anything, as it's not used
+ else:
+ raise NotImplementedError(f"Unknown loss type {args.loss_type}")
+ timesteps = timesteps.long()
+
+ # Add noise to the latents according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ if args.input_perturbation:
+ noisy_latents = noise_scheduler.add_noise(latents, new_noise, timesteps)
+ else:
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
+
+ # Get the text embedding for conditioning
+ encoder_hidden_states = text_encoder(batch["input_ids"], return_dict=False)[0]
+
+ # Get the target for loss depending on the prediction type
+ if args.prediction_type is not None:
+ # set prediction_type of scheduler if defined
+ noise_scheduler.register_to_config(prediction_type=args.prediction_type)
+
+ if noise_scheduler.config.prediction_type == "epsilon":
+ target = noise
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
+ else:
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
+
+ # Predict the noise residual and compute loss
+ model_pred = unet(noisy_latents, timesteps, encoder_hidden_states, return_dict=False)[0]
+
+ if args.snr_gamma is None:
+ loss = conditional_loss(
+ model_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c
+ )
+ else:
+ # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
+ # Since we predict the noise instead of x_0, the original formulation is slightly changed.
+ # This is discussed in Section 4.2 of the same paper.
+ snr = compute_snr(noise_scheduler, timesteps)
+ mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(
+ dim=1
+ )[0]
+ if noise_scheduler.config.prediction_type == "epsilon":
+ mse_loss_weights = mse_loss_weights / snr
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ mse_loss_weights = mse_loss_weights / (snr + 1)
+
+ loss = conditional_loss(
+ model_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
+ )
+ loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
+ loss = loss.mean()
+
+ # Gather the losses across all processes for logging (if we use distributed training).
+ avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
+ train_loss += avg_loss.item() / args.gradient_accumulation_steps
+
+ # Backpropagate
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ if args.use_ema:
+ ema_unet.step(unet.parameters())
+ progress_bar.update(1)
+ global_step += 1
+ accelerator.log({"train_loss": train_loss}, step=global_step)
+ train_loss = 0.0
+
+ if global_step % args.checkpointing_steps == 0:
+ if accelerator.is_main_process:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ if accelerator.is_main_process:
+ if args.validation_prompts is not None and epoch % args.validation_epochs == 0:
+ if args.use_ema:
+ # Store the UNet parameters temporarily and load the EMA parameters to perform inference.
+ ema_unet.store(unet.parameters())
+ ema_unet.copy_to(unet.parameters())
+ log_validation(
+ vae,
+ text_encoder,
+ tokenizer,
+ unet,
+ args,
+ accelerator,
+ weight_dtype,
+ global_step,
+ )
+ if args.use_ema:
+ # Switch back to the original UNet parameters.
+ ema_unet.restore(unet.parameters())
+
+ # Create the pipeline using the trained modules and save it.
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ unet = unwrap_model(unet)
+ if args.use_ema:
+ ema_unet.copy_to(unet.parameters())
+
+ pipeline = StableDiffusionPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ text_encoder=text_encoder,
+ vae=vae,
+ unet=unet,
+ revision=args.revision,
+ variant=args.variant,
+ )
+ pipeline.save_pretrained(args.output_dir)
+
+ # Run a final round of inference.
+ images = []
+ if args.validation_prompts is not None:
+ logger.info("Running inference for collecting generated images...")
+ pipeline = pipeline.to(accelerator.device)
+ pipeline.torch_dtype = weight_dtype
+ pipeline.set_progress_bar_config(disable=True)
+
+ if args.enable_xformers_memory_efficient_attention:
+ pipeline.enable_xformers_memory_efficient_attention()
+
+ if args.seed is None:
+ generator = None
+ else:
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
+
+ for i in range(len(args.validation_prompts)):
+ with torch.autocast("cuda"):
+ image = pipeline(args.validation_prompts[i], num_inference_steps=20, generator=generator).images[0]
+ images.append(image)
+
+ if args.push_to_hub:
+ save_model_card(args, repo_id, images, repo_folder=args.output_dir)
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/diffusers/examples/research_projects/scheduled_huber_loss_training/text_to_image/train_text_to_image_lora.py b/diffusers/examples/research_projects/scheduled_huber_loss_training/text_to_image/train_text_to_image_lora.py
new file mode 100644
index 0000000000000000000000000000000000000000..a4b4d69bb8923976c26b04bb29763555b1ff876e
--- /dev/null
+++ b/diffusers/examples/research_projects/scheduled_huber_loss_training/text_to_image/train_text_to_image_lora.py
@@ -0,0 +1,1051 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Fine-tuning script for Stable Diffusion for text2image with support for LoRA."""
+
+import argparse
+import logging
+import math
+import os
+import random
+import shutil
+from pathlib import Path
+
+import datasets
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import ProjectConfiguration, set_seed
+from datasets import load_dataset
+from huggingface_hub import create_repo, upload_folder
+from packaging import version
+from peft import LoraConfig
+from peft.utils import get_peft_model_state_dict
+from torchvision import transforms
+from tqdm.auto import tqdm
+from transformers import CLIPTextModel, CLIPTokenizer
+
+import diffusers
+from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, StableDiffusionPipeline, UNet2DConditionModel
+from diffusers.optimization import get_scheduler
+from diffusers.training_utils import cast_training_params, compute_snr
+from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available
+from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
+from diffusers.utils.import_utils import is_xformers_available
+from diffusers.utils.torch_utils import is_compiled_module
+
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.28.0.dev0")
+
+logger = get_logger(__name__, log_level="INFO")
+
+
+def save_model_card(
+ repo_id: str,
+ images: list = None,
+ base_model: str = None,
+ dataset_name: str = None,
+ repo_folder: str = None,
+):
+ img_str = ""
+ if images is not None:
+ for i, image in enumerate(images):
+ image.save(os.path.join(repo_folder, f"image_{i}.png"))
+ img_str += f"\n"
+
+ model_description = f"""
+# LoRA text2image fine-tuning - {repo_id}
+These are LoRA adaption weights for {base_model}. The weights were fine-tuned on the {dataset_name} dataset. You can find some example images in the following. \n
+{img_str}
+"""
+
+ model_card = load_or_create_model_card(
+ repo_id_or_path=repo_id,
+ from_training=True,
+ license="creativeml-openrail-m",
+ base_model=base_model,
+ model_description=model_description,
+ inference=True,
+ )
+
+ tags = [
+ "stable-diffusion",
+ "stable-diffusion-diffusers",
+ "text-to-image",
+ "diffusers",
+ "diffusers-training",
+ "lora",
+ ]
+ model_card = populate_model_card(model_card, tags=tags)
+
+ model_card.save(os.path.join(repo_folder, "README.md"))
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--variant",
+ type=str,
+ default=None,
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
+ )
+ parser.add_argument(
+ "--dataset_name",
+ type=str,
+ default=None,
+ help=(
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
+ " or to a folder containing files that 🤗 Datasets can understand."
+ ),
+ )
+ parser.add_argument(
+ "--dataset_config_name",
+ type=str,
+ default=None,
+ help="The config of the Dataset, leave as None if there's only one config.",
+ )
+ parser.add_argument(
+ "--train_data_dir",
+ type=str,
+ default=None,
+ help=(
+ "A folder containing the training data. Folder contents must follow the structure described in"
+ " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
+ " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
+ ),
+ )
+ parser.add_argument(
+ "--image_column", type=str, default="image", help="The column of the dataset containing an image."
+ )
+ parser.add_argument(
+ "--caption_column",
+ type=str,
+ default="text",
+ help="The column of the dataset containing a caption or a list of captions.",
+ )
+ parser.add_argument(
+ "--validation_prompt", type=str, default=None, help="A prompt that is sampled during training for inference."
+ )
+ parser.add_argument(
+ "--num_validation_images",
+ type=int,
+ default=4,
+ help="Number of images that should be generated during validation with `validation_prompt`.",
+ )
+ parser.add_argument(
+ "--validation_epochs",
+ type=int,
+ default=1,
+ help=(
+ "Run fine-tuning validation every X epochs. The validation process consists of running the prompt"
+ " `args.validation_prompt` multiple times: `args.num_validation_images`."
+ ),
+ )
+ parser.add_argument(
+ "--max_train_samples",
+ type=int,
+ default=None,
+ help=(
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ ),
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="sd-model-finetuned-lora",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--center_crop",
+ default=False,
+ action="store_true",
+ help=(
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
+ " cropped. The images will be resized to the resolution first before cropping."
+ ),
+ )
+ parser.add_argument(
+ "--random_flip",
+ action="store_true",
+ help="whether to randomly flip images horizontally",
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=100)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--snr_gamma",
+ type=float,
+ default=None,
+ help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
+ "More details here: https://arxiv.org/abs/2303.09556.",
+ )
+ parser.add_argument(
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--prediction_type",
+ type=str,
+ default=None,
+ help="The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediction_type` is chosen.",
+ )
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
+ )
+ parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.")
+ parser.add_argument(
+ "--loss_type",
+ type=str,
+ default="l2",
+ choices=["l2", "huber", "smooth_l1"],
+ help="The type of loss to use and whether it's timestep-scheduled. See Issue #7488 for more info.",
+ )
+ parser.add_argument(
+ "--huber_schedule",
+ type=str,
+ default="snr",
+ choices=["constant", "exponential", "snr"],
+ help="The schedule to use for the huber losses parameter",
+ )
+ parser.add_argument(
+ "--huber_c",
+ type=float,
+ default=0.1,
+ help="The huber loss parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type.",
+ )
+ parser.add_argument(
+ "--rank",
+ type=int,
+ default=4,
+ help=("The dimension of the LoRA update matrices."),
+ )
+
+ args = parser.parse_args()
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ # Sanity checks
+ if args.dataset_name is None and args.train_data_dir is None:
+ raise ValueError("Need either a dataset name or a training folder.")
+
+ return args
+
+
+DATASET_NAME_MAPPING = {
+ "lambdalabs/naruto-blip-captions": ("image", "text"),
+}
+
+
+# NOTE: if you're using the scheduled version, huber_c has to depend on the timesteps already
+def conditional_loss(
+ model_pred: torch.Tensor,
+ target: torch.Tensor,
+ reduction: str = "mean",
+ loss_type: str = "l2",
+ huber_c: float = 0.1,
+):
+ if loss_type == "l2":
+ loss = F.mse_loss(model_pred, target, reduction=reduction)
+ elif loss_type == "huber":
+ loss = 2 * huber_c * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
+ if reduction == "mean":
+ loss = torch.mean(loss)
+ elif reduction == "sum":
+ loss = torch.sum(loss)
+ elif loss_type == "smooth_l1":
+ loss = 2 * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
+ if reduction == "mean":
+ loss = torch.mean(loss)
+ elif reduction == "sum":
+ loss = torch.sum(loss)
+ else:
+ raise NotImplementedError(f"Unsupported Loss Type {loss_type}")
+ return loss
+
+
+def main():
+ args = parse_args()
+ if args.report_to == "wandb" and args.hub_token is not None:
+ raise ValueError(
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
+ " Please use `huggingface-cli login` to authenticate with the Hub."
+ )
+
+ logging_dir = Path(args.output_dir, args.logging_dir)
+
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ )
+ if args.report_to == "wandb":
+ if not is_wandb_available():
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
+ import wandb
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ datasets.utils.logging.set_verbosity_warning()
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ datasets.utils.logging.set_verbosity_error()
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
+ ).repo_id
+ # Load scheduler, tokenizer and models.
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
+ tokenizer = CLIPTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
+ )
+ text_encoder = CLIPTextModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
+ )
+ vae = AutoencoderKL.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant
+ )
+ unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
+ )
+ # freeze parameters of models to save more memory
+ unet.requires_grad_(False)
+ vae.requires_grad_(False)
+ text_encoder.requires_grad_(False)
+
+ # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
+ # as these weights are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ # Freeze the unet parameters before adding adapters
+ for param in unet.parameters():
+ param.requires_grad_(False)
+
+ unet_lora_config = LoraConfig(
+ r=args.rank,
+ lora_alpha=args.rank,
+ init_lora_weights="gaussian",
+ target_modules=["to_k", "to_q", "to_v", "to_out.0"],
+ )
+
+ # Move unet, vae and text_encoder to device and cast to weight_dtype
+ unet.to(accelerator.device, dtype=weight_dtype)
+ vae.to(accelerator.device, dtype=weight_dtype)
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
+
+ # Add adapter and make sure the trainable params are in float32.
+ unet.add_adapter(unet_lora_config)
+ if args.mixed_precision == "fp16":
+ # only upcast trainable parameters (LoRA) into fp32
+ cast_training_params(unet, dtype=torch.float32)
+
+ if args.enable_xformers_memory_efficient_attention:
+ if is_xformers_available():
+ import xformers
+
+ xformers_version = version.parse(xformers.__version__)
+ if xformers_version == version.parse("0.0.16"):
+ logger.warning(
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
+ )
+ unet.enable_xformers_memory_efficient_attention()
+ else:
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
+
+ lora_layers = filter(lambda p: p.requires_grad, unet.parameters())
+
+ if args.gradient_checkpointing:
+ unet.enable_gradient_checkpointing()
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ # Initialize the optimizer
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
+ )
+
+ optimizer_cls = bnb.optim.AdamW8bit
+ else:
+ optimizer_cls = torch.optim.AdamW
+
+ optimizer = optimizer_cls(
+ lora_layers,
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ # Get the datasets: you can either provide your own training and evaluation files (see below)
+ # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
+
+ # In distributed training, the load_dataset function guarantees that only one local process can concurrently
+ # download the dataset.
+ if args.dataset_name is not None:
+ # Downloading and loading a dataset from the hub.
+ dataset = load_dataset(
+ args.dataset_name,
+ args.dataset_config_name,
+ cache_dir=args.cache_dir,
+ data_dir=args.train_data_dir,
+ )
+ else:
+ data_files = {}
+ if args.train_data_dir is not None:
+ data_files["train"] = os.path.join(args.train_data_dir, "**")
+ dataset = load_dataset(
+ "imagefolder",
+ data_files=data_files,
+ cache_dir=args.cache_dir,
+ )
+ # See more about loading custom images at
+ # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
+
+ # Preprocessing the datasets.
+ # We need to tokenize inputs and targets.
+ column_names = dataset["train"].column_names
+
+ # 6. Get the column names for input/target.
+ dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None)
+ if args.image_column is None:
+ image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
+ else:
+ image_column = args.image_column
+ if image_column not in column_names:
+ raise ValueError(
+ f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}"
+ )
+ if args.caption_column is None:
+ caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
+ else:
+ caption_column = args.caption_column
+ if caption_column not in column_names:
+ raise ValueError(
+ f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}"
+ )
+
+ # Preprocessing the datasets.
+ # We need to tokenize input captions and transform the images.
+ def tokenize_captions(examples, is_train=True):
+ captions = []
+ for caption in examples[caption_column]:
+ if isinstance(caption, str):
+ captions.append(caption)
+ elif isinstance(caption, (list, np.ndarray)):
+ # take a random caption if there are multiple
+ captions.append(random.choice(caption) if is_train else caption[0])
+ else:
+ raise ValueError(
+ f"Caption column `{caption_column}` should contain either strings or lists of strings."
+ )
+ inputs = tokenizer(
+ captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
+ )
+ return inputs.input_ids
+
+ # Preprocessing the datasets.
+ train_transforms = transforms.Compose(
+ [
+ transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
+ transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
+ transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+ )
+
+ def unwrap_model(model):
+ model = accelerator.unwrap_model(model)
+ model = model._orig_mod if is_compiled_module(model) else model
+ return model
+
+ def preprocess_train(examples):
+ images = [image.convert("RGB") for image in examples[image_column]]
+ examples["pixel_values"] = [train_transforms(image) for image in images]
+ examples["input_ids"] = tokenize_captions(examples)
+ return examples
+
+ with accelerator.main_process_first():
+ if args.max_train_samples is not None:
+ dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
+ # Set the training transforms
+ train_dataset = dataset["train"].with_transform(preprocess_train)
+
+ def collate_fn(examples):
+ pixel_values = torch.stack([example["pixel_values"] for example in examples])
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
+ input_ids = torch.stack([example["input_ids"] for example in examples])
+ return {"pixel_values": pixel_values, "input_ids": input_ids}
+
+ # DataLoaders creation:
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset,
+ shuffle=True,
+ collate_fn=collate_fn,
+ batch_size=args.train_batch_size,
+ num_workers=args.dataloader_num_workers,
+ )
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
+ )
+
+ # Prepare everything with our `accelerator`.
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ unet, optimizer, train_dataloader, lr_scheduler
+ )
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ accelerator.init_trackers("text2image-fine-tune", config=vars(args))
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the most recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ initial_global_step = 0
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ initial_global_step = global_step
+ first_epoch = global_step // num_update_steps_per_epoch
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
+
+ for epoch in range(first_epoch, args.num_train_epochs):
+ unet.train()
+ train_loss = 0.0
+ for step, batch in enumerate(train_dataloader):
+ with accelerator.accumulate(unet):
+ # Convert images to latent space
+ latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
+ latents = latents * vae.config.scaling_factor
+
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(latents)
+ if args.noise_offset:
+ # https://www.crosslabs.org//blog/diffusion-with-offset-noise
+ noise += args.noise_offset * torch.randn(
+ (latents.shape[0], latents.shape[1], 1, 1), device=latents.device
+ )
+
+ bsz = latents.shape[0]
+ # Sample a random timestep for each image
+ if args.loss_type == "huber" or args.loss_type == "smooth_l1":
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (1,), device="cpu")
+ timestep = timesteps.item()
+
+ if args.huber_schedule == "exponential":
+ alpha = -math.log(args.huber_c) / noise_scheduler.config.num_train_timesteps
+ huber_c = math.exp(-alpha * timestep)
+ elif args.huber_schedule == "snr":
+ alphas_cumprod = noise_scheduler.alphas_cumprod[timestep]
+ sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5
+ huber_c = (1 - args.huber_c) / (1 + sigmas) ** 2 + args.huber_c
+ elif args.huber_schedule == "constant":
+ huber_c = args.huber_c
+ else:
+ raise NotImplementedError(f"Unknown Huber loss schedule {args.huber_schedule}!")
+
+ timesteps = timesteps.repeat(bsz).to(latents.device)
+ elif args.loss_type == "l2":
+ timesteps = torch.randint(
+ 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device
+ )
+ huber_c = 1 # may be anything, as it's not used
+ else:
+ raise NotImplementedError(f"Unknown loss type {args.loss_type}")
+
+ timesteps = timesteps.long()
+
+ # Add noise to the latents according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
+
+ # Get the text embedding for conditioning
+ encoder_hidden_states = text_encoder(batch["input_ids"], return_dict=False)[0]
+
+ # Get the target for loss depending on the prediction type
+ if args.prediction_type is not None:
+ # set prediction_type of scheduler if defined
+ noise_scheduler.register_to_config(prediction_type=args.prediction_type)
+
+ if noise_scheduler.config.prediction_type == "epsilon":
+ target = noise
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
+ else:
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
+
+ # Predict the noise residual and compute loss
+ model_pred = unet(noisy_latents, timesteps, encoder_hidden_states, return_dict=False)[0]
+
+ if args.snr_gamma is None:
+ loss = conditional_loss(
+ model_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c
+ )
+ else:
+ # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
+ # Since we predict the noise instead of x_0, the original formulation is slightly changed.
+ # This is discussed in Section 4.2 of the same paper.
+ snr = compute_snr(noise_scheduler, timesteps)
+ mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(
+ dim=1
+ )[0]
+ if noise_scheduler.config.prediction_type == "epsilon":
+ mse_loss_weights = mse_loss_weights / snr
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ mse_loss_weights = mse_loss_weights / (snr + 1)
+
+ loss = conditional_loss(
+ model_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
+ )
+ loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
+ loss = loss.mean()
+
+ # Gather the losses across all processes for logging (if we use distributed training).
+ avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
+ train_loss += avg_loss.item() / args.gradient_accumulation_steps
+
+ # Backpropagate
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ params_to_clip = lora_layers
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ global_step += 1
+ accelerator.log({"train_loss": train_loss}, step=global_step)
+ train_loss = 0.0
+
+ if global_step % args.checkpointing_steps == 0:
+ if accelerator.is_main_process:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+
+ unwrapped_unet = unwrap_model(unet)
+ unet_lora_state_dict = convert_state_dict_to_diffusers(
+ get_peft_model_state_dict(unwrapped_unet)
+ )
+
+ StableDiffusionPipeline.save_lora_weights(
+ save_directory=save_path,
+ unet_lora_layers=unet_lora_state_dict,
+ safe_serialization=True,
+ )
+
+ logger.info(f"Saved state to {save_path}")
+
+ logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ if accelerator.is_main_process:
+ if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
+ logger.info(
+ f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
+ f" {args.validation_prompt}."
+ )
+ # create pipeline
+ pipeline = DiffusionPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ unet=unwrap_model(unet),
+ revision=args.revision,
+ variant=args.variant,
+ torch_dtype=weight_dtype,
+ )
+ pipeline = pipeline.to(accelerator.device)
+ pipeline.set_progress_bar_config(disable=True)
+
+ # run inference
+ generator = torch.Generator(device=accelerator.device)
+ if args.seed is not None:
+ generator = generator.manual_seed(args.seed)
+ images = []
+ with torch.cuda.amp.autocast():
+ for _ in range(args.num_validation_images):
+ images.append(
+ pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]
+ )
+
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
+ if tracker.name == "wandb":
+ tracker.log(
+ {
+ "validation": [
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
+ for i, image in enumerate(images)
+ ]
+ }
+ )
+
+ del pipeline
+ torch.cuda.empty_cache()
+
+ # Save the lora layers
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ unet = unet.to(torch.float32)
+
+ unwrapped_unet = unwrap_model(unet)
+ unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unwrapped_unet))
+ StableDiffusionPipeline.save_lora_weights(
+ save_directory=args.output_dir,
+ unet_lora_layers=unet_lora_state_dict,
+ safe_serialization=True,
+ )
+
+ if args.push_to_hub:
+ save_model_card(
+ repo_id,
+ images=images,
+ base_model=args.pretrained_model_name_or_path,
+ dataset_name=args.dataset_name,
+ repo_folder=args.output_dir,
+ )
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ # Final inference
+ # Load previous pipeline
+ if args.validation_prompt is not None:
+ pipeline = DiffusionPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ revision=args.revision,
+ variant=args.variant,
+ torch_dtype=weight_dtype,
+ )
+ pipeline = pipeline.to(accelerator.device)
+
+ # load attention processors
+ pipeline.load_lora_weights(args.output_dir)
+
+ # run inference
+ generator = torch.Generator(device=accelerator.device)
+ if args.seed is not None:
+ generator = generator.manual_seed(args.seed)
+ images = []
+ with torch.cuda.amp.autocast():
+ for _ in range(args.num_validation_images):
+ images.append(
+ pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]
+ )
+
+ for tracker in accelerator.trackers:
+ if len(images) != 0:
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC")
+ if tracker.name == "wandb":
+ tracker.log(
+ {
+ "test": [
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
+ for i, image in enumerate(images)
+ ]
+ }
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/diffusers/examples/research_projects/scheduled_huber_loss_training/text_to_image/train_text_to_image_lora_sdxl.py b/diffusers/examples/research_projects/scheduled_huber_loss_training/text_to_image/train_text_to_image_lora_sdxl.py
new file mode 100644
index 0000000000000000000000000000000000000000..bab86bf21a765c4fd20354c3e3ca298e13e2ba8b
--- /dev/null
+++ b/diffusers/examples/research_projects/scheduled_huber_loss_training/text_to_image/train_text_to_image_lora_sdxl.py
@@ -0,0 +1,1384 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Fine-tuning script for Stable Diffusion XL for text2image with support for LoRA."""
+
+import argparse
+import logging
+import math
+import os
+import random
+import shutil
+from pathlib import Path
+
+import datasets
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
+from datasets import load_dataset
+from huggingface_hub import create_repo, upload_folder
+from packaging import version
+from peft import LoraConfig, set_peft_model_state_dict
+from peft.utils import get_peft_model_state_dict
+from torchvision import transforms
+from torchvision.transforms.functional import crop
+from tqdm.auto import tqdm
+from transformers import AutoTokenizer, PretrainedConfig
+
+import diffusers
+from diffusers import (
+ AutoencoderKL,
+ DDPMScheduler,
+ StableDiffusionXLPipeline,
+ UNet2DConditionModel,
+)
+from diffusers.loaders import StableDiffusionLoraLoaderMixin
+from diffusers.optimization import get_scheduler
+from diffusers.training_utils import _set_state_dict_into_text_encoder, cast_training_params, compute_snr
+from diffusers.utils import (
+ check_min_version,
+ convert_state_dict_to_diffusers,
+ convert_unet_state_dict_to_peft,
+ is_wandb_available,
+)
+from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
+from diffusers.utils.import_utils import is_xformers_available
+from diffusers.utils.torch_utils import is_compiled_module
+
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.28.0.dev0")
+
+logger = get_logger(__name__)
+
+
+def save_model_card(
+ repo_id: str,
+ images: list = None,
+ base_model: str = None,
+ dataset_name: str = None,
+ train_text_encoder: bool = False,
+ repo_folder: str = None,
+ vae_path: str = None,
+):
+ img_str = ""
+ if images is not None:
+ for i, image in enumerate(images):
+ image.save(os.path.join(repo_folder, f"image_{i}.png"))
+ img_str += f"\n"
+
+ model_description = f"""
+# LoRA text2image fine-tuning - {repo_id}
+
+These are LoRA adaption weights for {base_model}. The weights were fine-tuned on the {dataset_name} dataset. You can find some example images in the following. \n
+{img_str}
+
+LoRA for the text encoder was enabled: {train_text_encoder}.
+
+Special VAE used for training: {vae_path}.
+"""
+ model_card = load_or_create_model_card(
+ repo_id_or_path=repo_id,
+ from_training=True,
+ license="creativeml-openrail-m",
+ base_model=base_model,
+ model_description=model_description,
+ inference=True,
+ )
+
+ tags = [
+ "stable-diffusion-xl",
+ "stable-diffusion-xl-diffusers",
+ "text-to-image",
+ "diffusers",
+ "diffusers-training",
+ "lora",
+ ]
+ model_card = populate_model_card(model_card, tags=tags)
+
+ model_card.save(os.path.join(repo_folder, "README.md"))
+
+
+def import_model_class_from_model_name_or_path(
+ pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
+):
+ text_encoder_config = PretrainedConfig.from_pretrained(
+ pretrained_model_name_or_path, subfolder=subfolder, revision=revision
+ )
+ model_class = text_encoder_config.architectures[0]
+
+ if model_class == "CLIPTextModel":
+ from transformers import CLIPTextModel
+
+ return CLIPTextModel
+ elif model_class == "CLIPTextModelWithProjection":
+ from transformers import CLIPTextModelWithProjection
+
+ return CLIPTextModelWithProjection
+ else:
+ raise ValueError(f"{model_class} is not supported.")
+
+
+def parse_args(input_args=None):
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--pretrained_vae_model_name_or_path",
+ type=str,
+ default=None,
+ help="Path to pretrained VAE model with better numerical stability. More details: https://github.com/huggingface/diffusers/pull/4038.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--variant",
+ type=str,
+ default=None,
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
+ )
+ parser.add_argument(
+ "--dataset_name",
+ type=str,
+ default=None,
+ help=(
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
+ " or to a folder containing files that 🤗 Datasets can understand."
+ ),
+ )
+ parser.add_argument(
+ "--dataset_config_name",
+ type=str,
+ default=None,
+ help="The config of the Dataset, leave as None if there's only one config.",
+ )
+ parser.add_argument(
+ "--train_data_dir",
+ type=str,
+ default=None,
+ help=(
+ "A folder containing the training data. Folder contents must follow the structure described in"
+ " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
+ " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
+ ),
+ )
+ parser.add_argument(
+ "--image_column", type=str, default="image", help="The column of the dataset containing an image."
+ )
+ parser.add_argument(
+ "--caption_column",
+ type=str,
+ default="text",
+ help="The column of the dataset containing a caption or a list of captions.",
+ )
+ parser.add_argument(
+ "--validation_prompt",
+ type=str,
+ default=None,
+ help="A prompt that is used during validation to verify that the model is learning.",
+ )
+ parser.add_argument(
+ "--num_validation_images",
+ type=int,
+ default=4,
+ help="Number of images that should be generated during validation with `validation_prompt`.",
+ )
+ parser.add_argument(
+ "--validation_epochs",
+ type=int,
+ default=1,
+ help=(
+ "Run fine-tuning validation every X epochs. The validation process consists of running the prompt"
+ " `args.validation_prompt` multiple times: `args.num_validation_images`."
+ ),
+ )
+ parser.add_argument(
+ "--max_train_samples",
+ type=int,
+ default=None,
+ help=(
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ ),
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="sd-model-finetuned-lora",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=1024,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--center_crop",
+ default=False,
+ action="store_true",
+ help=(
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
+ " cropped. The images will be resized to the resolution first before cropping."
+ ),
+ )
+ parser.add_argument(
+ "--random_flip",
+ action="store_true",
+ help="whether to randomly flip images horizontally",
+ )
+ parser.add_argument(
+ "--train_text_encoder",
+ action="store_true",
+ help="Whether to train the text encoder. If set, the text encoder should be float32 precision.",
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=100)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
+ " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--snr_gamma",
+ type=float,
+ default=None,
+ help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
+ "More details here: https://arxiv.org/abs/2303.09556.",
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ parser.add_argument(
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--prediction_type",
+ type=str,
+ default=None,
+ help="The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediction_type` is chosen.",
+ )
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+ parser.add_argument(
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
+ )
+ parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.")
+ parser.add_argument(
+ "--loss_type",
+ type=str,
+ default="l2",
+ choices=["l2", "huber", "smooth_l1"],
+ help="The type of loss to use and whether it's timestep-scheduled. See Issue #7488 for more info.",
+ )
+ parser.add_argument(
+ "--huber_schedule",
+ type=str,
+ default="snr",
+ choices=["constant", "exponential", "snr"],
+ help="The schedule to use for the huber losses parameter",
+ )
+ parser.add_argument(
+ "--huber_c",
+ type=float,
+ default=0.1,
+ help="The huber loss parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type.",
+ )
+ parser.add_argument(
+ "--rank",
+ type=int,
+ default=4,
+ help=("The dimension of the LoRA update matrices."),
+ )
+ parser.add_argument(
+ "--debug_loss",
+ action="store_true",
+ help="debug loss for each image, if filenames are awailable in the dataset",
+ )
+
+ if input_args is not None:
+ args = parser.parse_args(input_args)
+ else:
+ args = parser.parse_args()
+
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ # Sanity checks
+ if args.dataset_name is None and args.train_data_dir is None:
+ raise ValueError("Need either a dataset name or a training folder.")
+
+ return args
+
+
+DATASET_NAME_MAPPING = {
+ "lambdalabs/naruto-blip-captions": ("image", "text"),
+}
+
+
+def tokenize_prompt(tokenizer, prompt):
+ text_inputs = tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ return text_input_ids
+
+
+# Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt
+def encode_prompt(text_encoders, tokenizers, prompt, text_input_ids_list=None):
+ prompt_embeds_list = []
+
+ for i, text_encoder in enumerate(text_encoders):
+ if tokenizers is not None:
+ tokenizer = tokenizers[i]
+ text_input_ids = tokenize_prompt(tokenizer, prompt)
+ else:
+ assert text_input_ids_list is not None
+ text_input_ids = text_input_ids_list[i]
+
+ prompt_embeds = text_encoder(
+ text_input_ids.to(text_encoder.device), output_hidden_states=True, return_dict=False
+ )
+
+ # We are only ALWAYS interested in the pooled output of the final text encoder
+ pooled_prompt_embeds = prompt_embeds[0]
+ prompt_embeds = prompt_embeds[-1][-2]
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
+ prompt_embeds_list.append(prompt_embeds)
+
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
+ pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
+ return prompt_embeds, pooled_prompt_embeds
+
+
+# NOTE: if you're using the scheduled version, huber_c has to depend on the timesteps already
+def conditional_loss(
+ model_pred: torch.Tensor,
+ target: torch.Tensor,
+ reduction: str = "mean",
+ loss_type: str = "l2",
+ huber_c: float = 0.1,
+):
+ if loss_type == "l2":
+ loss = F.mse_loss(model_pred, target, reduction=reduction)
+ elif loss_type == "huber" or loss_type == "huber_scheduled":
+ loss = huber_c * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
+ if reduction == "mean":
+ loss = torch.mean(loss)
+ elif reduction == "sum":
+ loss = torch.sum(loss)
+ else:
+ raise NotImplementedError(f"Unsupported Loss Type {loss_type}")
+ return loss
+
+
+def main(args):
+ if args.report_to == "wandb" and args.hub_token is not None:
+ raise ValueError(
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
+ " Please use `huggingface-cli login` to authenticate with the Hub."
+ )
+
+ logging_dir = Path(args.output_dir, args.logging_dir)
+
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+ kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ kwargs_handlers=[kwargs],
+ )
+
+ if args.report_to == "wandb":
+ if not is_wandb_available():
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
+ import wandb
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ datasets.utils.logging.set_verbosity_warning()
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ datasets.utils.logging.set_verbosity_error()
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
+ ).repo_id
+
+ # Load the tokenizers
+ tokenizer_one = AutoTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="tokenizer",
+ revision=args.revision,
+ use_fast=False,
+ )
+ tokenizer_two = AutoTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="tokenizer_2",
+ revision=args.revision,
+ use_fast=False,
+ )
+
+ # import correct text encoder classes
+ text_encoder_cls_one = import_model_class_from_model_name_or_path(
+ args.pretrained_model_name_or_path, args.revision
+ )
+ text_encoder_cls_two = import_model_class_from_model_name_or_path(
+ args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2"
+ )
+
+ # Load scheduler and models
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
+ text_encoder_one = text_encoder_cls_one.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
+ )
+ text_encoder_two = text_encoder_cls_two.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant
+ )
+ vae_path = (
+ args.pretrained_model_name_or_path
+ if args.pretrained_vae_model_name_or_path is None
+ else args.pretrained_vae_model_name_or_path
+ )
+ vae = AutoencoderKL.from_pretrained(
+ vae_path,
+ subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
+ revision=args.revision,
+ variant=args.variant,
+ )
+ unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
+ )
+
+ # We only train the additional adapter LoRA layers
+ vae.requires_grad_(False)
+ text_encoder_one.requires_grad_(False)
+ text_encoder_two.requires_grad_(False)
+ unet.requires_grad_(False)
+
+ # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
+ # as these weights are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ # Move unet, vae and text_encoder to device and cast to weight_dtype
+ # The VAE is in float32 to avoid NaN losses.
+ unet.to(accelerator.device, dtype=weight_dtype)
+
+ if args.pretrained_vae_model_name_or_path is None:
+ vae.to(accelerator.device, dtype=torch.float32)
+ else:
+ vae.to(accelerator.device, dtype=weight_dtype)
+ text_encoder_one.to(accelerator.device, dtype=weight_dtype)
+ text_encoder_two.to(accelerator.device, dtype=weight_dtype)
+
+ if args.enable_xformers_memory_efficient_attention:
+ if is_xformers_available():
+ import xformers
+
+ xformers_version = version.parse(xformers.__version__)
+ if xformers_version == version.parse("0.0.16"):
+ logger.warning(
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
+ )
+ unet.enable_xformers_memory_efficient_attention()
+ else:
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
+
+ # now we will add new LoRA weights to the attention layers
+ # Set correct lora layers
+ unet_lora_config = LoraConfig(
+ r=args.rank,
+ lora_alpha=args.rank,
+ init_lora_weights="gaussian",
+ target_modules=["to_k", "to_q", "to_v", "to_out.0"],
+ )
+
+ unet.add_adapter(unet_lora_config)
+
+ # The text encoder comes from 🤗 transformers, we will also attach adapters to it.
+ if args.train_text_encoder:
+ # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
+ text_lora_config = LoraConfig(
+ r=args.rank,
+ lora_alpha=args.rank,
+ init_lora_weights="gaussian",
+ target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
+ )
+ text_encoder_one.add_adapter(text_lora_config)
+ text_encoder_two.add_adapter(text_lora_config)
+
+ def unwrap_model(model):
+ model = accelerator.unwrap_model(model)
+ model = model._orig_mod if is_compiled_module(model) else model
+ return model
+
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
+ def save_model_hook(models, weights, output_dir):
+ if accelerator.is_main_process:
+ # there are only two options here. Either are just the unet attn processor layers
+ # or there are the unet and text encoder attn layers
+ unet_lora_layers_to_save = None
+ text_encoder_one_lora_layers_to_save = None
+ text_encoder_two_lora_layers_to_save = None
+
+ for model in models:
+ if isinstance(unwrap_model(model), type(unwrap_model(unet))):
+ unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model))
+ elif isinstance(unwrap_model(model), type(unwrap_model(text_encoder_one))):
+ text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers(
+ get_peft_model_state_dict(model)
+ )
+ elif isinstance(unwrap_model(model), type(unwrap_model(text_encoder_two))):
+ text_encoder_two_lora_layers_to_save = convert_state_dict_to_diffusers(
+ get_peft_model_state_dict(model)
+ )
+ else:
+ raise ValueError(f"unexpected save model: {model.__class__}")
+
+ # make sure to pop weight so that corresponding model is not saved again
+ if weights:
+ weights.pop()
+
+ StableDiffusionXLPipeline.save_lora_weights(
+ output_dir,
+ unet_lora_layers=unet_lora_layers_to_save,
+ text_encoder_lora_layers=text_encoder_one_lora_layers_to_save,
+ text_encoder_2_lora_layers=text_encoder_two_lora_layers_to_save,
+ )
+
+ def load_model_hook(models, input_dir):
+ unet_ = None
+ text_encoder_one_ = None
+ text_encoder_two_ = None
+
+ while len(models) > 0:
+ model = models.pop()
+
+ if isinstance(model, type(unwrap_model(unet))):
+ unet_ = model
+ elif isinstance(model, type(unwrap_model(text_encoder_one))):
+ text_encoder_one_ = model
+ elif isinstance(model, type(unwrap_model(text_encoder_two))):
+ text_encoder_two_ = model
+ else:
+ raise ValueError(f"unexpected save model: {model.__class__}")
+
+ lora_state_dict, _ = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir)
+ unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")}
+ unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
+ incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
+ if incompatible_keys is not None:
+ # check only for unexpected keys
+ unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
+ if unexpected_keys:
+ logger.warning(
+ f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
+ f" {unexpected_keys}. "
+ )
+
+ if args.train_text_encoder:
+ _set_state_dict_into_text_encoder(lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_one_)
+
+ _set_state_dict_into_text_encoder(
+ lora_state_dict, prefix="text_encoder_2.", text_encoder=text_encoder_two_
+ )
+
+ # Make sure the trainable params are in float32. This is again needed since the base models
+ # are in `weight_dtype`. More details:
+ # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
+ if args.mixed_precision == "fp16":
+ models = [unet_]
+ if args.train_text_encoder:
+ models.extend([text_encoder_one_, text_encoder_two_])
+ cast_training_params(models, dtype=torch.float32)
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
+
+ if args.gradient_checkpointing:
+ unet.enable_gradient_checkpointing()
+ if args.train_text_encoder:
+ text_encoder_one.gradient_checkpointing_enable()
+ text_encoder_two.gradient_checkpointing_enable()
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ # Make sure the trainable params are in float32.
+ if args.mixed_precision == "fp16":
+ models = [unet]
+ if args.train_text_encoder:
+ models.extend([text_encoder_one, text_encoder_two])
+ cast_training_params(models, dtype=torch.float32)
+
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
+ )
+
+ optimizer_class = bnb.optim.AdamW8bit
+ else:
+ optimizer_class = torch.optim.AdamW
+
+ # Optimizer creation
+ params_to_optimize = list(filter(lambda p: p.requires_grad, unet.parameters()))
+ if args.train_text_encoder:
+ params_to_optimize = (
+ params_to_optimize
+ + list(filter(lambda p: p.requires_grad, text_encoder_one.parameters()))
+ + list(filter(lambda p: p.requires_grad, text_encoder_two.parameters()))
+ )
+ optimizer = optimizer_class(
+ params_to_optimize,
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ # Get the datasets: you can either provide your own training and evaluation files (see below)
+ # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
+
+ # In distributed training, the load_dataset function guarantees that only one local process can concurrently
+ # download the dataset.
+ if args.dataset_name is not None:
+ # Downloading and loading a dataset from the hub.
+ dataset = load_dataset(
+ args.dataset_name, args.dataset_config_name, cache_dir=args.cache_dir, data_dir=args.train_data_dir
+ )
+ else:
+ data_files = {}
+ if args.train_data_dir is not None:
+ data_files["train"] = os.path.join(args.train_data_dir, "**")
+ dataset = load_dataset(
+ "imagefolder",
+ data_files=data_files,
+ cache_dir=args.cache_dir,
+ )
+ # See more about loading custom images at
+ # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
+
+ # Preprocessing the datasets.
+ # We need to tokenize inputs and targets.
+ column_names = dataset["train"].column_names
+
+ # 6. Get the column names for input/target.
+ dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None)
+ if args.image_column is None:
+ image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
+ else:
+ image_column = args.image_column
+ if image_column not in column_names:
+ raise ValueError(
+ f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}"
+ )
+ if args.caption_column is None:
+ caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
+ else:
+ caption_column = args.caption_column
+ if caption_column not in column_names:
+ raise ValueError(
+ f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}"
+ )
+
+ # Preprocessing the datasets.
+ # We need to tokenize input captions and transform the images.
+ def tokenize_captions(examples, is_train=True):
+ captions = []
+ for caption in examples[caption_column]:
+ if isinstance(caption, str):
+ captions.append(caption)
+ elif isinstance(caption, (list, np.ndarray)):
+ # take a random caption if there are multiple
+ captions.append(random.choice(caption) if is_train else caption[0])
+ else:
+ raise ValueError(
+ f"Caption column `{caption_column}` should contain either strings or lists of strings."
+ )
+ tokens_one = tokenize_prompt(tokenizer_one, captions)
+ tokens_two = tokenize_prompt(tokenizer_two, captions)
+ return tokens_one, tokens_two
+
+ # Preprocessing the datasets.
+ train_resize = transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR)
+ train_crop = transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution)
+ train_flip = transforms.RandomHorizontalFlip(p=1.0)
+ train_transforms = transforms.Compose(
+ [
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+ )
+
+ def preprocess_train(examples):
+ images = [image.convert("RGB") for image in examples[image_column]]
+ # image aug
+ original_sizes = []
+ all_images = []
+ crop_top_lefts = []
+ for image in images:
+ original_sizes.append((image.height, image.width))
+ image = train_resize(image)
+ if args.random_flip and random.random() < 0.5:
+ # flip
+ image = train_flip(image)
+ if args.center_crop:
+ y1 = max(0, int(round((image.height - args.resolution) / 2.0)))
+ x1 = max(0, int(round((image.width - args.resolution) / 2.0)))
+ image = train_crop(image)
+ else:
+ y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution))
+ image = crop(image, y1, x1, h, w)
+ crop_top_left = (y1, x1)
+ crop_top_lefts.append(crop_top_left)
+ image = train_transforms(image)
+ all_images.append(image)
+
+ examples["original_sizes"] = original_sizes
+ examples["crop_top_lefts"] = crop_top_lefts
+ examples["pixel_values"] = all_images
+ tokens_one, tokens_two = tokenize_captions(examples)
+ examples["input_ids_one"] = tokens_one
+ examples["input_ids_two"] = tokens_two
+ if args.debug_loss:
+ fnames = [os.path.basename(image.filename) for image in examples[image_column] if image.filename]
+ if fnames:
+ examples["filenames"] = fnames
+ return examples
+
+ with accelerator.main_process_first():
+ if args.max_train_samples is not None:
+ dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
+ # Set the training transforms
+ train_dataset = dataset["train"].with_transform(preprocess_train, output_all_columns=True)
+
+ def collate_fn(examples):
+ pixel_values = torch.stack([example["pixel_values"] for example in examples])
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
+ original_sizes = [example["original_sizes"] for example in examples]
+ crop_top_lefts = [example["crop_top_lefts"] for example in examples]
+ input_ids_one = torch.stack([example["input_ids_one"] for example in examples])
+ input_ids_two = torch.stack([example["input_ids_two"] for example in examples])
+ result = {
+ "pixel_values": pixel_values,
+ "input_ids_one": input_ids_one,
+ "input_ids_two": input_ids_two,
+ "original_sizes": original_sizes,
+ "crop_top_lefts": crop_top_lefts,
+ }
+
+ filenames = [example["filenames"] for example in examples if "filenames" in example]
+ if filenames:
+ result["filenames"] = filenames
+ return result
+
+ # DataLoaders creation:
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset,
+ shuffle=True,
+ collate_fn=collate_fn,
+ batch_size=args.train_batch_size,
+ num_workers=args.dataloader_num_workers,
+ )
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
+ num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
+ )
+
+ # Prepare everything with our `accelerator`.
+ if args.train_text_encoder:
+ unet, text_encoder_one, text_encoder_two, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ unet, text_encoder_one, text_encoder_two, optimizer, train_dataloader, lr_scheduler
+ )
+ else:
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ unet, optimizer, train_dataloader, lr_scheduler
+ )
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ accelerator.init_trackers("text2image-fine-tune", config=vars(args))
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the most recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ initial_global_step = 0
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ initial_global_step = global_step
+ first_epoch = global_step // num_update_steps_per_epoch
+
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
+
+ for epoch in range(first_epoch, args.num_train_epochs):
+ unet.train()
+ if args.train_text_encoder:
+ text_encoder_one.train()
+ text_encoder_two.train()
+ train_loss = 0.0
+ for step, batch in enumerate(train_dataloader):
+ with accelerator.accumulate(unet):
+ # Convert images to latent space
+ if args.pretrained_vae_model_name_or_path is not None:
+ pixel_values = batch["pixel_values"].to(dtype=weight_dtype)
+ else:
+ pixel_values = batch["pixel_values"]
+
+ model_input = vae.encode(pixel_values).latent_dist.sample()
+ model_input = model_input * vae.config.scaling_factor
+ if args.pretrained_vae_model_name_or_path is None:
+ model_input = model_input.to(weight_dtype)
+
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(model_input)
+ if args.noise_offset:
+ # https://www.crosslabs.org//blog/diffusion-with-offset-noise
+ noise += args.noise_offset * torch.randn(
+ (model_input.shape[0], model_input.shape[1], 1, 1), device=model_input.device
+ )
+
+ bsz = model_input.shape[0]
+ # Sample a random timestep for each image
+ if args.loss_type == "huber" or args.loss_type == "smooth_l1":
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (1,), device="cpu")
+ timestep = timesteps.item()
+
+ if args.huber_schedule == "exponential":
+ alpha = -math.log(args.huber_c) / noise_scheduler.config.num_train_timesteps
+ huber_c = math.exp(-alpha * timestep)
+ elif args.huber_schedule == "snr":
+ alphas_cumprod = noise_scheduler.alphas_cumprod[timestep]
+ sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5
+ huber_c = (1 - args.huber_c) / (1 + sigmas) ** 2 + args.huber_c
+ elif args.huber_schedule == "constant":
+ huber_c = args.huber_c
+ else:
+ raise NotImplementedError(f"Unknown Huber loss schedule {args.huber_schedule}!")
+
+ timesteps = timesteps.repeat(bsz).to(model_input.device)
+ elif args.loss_type == "l2":
+ timesteps = torch.randint(
+ 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device
+ )
+ huber_c = 1 # may be anything, as it's not used
+ else:
+ raise NotImplementedError(f"Unknown loss type {args.loss_type}")
+
+ timesteps = timesteps.long()
+
+ # Add noise to the model input according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)
+
+ # time ids
+ def compute_time_ids(original_size, crops_coords_top_left):
+ # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
+ target_size = (args.resolution, args.resolution)
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
+ add_time_ids = torch.tensor([add_time_ids])
+ add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype)
+ return add_time_ids
+
+ add_time_ids = torch.cat(
+ [compute_time_ids(s, c) for s, c in zip(batch["original_sizes"], batch["crop_top_lefts"])]
+ )
+
+ # Predict the noise residual
+ unet_added_conditions = {"time_ids": add_time_ids}
+ prompt_embeds, pooled_prompt_embeds = encode_prompt(
+ text_encoders=[text_encoder_one, text_encoder_two],
+ tokenizers=None,
+ prompt=None,
+ text_input_ids_list=[batch["input_ids_one"], batch["input_ids_two"]],
+ )
+ unet_added_conditions.update({"text_embeds": pooled_prompt_embeds})
+ model_pred = unet(
+ noisy_model_input,
+ timesteps,
+ prompt_embeds,
+ added_cond_kwargs=unet_added_conditions,
+ return_dict=False,
+ )[0]
+
+ # Get the target for loss depending on the prediction type
+ if args.prediction_type is not None:
+ # set prediction_type of scheduler if defined
+ noise_scheduler.register_to_config(prediction_type=args.prediction_type)
+
+ if noise_scheduler.config.prediction_type == "epsilon":
+ target = noise
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ target = noise_scheduler.get_velocity(model_input, noise, timesteps)
+ else:
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
+
+ if args.snr_gamma is None:
+ loss = conditional_loss(
+ model_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c
+ )
+ else:
+ # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
+ # Since we predict the noise instead of x_0, the original formulation is slightly changed.
+ # This is discussed in Section 4.2 of the same paper.
+ snr = compute_snr(noise_scheduler, timesteps)
+ mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(
+ dim=1
+ )[0]
+ if noise_scheduler.config.prediction_type == "epsilon":
+ mse_loss_weights = mse_loss_weights / snr
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ mse_loss_weights = mse_loss_weights / (snr + 1)
+
+ loss = conditional_loss(
+ model_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
+ )
+ loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
+ loss = loss.mean()
+ if args.debug_loss and "filenames" in batch:
+ for fname in batch["filenames"]:
+ accelerator.log({"loss_for_" + fname: loss}, step=global_step)
+ # Gather the losses across all processes for logging (if we use distributed training).
+ avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
+ train_loss += avg_loss.item() / args.gradient_accumulation_steps
+
+ # Backpropagate
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ accelerator.clip_grad_norm_(params_to_optimize, args.max_grad_norm)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ global_step += 1
+ accelerator.log({"train_loss": train_loss}, step=global_step)
+ train_loss = 0.0
+
+ if accelerator.is_main_process:
+ if global_step % args.checkpointing_steps == 0:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ if accelerator.is_main_process:
+ if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
+ logger.info(
+ f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
+ f" {args.validation_prompt}."
+ )
+ # create pipeline
+ pipeline = StableDiffusionXLPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ vae=vae,
+ text_encoder=unwrap_model(text_encoder_one),
+ text_encoder_2=unwrap_model(text_encoder_two),
+ unet=unwrap_model(unet),
+ revision=args.revision,
+ variant=args.variant,
+ torch_dtype=weight_dtype,
+ )
+
+ pipeline = pipeline.to(accelerator.device)
+ pipeline.set_progress_bar_config(disable=True)
+
+ # run inference
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
+ pipeline_args = {"prompt": args.validation_prompt}
+
+ with torch.cuda.amp.autocast():
+ images = [
+ pipeline(**pipeline_args, generator=generator).images[0]
+ for _ in range(args.num_validation_images)
+ ]
+
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
+ if tracker.name == "wandb":
+ tracker.log(
+ {
+ "validation": [
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
+ for i, image in enumerate(images)
+ ]
+ }
+ )
+
+ del pipeline
+ torch.cuda.empty_cache()
+
+ # Save the lora layers
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ unet = unwrap_model(unet)
+ unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet))
+
+ if args.train_text_encoder:
+ text_encoder_one = unwrap_model(text_encoder_one)
+ text_encoder_two = unwrap_model(text_encoder_two)
+
+ text_encoder_lora_layers = convert_state_dict_to_diffusers(get_peft_model_state_dict(text_encoder_one))
+ text_encoder_2_lora_layers = convert_state_dict_to_diffusers(get_peft_model_state_dict(text_encoder_two))
+ else:
+ text_encoder_lora_layers = None
+ text_encoder_2_lora_layers = None
+
+ StableDiffusionXLPipeline.save_lora_weights(
+ save_directory=args.output_dir,
+ unet_lora_layers=unet_lora_state_dict,
+ text_encoder_lora_layers=text_encoder_lora_layers,
+ text_encoder_2_lora_layers=text_encoder_2_lora_layers,
+ )
+
+ del unet
+ del text_encoder_one
+ del text_encoder_two
+ del text_encoder_lora_layers
+ del text_encoder_2_lora_layers
+ torch.cuda.empty_cache()
+
+ # Final inference
+ # Make sure vae.dtype is consistent with the unet.dtype
+ if args.mixed_precision == "fp16":
+ vae.to(weight_dtype)
+ # Load previous pipeline
+ pipeline = StableDiffusionXLPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ vae=vae,
+ revision=args.revision,
+ variant=args.variant,
+ torch_dtype=weight_dtype,
+ )
+ pipeline = pipeline.to(accelerator.device)
+
+ # load attention processors
+ pipeline.load_lora_weights(args.output_dir)
+
+ # run inference
+ images = []
+ if args.validation_prompt and args.num_validation_images > 0:
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
+ images = [
+ pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]
+ for _ in range(args.num_validation_images)
+ ]
+
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC")
+ if tracker.name == "wandb":
+ tracker.log(
+ {
+ "test": [
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
+ for i, image in enumerate(images)
+ ]
+ }
+ )
+
+ if args.push_to_hub:
+ save_model_card(
+ repo_id,
+ images=images,
+ base_model=args.pretrained_model_name_or_path,
+ dataset_name=args.dataset_name,
+ train_text_encoder=args.train_text_encoder,
+ repo_folder=args.output_dir,
+ vae_path=args.pretrained_vae_model_name_or_path,
+ )
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ main(args)
diff --git a/diffusers/examples/research_projects/scheduled_huber_loss_training/text_to_image/train_text_to_image_sdxl.py b/diffusers/examples/research_projects/scheduled_huber_loss_training/text_to_image/train_text_to_image_sdxl.py
new file mode 100644
index 0000000000000000000000000000000000000000..a056bcfc8cb10e33b409377792111065d53de14e
--- /dev/null
+++ b/diffusers/examples/research_projects/scheduled_huber_loss_training/text_to_image/train_text_to_image_sdxl.py
@@ -0,0 +1,1394 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Fine-tuning script for Stable Diffusion XL for text2image."""
+
+import argparse
+import functools
+import gc
+import logging
+import math
+import os
+import random
+import shutil
+from pathlib import Path
+
+import accelerate
+import datasets
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import ProjectConfiguration, set_seed
+from datasets import concatenate_datasets, load_dataset
+from huggingface_hub import create_repo, upload_folder
+from packaging import version
+from torchvision import transforms
+from torchvision.transforms.functional import crop
+from tqdm.auto import tqdm
+from transformers import AutoTokenizer, PretrainedConfig
+
+import diffusers
+from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionXLPipeline, UNet2DConditionModel
+from diffusers.optimization import get_scheduler
+from diffusers.training_utils import EMAModel, compute_snr
+from diffusers.utils import check_min_version, is_wandb_available
+from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
+from diffusers.utils.import_utils import is_xformers_available
+from diffusers.utils.torch_utils import is_compiled_module
+
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.28.0.dev0")
+
+logger = get_logger(__name__)
+
+
+DATASET_NAME_MAPPING = {
+ "lambdalabs/naruto-blip-captions": ("image", "text"),
+}
+
+
+def save_model_card(
+ repo_id: str,
+ images: list = None,
+ validation_prompt: str = None,
+ base_model: str = None,
+ dataset_name: str = None,
+ repo_folder: str = None,
+ vae_path: str = None,
+):
+ img_str = ""
+ if images is not None:
+ for i, image in enumerate(images):
+ image.save(os.path.join(repo_folder, f"image_{i}.png"))
+ img_str += f"\n"
+
+ model_description = f"""
+# Text-to-image finetuning - {repo_id}
+
+This pipeline was finetuned from **{base_model}** on the **{dataset_name}** dataset. Below are some example images generated with the finetuned pipeline using the following prompt: {validation_prompt}: \n
+{img_str}
+
+Special VAE used for training: {vae_path}.
+"""
+
+ model_card = load_or_create_model_card(
+ repo_id_or_path=repo_id,
+ from_training=True,
+ license="creativeml-openrail-m",
+ base_model=base_model,
+ model_description=model_description,
+ inference=True,
+ )
+
+ tags = [
+ "stable-diffusion-xl",
+ "stable-diffusion-xl-diffusers",
+ "text-to-image",
+ "diffusers-training",
+ "diffusers",
+ ]
+ model_card = populate_model_card(model_card, tags=tags)
+
+ model_card.save(os.path.join(repo_folder, "README.md"))
+
+
+def import_model_class_from_model_name_or_path(
+ pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
+):
+ text_encoder_config = PretrainedConfig.from_pretrained(
+ pretrained_model_name_or_path, subfolder=subfolder, revision=revision
+ )
+ model_class = text_encoder_config.architectures[0]
+
+ if model_class == "CLIPTextModel":
+ from transformers import CLIPTextModel
+
+ return CLIPTextModel
+ elif model_class == "CLIPTextModelWithProjection":
+ from transformers import CLIPTextModelWithProjection
+
+ return CLIPTextModelWithProjection
+ else:
+ raise ValueError(f"{model_class} is not supported.")
+
+
+def parse_args(input_args=None):
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--pretrained_vae_model_name_or_path",
+ type=str,
+ default=None,
+ help="Path to pretrained VAE model with better numerical stability. More details: https://github.com/huggingface/diffusers/pull/4038.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--variant",
+ type=str,
+ default=None,
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
+ )
+ parser.add_argument(
+ "--dataset_name",
+ type=str,
+ default=None,
+ help=(
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
+ " or to a folder containing files that 🤗 Datasets can understand."
+ ),
+ )
+ parser.add_argument(
+ "--dataset_config_name",
+ type=str,
+ default=None,
+ help="The config of the Dataset, leave as None if there's only one config.",
+ )
+ parser.add_argument(
+ "--train_data_dir",
+ type=str,
+ default=None,
+ help=(
+ "A folder containing the training data. Folder contents must follow the structure described in"
+ " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
+ " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
+ ),
+ )
+ parser.add_argument(
+ "--image_column", type=str, default="image", help="The column of the dataset containing an image."
+ )
+ parser.add_argument(
+ "--caption_column",
+ type=str,
+ default="text",
+ help="The column of the dataset containing a caption or a list of captions.",
+ )
+ parser.add_argument(
+ "--validation_prompt",
+ type=str,
+ default=None,
+ help="A prompt that is used during validation to verify that the model is learning.",
+ )
+ parser.add_argument(
+ "--num_validation_images",
+ type=int,
+ default=4,
+ help="Number of images that should be generated during validation with `validation_prompt`.",
+ )
+ parser.add_argument(
+ "--validation_epochs",
+ type=int,
+ default=1,
+ help=(
+ "Run fine-tuning validation every X epochs. The validation process consists of running the prompt"
+ " `args.validation_prompt` multiple times: `args.num_validation_images`."
+ ),
+ )
+ parser.add_argument(
+ "--max_train_samples",
+ type=int,
+ default=None,
+ help=(
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ ),
+ )
+ parser.add_argument(
+ "--proportion_empty_prompts",
+ type=float,
+ default=0,
+ help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).",
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="sdxl-model-finetuned",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=1024,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--center_crop",
+ default=False,
+ action="store_true",
+ help=(
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
+ " cropped. The images will be resized to the resolution first before cropping."
+ ),
+ )
+ parser.add_argument(
+ "--random_flip",
+ action="store_true",
+ help="whether to randomly flip images horizontally",
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=100)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
+ " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--timestep_bias_strategy",
+ type=str,
+ default="none",
+ choices=["earlier", "later", "range", "none"],
+ help=(
+ "The timestep bias strategy, which may help direct the model toward learning low or high frequency details."
+ " Choices: ['earlier', 'later', 'range', 'none']."
+ " The default is 'none', which means no bias is applied, and training proceeds normally."
+ " The value of 'later' will increase the frequency of the model's final training timesteps."
+ ),
+ )
+ parser.add_argument(
+ "--timestep_bias_multiplier",
+ type=float,
+ default=1.0,
+ help=(
+ "The multiplier for the bias. Defaults to 1.0, which means no bias is applied."
+ " A value of 2.0 will double the weight of the bias, and a value of 0.5 will halve it."
+ ),
+ )
+ parser.add_argument(
+ "--timestep_bias_begin",
+ type=int,
+ default=0,
+ help=(
+ "When using `--timestep_bias_strategy=range`, the beginning (inclusive) timestep to bias."
+ " Defaults to zero, which equates to having no specific bias."
+ ),
+ )
+ parser.add_argument(
+ "--timestep_bias_end",
+ type=int,
+ default=1000,
+ help=(
+ "When using `--timestep_bias_strategy=range`, the final timestep (inclusive) to bias."
+ " Defaults to 1000, which is the number of timesteps that Stable Diffusion is trained on."
+ ),
+ )
+ parser.add_argument(
+ "--timestep_bias_portion",
+ type=float,
+ default=0.25,
+ help=(
+ "The portion of timesteps to bias. Defaults to 0.25, which 25% of timesteps will be biased."
+ " A value of 0.5 will bias one half of the timesteps. The value provided for `--timestep_bias_strategy` determines"
+ " whether the biased portions are in the earlier or later timesteps."
+ ),
+ )
+ parser.add_argument(
+ "--snr_gamma",
+ type=float,
+ default=None,
+ help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
+ "More details here: https://arxiv.org/abs/2303.09556.",
+ )
+ parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ parser.add_argument(
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--prediction_type",
+ type=str,
+ default=None,
+ help="The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediction_type` is chosen.",
+ )
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+ parser.add_argument(
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
+ )
+ parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.")
+ parser.add_argument(
+ "--loss_type",
+ type=str,
+ default="l2",
+ choices=["l2", "huber", "smooth_l1"],
+ help="The type of loss to use and whether it's timestep-scheduled. See Issue #7488 for more info.",
+ )
+ parser.add_argument(
+ "--huber_schedule",
+ type=str,
+ default="snr",
+ choices=["constant", "exponential", "snr"],
+ help="The schedule to use for the huber losses parameter",
+ )
+ parser.add_argument(
+ "--huber_c",
+ type=float,
+ default=0.1,
+ help="The huber loss parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type.",
+ )
+
+ if input_args is not None:
+ args = parser.parse_args(input_args)
+ else:
+ args = parser.parse_args()
+
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ # Sanity checks
+ if args.dataset_name is None and args.train_data_dir is None:
+ raise ValueError("Need either a dataset name or a training folder.")
+
+ if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1:
+ raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].")
+
+ return args
+
+
+# Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt
+def encode_prompt(batch, text_encoders, tokenizers, proportion_empty_prompts, caption_column, is_train=True):
+ prompt_embeds_list = []
+ prompt_batch = batch[caption_column]
+
+ captions = []
+ for caption in prompt_batch:
+ if random.random() < proportion_empty_prompts:
+ captions.append("")
+ elif isinstance(caption, str):
+ captions.append(caption)
+ elif isinstance(caption, (list, np.ndarray)):
+ # take a random caption if there are multiple
+ captions.append(random.choice(caption) if is_train else caption[0])
+
+ with torch.no_grad():
+ for tokenizer, text_encoder in zip(tokenizers, text_encoders):
+ text_inputs = tokenizer(
+ captions,
+ padding="max_length",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ prompt_embeds = text_encoder(
+ text_input_ids.to(text_encoder.device),
+ output_hidden_states=True,
+ return_dict=False,
+ )
+
+ # We are only ALWAYS interested in the pooled output of the final text encoder
+ pooled_prompt_embeds = prompt_embeds[0]
+ prompt_embeds = prompt_embeds[-1][-2]
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
+ prompt_embeds_list.append(prompt_embeds)
+
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
+ pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
+ return {"prompt_embeds": prompt_embeds.cpu(), "pooled_prompt_embeds": pooled_prompt_embeds.cpu()}
+
+
+def compute_vae_encodings(batch, vae):
+ images = batch.pop("pixel_values")
+ pixel_values = torch.stack(list(images))
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
+ pixel_values = pixel_values.to(vae.device, dtype=vae.dtype)
+
+ with torch.no_grad():
+ model_input = vae.encode(pixel_values).latent_dist.sample()
+ model_input = model_input * vae.config.scaling_factor
+ return {"model_input": model_input.cpu()}
+
+
+def generate_timestep_weights(args, num_timesteps):
+ weights = torch.ones(num_timesteps)
+
+ # Determine the indices to bias
+ num_to_bias = int(args.timestep_bias_portion * num_timesteps)
+
+ if args.timestep_bias_strategy == "later":
+ bias_indices = slice(-num_to_bias, None)
+ elif args.timestep_bias_strategy == "earlier":
+ bias_indices = slice(0, num_to_bias)
+ elif args.timestep_bias_strategy == "range":
+ # Out of the possible 1000 timesteps, we might want to focus on eg. 200-500.
+ range_begin = args.timestep_bias_begin
+ range_end = args.timestep_bias_end
+ if range_begin < 0:
+ raise ValueError(
+ "When using the range strategy for timestep bias, you must provide a beginning timestep greater or equal to zero."
+ )
+ if range_end > num_timesteps:
+ raise ValueError(
+ "When using the range strategy for timestep bias, you must provide an ending timestep smaller than the number of timesteps."
+ )
+ bias_indices = slice(range_begin, range_end)
+ else: # 'none' or any other string
+ return weights
+ if args.timestep_bias_multiplier <= 0:
+ return ValueError(
+ "The parameter --timestep_bias_multiplier is not intended to be used to disable the training of specific timesteps."
+ " If it was intended to disable timestep bias, use `--timestep_bias_strategy none` instead."
+ " A timestep bias multiplier less than or equal to 0 is not allowed."
+ )
+
+ # Apply the bias
+ weights[bias_indices] *= args.timestep_bias_multiplier
+
+ # Normalize
+ weights /= weights.sum()
+
+ return weights
+
+
+# NOTE: if you're using the scheduled version, huber_c has to depend on the timesteps already
+def conditional_loss(
+ model_pred: torch.Tensor,
+ target: torch.Tensor,
+ reduction: str = "mean",
+ loss_type: str = "l2",
+ huber_c: float = 0.1,
+):
+ if loss_type == "l2":
+ loss = F.mse_loss(model_pred, target, reduction=reduction)
+ elif loss_type == "huber":
+ loss = 2 * huber_c * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
+ if reduction == "mean":
+ loss = torch.mean(loss)
+ elif reduction == "sum":
+ loss = torch.sum(loss)
+ elif loss_type == "smooth_l1":
+ loss = 2 * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
+ if reduction == "mean":
+ loss = torch.mean(loss)
+ elif reduction == "sum":
+ loss = torch.sum(loss)
+ else:
+ raise NotImplementedError(f"Unsupported Loss Type {loss_type}")
+ return loss
+
+
+def main(args):
+ if args.report_to == "wandb" and args.hub_token is not None:
+ raise ValueError(
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
+ " Please use `huggingface-cli login` to authenticate with the Hub."
+ )
+
+ logging_dir = Path(args.output_dir, args.logging_dir)
+
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ )
+
+ if args.report_to == "wandb":
+ if not is_wandb_available():
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
+ import wandb
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ datasets.utils.logging.set_verbosity_warning()
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ datasets.utils.logging.set_verbosity_error()
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
+ ).repo_id
+
+ # Load the tokenizers
+ tokenizer_one = AutoTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="tokenizer",
+ revision=args.revision,
+ use_fast=False,
+ )
+ tokenizer_two = AutoTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="tokenizer_2",
+ revision=args.revision,
+ use_fast=False,
+ )
+
+ # import correct text encoder classes
+ text_encoder_cls_one = import_model_class_from_model_name_or_path(
+ args.pretrained_model_name_or_path, args.revision
+ )
+ text_encoder_cls_two = import_model_class_from_model_name_or_path(
+ args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2"
+ )
+
+ # Load scheduler and models
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
+ # Check for terminal SNR in combination with SNR Gamma
+ text_encoder_one = text_encoder_cls_one.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
+ )
+ text_encoder_two = text_encoder_cls_two.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant
+ )
+ vae_path = (
+ args.pretrained_model_name_or_path
+ if args.pretrained_vae_model_name_or_path is None
+ else args.pretrained_vae_model_name_or_path
+ )
+ vae = AutoencoderKL.from_pretrained(
+ vae_path,
+ subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
+ revision=args.revision,
+ variant=args.variant,
+ )
+ unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
+ )
+
+ # Freeze vae and text encoders.
+ vae.requires_grad_(False)
+ text_encoder_one.requires_grad_(False)
+ text_encoder_two.requires_grad_(False)
+ # Set unet as trainable.
+ unet.train()
+
+ # For mixed precision training we cast all non-trainable weights to half-precision
+ # as these weights are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ # Move unet, vae and text_encoder to device and cast to weight_dtype
+ # The VAE is in float32 to avoid NaN losses.
+ vae.to(accelerator.device, dtype=torch.float32)
+ text_encoder_one.to(accelerator.device, dtype=weight_dtype)
+ text_encoder_two.to(accelerator.device, dtype=weight_dtype)
+
+ # Create EMA for the unet.
+ if args.use_ema:
+ ema_unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
+ )
+ ema_unet = EMAModel(ema_unet.parameters(), model_cls=UNet2DConditionModel, model_config=ema_unet.config)
+
+ if args.enable_xformers_memory_efficient_attention:
+ if is_xformers_available():
+ import xformers
+
+ xformers_version = version.parse(xformers.__version__)
+ if xformers_version == version.parse("0.0.16"):
+ logger.warning(
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
+ )
+ unet.enable_xformers_memory_efficient_attention()
+ else:
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
+
+ # `accelerate` 0.16.0 will have better support for customized saving
+ if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
+ def save_model_hook(models, weights, output_dir):
+ if accelerator.is_main_process:
+ if args.use_ema:
+ ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema"))
+
+ for i, model in enumerate(models):
+ model.save_pretrained(os.path.join(output_dir, "unet"))
+
+ # make sure to pop weight so that corresponding model is not saved again
+ weights.pop()
+
+ def load_model_hook(models, input_dir):
+ if args.use_ema:
+ load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), UNet2DConditionModel)
+ ema_unet.load_state_dict(load_model.state_dict())
+ ema_unet.to(accelerator.device)
+ del load_model
+
+ for _ in range(len(models)):
+ # pop models so that they are not loaded again
+ model = models.pop()
+
+ # load diffusers style into model
+ load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet")
+ model.register_to_config(**load_model.config)
+
+ model.load_state_dict(load_model.state_dict())
+ del load_model
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
+
+ if args.gradient_checkpointing:
+ unet.enable_gradient_checkpointing()
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
+ )
+
+ optimizer_class = bnb.optim.AdamW8bit
+ else:
+ optimizer_class = torch.optim.AdamW
+
+ # Optimizer creation
+ params_to_optimize = unet.parameters()
+ optimizer = optimizer_class(
+ params_to_optimize,
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ # Get the datasets: you can either provide your own training and evaluation files (see below)
+ # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
+
+ # In distributed training, the load_dataset function guarantees that only one local process can concurrently
+ # download the dataset.
+ if args.dataset_name is not None:
+ # Downloading and loading a dataset from the hub.
+ dataset = load_dataset(
+ args.dataset_name,
+ args.dataset_config_name,
+ cache_dir=args.cache_dir,
+ )
+ else:
+ data_files = {}
+ if args.train_data_dir is not None:
+ data_files["train"] = os.path.join(args.train_data_dir, "**")
+ dataset = load_dataset(
+ "imagefolder",
+ data_files=data_files,
+ cache_dir=args.cache_dir,
+ )
+ # See more about loading custom images at
+ # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
+
+ # Preprocessing the datasets.
+ # We need to tokenize inputs and targets.
+ column_names = dataset["train"].column_names
+
+ # 6. Get the column names for input/target.
+ dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None)
+ if args.image_column is None:
+ image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
+ else:
+ image_column = args.image_column
+ if image_column not in column_names:
+ raise ValueError(
+ f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}"
+ )
+ if args.caption_column is None:
+ caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
+ else:
+ caption_column = args.caption_column
+ if caption_column not in column_names:
+ raise ValueError(
+ f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}"
+ )
+
+ # Preprocessing the datasets.
+ train_resize = transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR)
+ train_crop = transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution)
+ train_flip = transforms.RandomHorizontalFlip(p=1.0)
+ train_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
+
+ def preprocess_train(examples):
+ images = [image.convert("RGB") for image in examples[image_column]]
+ # image aug
+ original_sizes = []
+ all_images = []
+ crop_top_lefts = []
+ for image in images:
+ original_sizes.append((image.height, image.width))
+ image = train_resize(image)
+ if args.random_flip and random.random() < 0.5:
+ # flip
+ image = train_flip(image)
+ if args.center_crop:
+ y1 = max(0, int(round((image.height - args.resolution) / 2.0)))
+ x1 = max(0, int(round((image.width - args.resolution) / 2.0)))
+ image = train_crop(image)
+ else:
+ y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution))
+ image = crop(image, y1, x1, h, w)
+ crop_top_left = (y1, x1)
+ crop_top_lefts.append(crop_top_left)
+ image = train_transforms(image)
+ all_images.append(image)
+
+ examples["original_sizes"] = original_sizes
+ examples["crop_top_lefts"] = crop_top_lefts
+ examples["pixel_values"] = all_images
+ return examples
+
+ with accelerator.main_process_first():
+ if args.max_train_samples is not None:
+ dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
+ # Set the training transforms
+ train_dataset = dataset["train"].with_transform(preprocess_train)
+
+ # Let's first compute all the embeddings so that we can free up the text encoders
+ # from memory. We will pre-compute the VAE encodings too.
+ text_encoders = [text_encoder_one, text_encoder_two]
+ tokenizers = [tokenizer_one, tokenizer_two]
+ compute_embeddings_fn = functools.partial(
+ encode_prompt,
+ text_encoders=text_encoders,
+ tokenizers=tokenizers,
+ proportion_empty_prompts=args.proportion_empty_prompts,
+ caption_column=args.caption_column,
+ )
+ compute_vae_encodings_fn = functools.partial(compute_vae_encodings, vae=vae)
+ with accelerator.main_process_first():
+ from datasets.fingerprint import Hasher
+
+ # fingerprint used by the cache for the other processes to load the result
+ # details: https://github.com/huggingface/diffusers/pull/4038#discussion_r1266078401
+ new_fingerprint = Hasher.hash(args)
+ new_fingerprint_for_vae = Hasher.hash(vae_path)
+ train_dataset_with_embeddings = train_dataset.map(
+ compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint
+ )
+ train_dataset_with_vae = train_dataset.map(
+ compute_vae_encodings_fn,
+ batched=True,
+ batch_size=args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps,
+ new_fingerprint=new_fingerprint_for_vae,
+ )
+ precomputed_dataset = concatenate_datasets(
+ [train_dataset_with_embeddings, train_dataset_with_vae.remove_columns(["image", "text"])], axis=1
+ )
+ precomputed_dataset = precomputed_dataset.with_transform(preprocess_train)
+
+ del compute_vae_encodings_fn, compute_embeddings_fn, text_encoder_one, text_encoder_two
+ del text_encoders, tokenizers, vae
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ def collate_fn(examples):
+ model_input = torch.stack([torch.tensor(example["model_input"]) for example in examples])
+ original_sizes = [example["original_sizes"] for example in examples]
+ crop_top_lefts = [example["crop_top_lefts"] for example in examples]
+ prompt_embeds = torch.stack([torch.tensor(example["prompt_embeds"]) for example in examples])
+ pooled_prompt_embeds = torch.stack([torch.tensor(example["pooled_prompt_embeds"]) for example in examples])
+
+ return {
+ "model_input": model_input,
+ "prompt_embeds": prompt_embeds,
+ "pooled_prompt_embeds": pooled_prompt_embeds,
+ "original_sizes": original_sizes,
+ "crop_top_lefts": crop_top_lefts,
+ }
+
+ # DataLoaders creation:
+ train_dataloader = torch.utils.data.DataLoader(
+ precomputed_dataset,
+ shuffle=True,
+ collate_fn=collate_fn,
+ batch_size=args.train_batch_size,
+ num_workers=args.dataloader_num_workers,
+ )
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
+ num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
+ )
+
+ # Prepare everything with our `accelerator`.
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ unet, optimizer, train_dataloader, lr_scheduler
+ )
+
+ if args.use_ema:
+ ema_unet.to(accelerator.device)
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ accelerator.init_trackers("text2image-fine-tune-sdxl", config=vars(args))
+
+ # Function for unwrapping if torch.compile() was used in accelerate.
+ def unwrap_model(model):
+ model = accelerator.unwrap_model(model)
+ model = model._orig_mod if is_compiled_module(model) else model
+ return model
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(precomputed_dataset)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the most recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ initial_global_step = 0
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ initial_global_step = global_step
+ first_epoch = global_step // num_update_steps_per_epoch
+
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
+
+ for epoch in range(first_epoch, args.num_train_epochs):
+ train_loss = 0.0
+ for step, batch in enumerate(train_dataloader):
+ with accelerator.accumulate(unet):
+ # Sample noise that we'll add to the latents
+ model_input = batch["model_input"].to(accelerator.device)
+ noise = torch.randn_like(model_input)
+ if args.noise_offset:
+ # https://www.crosslabs.org//blog/diffusion-with-offset-noise
+ noise += args.noise_offset * torch.randn(
+ (model_input.shape[0], model_input.shape[1], 1, 1), device=model_input.device
+ )
+
+ bsz = model_input.shape[0]
+ if args.timestep_bias_strategy == "none":
+ # Sample a random timestep for each image
+ if args.loss_type == "huber" or args.loss_type == "smooth_l1":
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (1,), device="cpu")
+ timestep = timesteps.item()
+
+ if args.huber_schedule == "exponential":
+ alpha = -math.log(args.huber_c) / noise_scheduler.config.num_train_timesteps
+ huber_c = math.exp(-alpha * timestep)
+ elif args.huber_schedule == "snr":
+ alphas_cumprod = noise_scheduler.alphas_cumprod[timestep]
+ sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5
+ huber_c = (1 - args.huber_c) / (1 + sigmas) ** 2 + args.huber_c
+ elif args.huber_schedule == "constant":
+ huber_c = args.huber_c
+ else:
+ raise NotImplementedError(f"Unknown Huber loss schedule {args.huber_schedule}!")
+
+ timesteps = timesteps.repeat(bsz).to(model_input.device)
+ elif args.loss_type == "l2":
+ timesteps = torch.randint(
+ 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device
+ )
+ huber_c = 1 # may be anything, as it's not used
+ else:
+ raise NotImplementedError(f"Unknown loss type {args.loss_type}")
+
+ timesteps = timesteps.long()
+
+ else:
+ if "huber_scheduled" in args.loss_type:
+ raise NotImplementedError(
+ "Randomly weighted timesteps not implemented yet for scheduled huber loss!"
+ )
+ else:
+ huber_c = args.huber_c
+ # Sample a random timestep for each image, potentially biased by the timestep weights.
+ # Biasing the timestep weights allows us to spend less time training irrelevant timesteps.
+ weights = generate_timestep_weights(args, noise_scheduler.config.num_train_timesteps).to(
+ model_input.device
+ )
+ timesteps = torch.multinomial(weights, bsz, replacement=True).long()
+
+ # Add noise to the model input according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)
+
+ # time ids
+ def compute_time_ids(original_size, crops_coords_top_left):
+ # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
+ target_size = (args.resolution, args.resolution)
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
+ add_time_ids = torch.tensor([add_time_ids])
+ add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype)
+ return add_time_ids
+
+ add_time_ids = torch.cat(
+ [compute_time_ids(s, c) for s, c in zip(batch["original_sizes"], batch["crop_top_lefts"])]
+ )
+
+ # Predict the noise residual
+ unet_added_conditions = {"time_ids": add_time_ids}
+ prompt_embeds = batch["prompt_embeds"].to(accelerator.device)
+ pooled_prompt_embeds = batch["pooled_prompt_embeds"].to(accelerator.device)
+ unet_added_conditions.update({"text_embeds": pooled_prompt_embeds})
+ model_pred = unet(
+ noisy_model_input,
+ timesteps,
+ prompt_embeds,
+ added_cond_kwargs=unet_added_conditions,
+ return_dict=False,
+ )[0]
+
+ # Get the target for loss depending on the prediction type
+ if args.prediction_type is not None:
+ # set prediction_type of scheduler if defined
+ noise_scheduler.register_to_config(prediction_type=args.prediction_type)
+
+ if noise_scheduler.config.prediction_type == "epsilon":
+ target = noise
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ target = noise_scheduler.get_velocity(model_input, noise, timesteps)
+ elif noise_scheduler.config.prediction_type == "sample":
+ # We set the target to latents here, but the model_pred will return the noise sample prediction.
+ target = model_input
+ # We will have to subtract the noise residual from the prediction to get the target sample.
+ model_pred = model_pred - noise
+ else:
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
+
+ if args.snr_gamma is None:
+ loss = conditional_loss(
+ model_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c
+ )
+ else:
+ # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
+ # Since we predict the noise instead of x_0, the original formulation is slightly changed.
+ # This is discussed in Section 4.2 of the same paper.
+ snr = compute_snr(noise_scheduler, timesteps)
+ mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(
+ dim=1
+ )[0]
+ if noise_scheduler.config.prediction_type == "epsilon":
+ mse_loss_weights = mse_loss_weights / snr
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ mse_loss_weights = mse_loss_weights / (snr + 1)
+
+ loss = conditional_loss(
+ model_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
+ )
+ loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
+ loss = loss.mean()
+
+ # Gather the losses across all processes for logging (if we use distributed training).
+ avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
+ train_loss += avg_loss.item() / args.gradient_accumulation_steps
+
+ # Backpropagate
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ params_to_clip = unet.parameters()
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ if args.use_ema:
+ ema_unet.step(unet.parameters())
+ progress_bar.update(1)
+ global_step += 1
+ accelerator.log({"train_loss": train_loss}, step=global_step)
+ train_loss = 0.0
+
+ if accelerator.is_main_process:
+ if global_step % args.checkpointing_steps == 0:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ if accelerator.is_main_process:
+ if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
+ logger.info(
+ f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
+ f" {args.validation_prompt}."
+ )
+ if args.use_ema:
+ # Store the UNet parameters temporarily and load the EMA parameters to perform inference.
+ ema_unet.store(unet.parameters())
+ ema_unet.copy_to(unet.parameters())
+
+ # create pipeline
+ vae = AutoencoderKL.from_pretrained(
+ vae_path,
+ subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
+ revision=args.revision,
+ variant=args.variant,
+ )
+ pipeline = StableDiffusionXLPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ vae=vae,
+ unet=accelerator.unwrap_model(unet),
+ revision=args.revision,
+ variant=args.variant,
+ torch_dtype=weight_dtype,
+ )
+ if args.prediction_type is not None:
+ scheduler_args = {"prediction_type": args.prediction_type}
+ pipeline.scheduler = pipeline.scheduler.from_config(pipeline.scheduler.config, **scheduler_args)
+
+ pipeline = pipeline.to(accelerator.device)
+ pipeline.set_progress_bar_config(disable=True)
+
+ # run inference
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
+ pipeline_args = {"prompt": args.validation_prompt}
+
+ with torch.cuda.amp.autocast():
+ images = [
+ pipeline(**pipeline_args, generator=generator, num_inference_steps=25).images[0]
+ for _ in range(args.num_validation_images)
+ ]
+
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
+ if tracker.name == "wandb":
+ tracker.log(
+ {
+ "validation": [
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
+ for i, image in enumerate(images)
+ ]
+ }
+ )
+
+ del pipeline
+ torch.cuda.empty_cache()
+
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ unet = unwrap_model(unet)
+ if args.use_ema:
+ ema_unet.copy_to(unet.parameters())
+
+ # Serialize pipeline.
+ vae = AutoencoderKL.from_pretrained(
+ vae_path,
+ subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
+ revision=args.revision,
+ variant=args.variant,
+ torch_dtype=weight_dtype,
+ )
+ pipeline = StableDiffusionXLPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ unet=unet,
+ vae=vae,
+ revision=args.revision,
+ variant=args.variant,
+ torch_dtype=weight_dtype,
+ )
+ if args.prediction_type is not None:
+ scheduler_args = {"prediction_type": args.prediction_type}
+ pipeline.scheduler = pipeline.scheduler.from_config(pipeline.scheduler.config, **scheduler_args)
+ pipeline.save_pretrained(args.output_dir)
+
+ # run inference
+ images = []
+ if args.validation_prompt and args.num_validation_images > 0:
+ pipeline = pipeline.to(accelerator.device)
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
+ with torch.cuda.amp.autocast():
+ images = [
+ pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]
+ for _ in range(args.num_validation_images)
+ ]
+
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC")
+ if tracker.name == "wandb":
+ tracker.log(
+ {
+ "test": [
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
+ for i, image in enumerate(images)
+ ]
+ }
+ )
+
+ if args.push_to_hub:
+ save_model_card(
+ repo_id=repo_id,
+ images=images,
+ validation_prompt=args.validation_prompt,
+ base_model=args.pretrained_model_name_or_path,
+ dataset_name=args.dataset_name,
+ repo_folder=args.output_dir,
+ vae_path=args.pretrained_vae_model_name_or_path,
+ )
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ main(args)
diff --git a/diffusers/examples/research_projects/sd3_lora_colab/README.md b/diffusers/examples/research_projects/sd3_lora_colab/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..b7d7eedfb5dcfaee925b972fcad66b63d9ea9a12
--- /dev/null
+++ b/diffusers/examples/research_projects/sd3_lora_colab/README.md
@@ -0,0 +1,38 @@
+# Running Stable Diffusion 3 DreamBooth LoRA training under 16GB
+
+This is an **EDUCATIONAL** project that provides utilities for DreamBooth LoRA training for [Stable Diffusion 3 (SD3)](ttps://huggingface.co/papers/2403.03206) under 16GB GPU VRAM. This means you can successfully try out this project using a [free-tier Colab Notebook](https://colab.research.google.com/github/huggingface/diffusers/blob/main/examples/research_projects/sd3_lora_colab/sd3_dreambooth_lora_16gb.ipynb) instance. 🤗
+
+> [!NOTE]
+> SD3 is gated, so you need to make sure you agree to [share your contact info](https://huggingface.co/stabilityai/stable-diffusion-3-medium-diffusers) to access the model before using it with Diffusers. Once you have access, you need to log in so your system knows you’re authorized. Use the command below to log in:
+
+```bash
+huggingface-cli login
+```
+
+This will also allow us to push the trained model parameters to the Hugging Face Hub platform.
+
+For setup, inference code, and details on how to run the code, please follow the Colab Notebook provided above.
+
+## How
+
+We make use of several techniques to make this possible:
+
+* Compute the embeddings from the instance prompt and serialize them for later reuse. This is implemented in the [`compute_embeddings.py`](./compute_embeddings.py) script. We use an 8bit (as introduced in [`LLM.int8()`](https://arxiv.org/abs/2208.07339)) T5 to reduce memory requirements to ~10.5GB.
+* In the `train_dreambooth_sd3_lora_miniature.py` script, we make use of:
+ * 8bit Adam for optimization through the `bitsandbytes` library.
+ * Gradient checkpointing and gradient accumulation.
+ * FP16 precision.
+ * Flash attention through `F.scaled_dot_product_attention()`.
+
+Computing the text embeddings is arguably the most memory-intensive part in the pipeline as SD3 employs three text encoders. If we run them in FP32, it will take about 20GB of VRAM. With FP16, we are down to 12GB.
+
+
+## Gotchas
+
+This project is educational. It exists to showcase the possibility of fine-tuning a big diffusion system on consumer GPUs. But additional components might have to be added to obtain state-of-the-art performance. Below are some commonly known gotchas that users should be aware of:
+
+* Training of text encoders is purposefully disabled.
+* Techniques such as prior-preservation is unsupported.
+* Custom instance captions for instance images are unsupported, but this should be relatively easy to integrate.
+
+Hopefully, this project gives you a template to extend it further to suit your needs.
diff --git a/diffusers/examples/research_projects/sd3_lora_colab/compute_embeddings.py b/diffusers/examples/research_projects/sd3_lora_colab/compute_embeddings.py
new file mode 100644
index 0000000000000000000000000000000000000000..5014752ffe349863d4d90fd95445ebcebd730c24
--- /dev/null
+++ b/diffusers/examples/research_projects/sd3_lora_colab/compute_embeddings.py
@@ -0,0 +1,123 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import glob
+import hashlib
+
+import pandas as pd
+import torch
+from transformers import T5EncoderModel
+
+from diffusers import StableDiffusion3Pipeline
+
+
+PROMPT = "a photo of sks dog"
+MAX_SEQ_LENGTH = 77
+LOCAL_DATA_DIR = "dog"
+OUTPUT_PATH = "sample_embeddings.parquet"
+
+
+def bytes_to_giga_bytes(bytes):
+ return bytes / 1024 / 1024 / 1024
+
+
+def generate_image_hash(image_path):
+ with open(image_path, "rb") as f:
+ img_data = f.read()
+ return hashlib.sha256(img_data).hexdigest()
+
+
+def load_sd3_pipeline():
+ id = "stabilityai/stable-diffusion-3-medium-diffusers"
+ text_encoder = T5EncoderModel.from_pretrained(id, subfolder="text_encoder_3", load_in_8bit=True, device_map="auto")
+ pipeline = StableDiffusion3Pipeline.from_pretrained(
+ id, text_encoder_3=text_encoder, transformer=None, vae=None, device_map="balanced"
+ )
+ return pipeline
+
+
+@torch.no_grad()
+def compute_embeddings(pipeline, prompt, max_sequence_length):
+ (
+ prompt_embeds,
+ negative_prompt_embeds,
+ pooled_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ ) = pipeline.encode_prompt(prompt=prompt, prompt_2=None, prompt_3=None, max_sequence_length=max_sequence_length)
+
+ print(
+ f"{prompt_embeds.shape=}, {negative_prompt_embeds.shape=}, {pooled_prompt_embeds.shape=}, {negative_pooled_prompt_embeds.shape}"
+ )
+
+ max_memory = bytes_to_giga_bytes(torch.cuda.max_memory_allocated())
+ print(f"Max memory allocated: {max_memory:.3f} GB")
+ return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
+
+
+def run(args):
+ pipeline = load_sd3_pipeline()
+ prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds = compute_embeddings(
+ pipeline, args.prompt, args.max_sequence_length
+ )
+
+ # Assumes that the images within `args.local_image_dir` have a JPEG extension. Change
+ # as needed.
+ image_paths = glob.glob(f"{args.local_data_dir}/*.jpeg")
+ data = []
+ for image_path in image_paths:
+ img_hash = generate_image_hash(image_path)
+ data.append(
+ (img_hash, prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds)
+ )
+
+ # Create a DataFrame
+ embedding_cols = [
+ "prompt_embeds",
+ "negative_prompt_embeds",
+ "pooled_prompt_embeds",
+ "negative_pooled_prompt_embeds",
+ ]
+ df = pd.DataFrame(
+ data,
+ columns=["image_hash"] + embedding_cols,
+ )
+
+ # Convert embedding lists to arrays (for proper storage in parquet)
+ for col in embedding_cols:
+ df[col] = df[col].apply(lambda x: x.cpu().numpy().flatten().tolist())
+
+ # Save the dataframe to a parquet file
+ df.to_parquet(args.output_path)
+ print(f"Data successfully serialized to {args.output_path}")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--prompt", type=str, default=PROMPT, help="The instance prompt.")
+ parser.add_argument(
+ "--max_sequence_length",
+ type=int,
+ default=MAX_SEQ_LENGTH,
+ help="Maximum sequence length to use for computing the embeddings. The more the higher computational costs.",
+ )
+ parser.add_argument(
+ "--local_data_dir", type=str, default=LOCAL_DATA_DIR, help="Path to the directory containing instance images."
+ )
+ parser.add_argument("--output_path", type=str, default=OUTPUT_PATH, help="Path to serialize the parquet file.")
+ args = parser.parse_args()
+
+ run(args)
diff --git a/diffusers/examples/research_projects/sd3_lora_colab/sd3_dreambooth_lora_16gb.ipynb b/diffusers/examples/research_projects/sd3_lora_colab/sd3_dreambooth_lora_16gb.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..25fcb36a47d56214a5ed5cffbe60368d7202a041
--- /dev/null
+++ b/diffusers/examples/research_projects/sd3_lora_colab/sd3_dreambooth_lora_16gb.ipynb
@@ -0,0 +1,2428 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "a6xLZDgOajbd"
+ },
+ "source": [
+ "# Running Stable Diffusion 3 (SD3) DreamBooth LoRA training under 16GB GPU VRAM"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "0jPZpMTwafua"
+ },
+ "source": [
+ "## Install Dependencies"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "lIYdn1woOS1n",
+ "outputId": "6d4a6332-d1f5-46e2-ad2b-c9e51b9f279a"
+ },
+ "outputs": [],
+ "source": [
+ "!pip install -q -U git+https://github.com/huggingface/diffusers\n",
+ "!pip install -q -U \\\n",
+ " transformers \\\n",
+ " accelerate \\\n",
+ " wandb \\\n",
+ " bitsandbytes \\\n",
+ " peft"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "5qUNciw6aov2"
+ },
+ "source": [
+ "As SD3 is gated, before using it with diffusers you first need to go to the [Stable Diffusion 3 Medium Hugging Face page](https://huggingface.co/stabilityai/stable-diffusion-3-medium-diffusers), fill in the form and accept the gate. Once you are in, you need to log in so that your system knows you’ve accepted the gate. Use the command below to log in:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "Bpk5FleeK1NR",
+ "outputId": "54d8e774-514e-46fe-b9a7-0185e0bcf211"
+ },
+ "outputs": [],
+ "source": [
+ "!huggingface-cli login"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "tcF7gl4FasJV"
+ },
+ "source": [
+ "## Clone `diffusers`"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "QgSOJYglJKiM",
+ "outputId": "be51f30f-8848-4a79-ae91-c4fb89c244ba"
+ },
+ "outputs": [],
+ "source": [
+ "!git clone https://github.com/huggingface/diffusers\n",
+ "%cd diffusers/examples/research_projects/sd3_lora_colab"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "X9dBawr6ayRY"
+ },
+ "source": [
+ "## Download instance data images"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 351,
+ "referenced_widgets": [
+ "8720a1f0a3b043dba02b6aab0afb861a",
+ "0e70a30146ef4b30b014179bd4bfd131",
+ "c39072b8cfff4a11ba283a9ae3155e52",
+ "1e834badd9c74b95bda30456a585fc06",
+ "d7c7c83b341b4471ad8a0ca1fe76d9ff",
+ "5ab639bd765f4824818a53ab84f690a8",
+ "cd94205b05d54e4c96c9c475e13abe83",
+ "be260274fdb04798af6fce6169646ff2",
+ "b9912757b9f9477186c171ecb2551d3a",
+ "a1f88f8e27894cdfab54cad04871f74e",
+ "19026c269dce47d585506f734fa2981a",
+ "50237341e55e4da0ba5cdbf652f30115",
+ "1d006f25b17e4cd8aaa5f66d58940dc7",
+ "4c673aa247ff4d65b82c4c64ca2e72da",
+ "92698388b667476ea1daf5cacb2fdb07",
+ "f6b3aa0f980e450289ee15cea7cb3ed7",
+ "0690a95eb8c3403e90d5b023aaadb22c",
+ "4a20ceca22724ab082d013c20c758d31",
+ "c80f3825981646a8a3e178192e338962",
+ "5673d0ca1f1247dd924874355eadecd4",
+ "7774ac850ab2451ea380bf80f3be5a86",
+ "22e57b8c83fa48489d6a327f1bbb756b",
+ "dd2debcf1c774181bef97efab0f3d6e1",
+ "633d7df9a17e4bf6951249aca83a9e96",
+ "6469e9991d7b41a0b83a7b443b9eebe5",
+ "0b9c72fa39c241ba9dd22dd67c2436fe",
+ "99e707cfe1454757aad4014230f6dae8",
+ "5a4ec2d031fa438eb4d0492329b28f00",
+ "6c0d4d6d84704f88b46a9b5cf94e3836",
+ "e1fb8dec23c04d6f8d1217242f8a495c",
+ "4b35f9d8d6444d0397a8bafcf3d73e8f",
+ "0f3279a4e6a247a7b69ff73bc06acfe0",
+ "b5ac4ab9256e4d5092ba6e449bc3cdd3",
+ "2840e90c518d4666b3f5a935c90569a7",
+ "adb012e95d7d442a820680e61e615e3c",
+ "be4fd10d940d49cf8e916904da8192ab",
+ "fd93adba791f46c1b0a25ff692426149",
+ "cdee3b61ca6a487c8ec8e7e884eb8b07",
+ "190a7fbe2b554104a6d5b2caa3b0a08e",
+ "975922b877e143edb09cdb888cb7cae8",
+ "d7365b62df59406dbd38677299cce1c8",
+ "67f0f5f1179140b4bdaa74c5583e3958",
+ "e560f25c3e334cf2a4c748981ac38da6",
+ "65173381a80b40748b7b2800fdb89151",
+ "7a4c5c0acd2d400e91da611e91ff5306",
+ "2a02c69a19a741b4a032dedbc21ad088",
+ "c6211ddb71e64f9e92c70158da2f7ef1",
+ "c219a1e791894469aa1452045b0f74b5",
+ "8e881fd3a17e4a5d95e71f6411ed8167",
+ "5350f001bf774b5fb7f3e362f912bec3",
+ "a893c93bcbc444a4931d1ddc6e342786",
+ "03047a13f06744fcac17c77cb03bca62",
+ "4d77a9c44d1c47b18022f8a51b29e20d",
+ "0b5bb94394fc447282d2c44780303f15",
+ "01bfa49325a8403b808ad1662465996e",
+ "3c0f67144f974aea85c7088f482e830d",
+ "0996e9f698dc4d6ab3c4319c96186619",
+ "9c663933ece544b193531725f4dc873d",
+ "b48dc06ca1654fe39bf8a77352a921f2",
+ "641f5a361a584cc0b71fd71d5f786958",
+ "66a952451b4c43cab54312fe886df5e6",
+ "42757924240c4abeb39add0c26687ab3",
+ "7f26ae5417cf4c80921cce830e60f72b",
+ "b77093ec9ffd40e0b2c1d9d1bdc063f5",
+ "7d8e3510c1e34524993849b8fce52758",
+ "b15011804460483ab84904a52da754b7"
+ ]
+ },
+ "id": "La1rBYWFNjEP",
+ "outputId": "e8567843-193e-4653-86b8-be26390700df"
+ },
+ "outputs": [],
+ "source": [
+ "from huggingface_hub import snapshot_download\n",
+ "\n",
+ "local_dir = \"./dog\"\n",
+ "snapshot_download(\n",
+ " \"diffusers/dog-example\",\n",
+ " local_dir=local_dir, repo_type=\"dataset\",\n",
+ " ignore_patterns=\".gitattributes\",\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "hbsIzdjbOzgi"
+ },
+ "outputs": [],
+ "source": [
+ "!rm -rf dog/.huggingface"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "88sOTn2ga07q"
+ },
+ "source": [
+ "## Compute embeddings\n",
+ "\n",
+ "Here we are using the default instance prompt \"a photo of sks dog\". But you can configure this. Refer to the `compute_embeddings.py` script for details on other supported arguments."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "ha6hPLpHLM8c",
+ "outputId": "82843eb0-473e-4d6b-d11d-1f79bcd1d11a"
+ },
+ "outputs": [],
+ "source": [
+ "!python compute_embeddings.py"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "10iMo-RUa_yv"
+ },
+ "source": [
+ "## Clear memory"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "-YltRmPgMuNa"
+ },
+ "outputs": [],
+ "source": [
+ "import torch\n",
+ "import gc\n",
+ "\n",
+ "\n",
+ "def flush():\n",
+ " torch.cuda.empty_cache()\n",
+ " gc.collect()\n",
+ "\n",
+ "flush()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "UO5oEtOJbBS9"
+ },
+ "source": [
+ "## Train!"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "HuJ6hdm2M4Aw",
+ "outputId": "0b2d8ca3-c65f-4bb4-af9a-809b77116510"
+ },
+ "outputs": [],
+ "source": [
+ "!accelerate launch train_dreambooth_lora_sd3_miniature.py \\\n",
+ " --pretrained_model_name_or_path=\"stabilityai/stable-diffusion-3-medium-diffusers\" \\\n",
+ " --instance_data_dir=\"dog\" \\\n",
+ " --data_df_path=\"sample_embeddings.parquet\" \\\n",
+ " --output_dir=\"trained-sd3-lora-miniature\" \\\n",
+ " --mixed_precision=\"fp16\" \\\n",
+ " --instance_prompt=\"a photo of sks dog\" \\\n",
+ " --resolution=1024 \\\n",
+ " --train_batch_size=1 \\\n",
+ " --gradient_accumulation_steps=4 --gradient_checkpointing \\\n",
+ " --use_8bit_adam \\\n",
+ " --learning_rate=1e-4 \\\n",
+ " --report_to=\"wandb\" \\\n",
+ " --lr_scheduler=\"constant\" \\\n",
+ " --lr_warmup_steps=0 \\\n",
+ " --max_train_steps=500 \\\n",
+ " --seed=\"0\""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "itS-dsJ0gjy3"
+ },
+ "source": [
+ "Training will take about an hour to complete depending on the length of your dataset."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "BpOuL7S1bI6j"
+ },
+ "source": [
+ "## Inference"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "clfMv4jKfQzb"
+ },
+ "outputs": [],
+ "source": [
+ "flush()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "np03SXHkbKpG"
+ },
+ "outputs": [],
+ "source": [
+ "from diffusers import DiffusionPipeline\n",
+ "import torch\n",
+ "\n",
+ "pipeline = DiffusionPipeline.from_pretrained(\n",
+ " \"stabilityai/stable-diffusion-3-medium-diffusers\",\n",
+ " torch_dtype=torch.float16\n",
+ ")\n",
+ "lora_output_path = \"trained-sd3-lora-miniature\"\n",
+ "pipeline.load_lora_weights(\"trained-sd3-lora-miniature\")\n",
+ "\n",
+ "pipeline.enable_sequential_cpu_offload()\n",
+ "\n",
+ "image = pipeline(\"a photo of sks dog in a bucket\").images[0]\n",
+ "image.save(\"bucket_dog.png\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "HDfrY2opjGjD"
+ },
+ "source": [
+ "Note that inference will be very slow in this case because we're loading and unloading individual components of the models and that introduces significant data movement overhead. Refer to [this resource](https://huggingface.co/blog/sd3#memory-optimizations-for-sd3) for more memory optimization related techniques."
+ ]
+ }
+ ],
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "gpuType": "T4",
+ "provenance": []
+ },
+ "kernelspec": {
+ "display_name": "Python 3",
+ "name": "python3"
+ },
+ "widgets": {
+ "application/vnd.jupyter.widget-state+json": {
+ "01bfa49325a8403b808ad1662465996e": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "03047a13f06744fcac17c77cb03bca62": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "0690a95eb8c3403e90d5b023aaadb22c": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "0996e9f698dc4d6ab3c4319c96186619": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_66a952451b4c43cab54312fe886df5e6",
+ "placeholder": "",
+ "style": "IPY_MODEL_42757924240c4abeb39add0c26687ab3",
+ "value": "alvan-nee-bQaAJCbNq3g-unsplash.jpeg: 100%"
+ }
+ },
+ "0b5bb94394fc447282d2c44780303f15": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "0b9c72fa39c241ba9dd22dd67c2436fe": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_0f3279a4e6a247a7b69ff73bc06acfe0",
+ "placeholder": "",
+ "style": "IPY_MODEL_b5ac4ab9256e4d5092ba6e449bc3cdd3",
+ "value": " 1.16M/1.16M [00:00<00:00, 10.3MB/s]"
+ }
+ },
+ "0e70a30146ef4b30b014179bd4bfd131": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_5ab639bd765f4824818a53ab84f690a8",
+ "placeholder": "",
+ "style": "IPY_MODEL_cd94205b05d54e4c96c9c475e13abe83",
+ "value": "Fetching 5 files: 100%"
+ }
+ },
+ "0f3279a4e6a247a7b69ff73bc06acfe0": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "19026c269dce47d585506f734fa2981a": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "190a7fbe2b554104a6d5b2caa3b0a08e": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "1d006f25b17e4cd8aaa5f66d58940dc7": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_0690a95eb8c3403e90d5b023aaadb22c",
+ "placeholder": "",
+ "style": "IPY_MODEL_4a20ceca22724ab082d013c20c758d31",
+ "value": "alvan-nee-9M0tSjb-cpA-unsplash.jpeg: 100%"
+ }
+ },
+ "1e834badd9c74b95bda30456a585fc06": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_a1f88f8e27894cdfab54cad04871f74e",
+ "placeholder": "",
+ "style": "IPY_MODEL_19026c269dce47d585506f734fa2981a",
+ "value": " 5/5 [00:01<00:00, 3.44it/s]"
+ }
+ },
+ "22e57b8c83fa48489d6a327f1bbb756b": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "2840e90c518d4666b3f5a935c90569a7": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HBoxModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_adb012e95d7d442a820680e61e615e3c",
+ "IPY_MODEL_be4fd10d940d49cf8e916904da8192ab",
+ "IPY_MODEL_fd93adba791f46c1b0a25ff692426149"
+ ],
+ "layout": "IPY_MODEL_cdee3b61ca6a487c8ec8e7e884eb8b07"
+ }
+ },
+ "2a02c69a19a741b4a032dedbc21ad088": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_5350f001bf774b5fb7f3e362f912bec3",
+ "placeholder": "",
+ "style": "IPY_MODEL_a893c93bcbc444a4931d1ddc6e342786",
+ "value": "alvan-nee-eoqnr8ikwFE-unsplash.jpeg: 100%"
+ }
+ },
+ "3c0f67144f974aea85c7088f482e830d": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HBoxModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_0996e9f698dc4d6ab3c4319c96186619",
+ "IPY_MODEL_9c663933ece544b193531725f4dc873d",
+ "IPY_MODEL_b48dc06ca1654fe39bf8a77352a921f2"
+ ],
+ "layout": "IPY_MODEL_641f5a361a584cc0b71fd71d5f786958"
+ }
+ },
+ "42757924240c4abeb39add0c26687ab3": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "4a20ceca22724ab082d013c20c758d31": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "4b35f9d8d6444d0397a8bafcf3d73e8f": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "ProgressStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "ProgressStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "bar_color": null,
+ "description_width": ""
+ }
+ },
+ "4c673aa247ff4d65b82c4c64ca2e72da": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "FloatProgressModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "FloatProgressModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "ProgressView",
+ "bar_style": "success",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_c80f3825981646a8a3e178192e338962",
+ "max": 677407,
+ "min": 0,
+ "orientation": "horizontal",
+ "style": "IPY_MODEL_5673d0ca1f1247dd924874355eadecd4",
+ "value": 677407
+ }
+ },
+ "4d77a9c44d1c47b18022f8a51b29e20d": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "ProgressStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "ProgressStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "bar_color": null,
+ "description_width": ""
+ }
+ },
+ "50237341e55e4da0ba5cdbf652f30115": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HBoxModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_1d006f25b17e4cd8aaa5f66d58940dc7",
+ "IPY_MODEL_4c673aa247ff4d65b82c4c64ca2e72da",
+ "IPY_MODEL_92698388b667476ea1daf5cacb2fdb07"
+ ],
+ "layout": "IPY_MODEL_f6b3aa0f980e450289ee15cea7cb3ed7"
+ }
+ },
+ "5350f001bf774b5fb7f3e362f912bec3": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "5673d0ca1f1247dd924874355eadecd4": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "ProgressStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "ProgressStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "bar_color": null,
+ "description_width": ""
+ }
+ },
+ "5a4ec2d031fa438eb4d0492329b28f00": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "5ab639bd765f4824818a53ab84f690a8": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "633d7df9a17e4bf6951249aca83a9e96": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_5a4ec2d031fa438eb4d0492329b28f00",
+ "placeholder": "",
+ "style": "IPY_MODEL_6c0d4d6d84704f88b46a9b5cf94e3836",
+ "value": "alvan-nee-Id1DBHv4fbg-unsplash.jpeg: 100%"
+ }
+ },
+ "641f5a361a584cc0b71fd71d5f786958": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "6469e9991d7b41a0b83a7b443b9eebe5": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "FloatProgressModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "FloatProgressModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "ProgressView",
+ "bar_style": "success",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_e1fb8dec23c04d6f8d1217242f8a495c",
+ "max": 1163467,
+ "min": 0,
+ "orientation": "horizontal",
+ "style": "IPY_MODEL_4b35f9d8d6444d0397a8bafcf3d73e8f",
+ "value": 1163467
+ }
+ },
+ "65173381a80b40748b7b2800fdb89151": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "66a952451b4c43cab54312fe886df5e6": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "67f0f5f1179140b4bdaa74c5583e3958": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "ProgressStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "ProgressStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "bar_color": null,
+ "description_width": ""
+ }
+ },
+ "6c0d4d6d84704f88b46a9b5cf94e3836": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "7774ac850ab2451ea380bf80f3be5a86": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "7a4c5c0acd2d400e91da611e91ff5306": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HBoxModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_2a02c69a19a741b4a032dedbc21ad088",
+ "IPY_MODEL_c6211ddb71e64f9e92c70158da2f7ef1",
+ "IPY_MODEL_c219a1e791894469aa1452045b0f74b5"
+ ],
+ "layout": "IPY_MODEL_8e881fd3a17e4a5d95e71f6411ed8167"
+ }
+ },
+ "7d8e3510c1e34524993849b8fce52758": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "7f26ae5417cf4c80921cce830e60f72b": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "8720a1f0a3b043dba02b6aab0afb861a": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HBoxModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_0e70a30146ef4b30b014179bd4bfd131",
+ "IPY_MODEL_c39072b8cfff4a11ba283a9ae3155e52",
+ "IPY_MODEL_1e834badd9c74b95bda30456a585fc06"
+ ],
+ "layout": "IPY_MODEL_d7c7c83b341b4471ad8a0ca1fe76d9ff"
+ }
+ },
+ "8e881fd3a17e4a5d95e71f6411ed8167": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "92698388b667476ea1daf5cacb2fdb07": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_7774ac850ab2451ea380bf80f3be5a86",
+ "placeholder": "",
+ "style": "IPY_MODEL_22e57b8c83fa48489d6a327f1bbb756b",
+ "value": " 677k/677k [00:00<00:00, 9.50MB/s]"
+ }
+ },
+ "975922b877e143edb09cdb888cb7cae8": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "99e707cfe1454757aad4014230f6dae8": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "9c663933ece544b193531725f4dc873d": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "FloatProgressModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "FloatProgressModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "ProgressView",
+ "bar_style": "success",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_7f26ae5417cf4c80921cce830e60f72b",
+ "max": 1396297,
+ "min": 0,
+ "orientation": "horizontal",
+ "style": "IPY_MODEL_b77093ec9ffd40e0b2c1d9d1bdc063f5",
+ "value": 1396297
+ }
+ },
+ "a1f88f8e27894cdfab54cad04871f74e": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "a893c93bcbc444a4931d1ddc6e342786": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "adb012e95d7d442a820680e61e615e3c": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_190a7fbe2b554104a6d5b2caa3b0a08e",
+ "placeholder": "",
+ "style": "IPY_MODEL_975922b877e143edb09cdb888cb7cae8",
+ "value": "alvan-nee-brFsZ7qszSY-unsplash.jpeg: 100%"
+ }
+ },
+ "b15011804460483ab84904a52da754b7": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "b48dc06ca1654fe39bf8a77352a921f2": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_7d8e3510c1e34524993849b8fce52758",
+ "placeholder": "",
+ "style": "IPY_MODEL_b15011804460483ab84904a52da754b7",
+ "value": " 1.40M/1.40M [00:00<00:00, 10.5MB/s]"
+ }
+ },
+ "b5ac4ab9256e4d5092ba6e449bc3cdd3": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "b77093ec9ffd40e0b2c1d9d1bdc063f5": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "ProgressStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "ProgressStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "bar_color": null,
+ "description_width": ""
+ }
+ },
+ "b9912757b9f9477186c171ecb2551d3a": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "ProgressStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "ProgressStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "bar_color": null,
+ "description_width": ""
+ }
+ },
+ "be260274fdb04798af6fce6169646ff2": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "be4fd10d940d49cf8e916904da8192ab": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "FloatProgressModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "FloatProgressModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "ProgressView",
+ "bar_style": "success",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_d7365b62df59406dbd38677299cce1c8",
+ "max": 1186464,
+ "min": 0,
+ "orientation": "horizontal",
+ "style": "IPY_MODEL_67f0f5f1179140b4bdaa74c5583e3958",
+ "value": 1186464
+ }
+ },
+ "c219a1e791894469aa1452045b0f74b5": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_0b5bb94394fc447282d2c44780303f15",
+ "placeholder": "",
+ "style": "IPY_MODEL_01bfa49325a8403b808ad1662465996e",
+ "value": " 1.17M/1.17M [00:00<00:00, 15.7MB/s]"
+ }
+ },
+ "c39072b8cfff4a11ba283a9ae3155e52": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "FloatProgressModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "FloatProgressModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "ProgressView",
+ "bar_style": "success",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_be260274fdb04798af6fce6169646ff2",
+ "max": 5,
+ "min": 0,
+ "orientation": "horizontal",
+ "style": "IPY_MODEL_b9912757b9f9477186c171ecb2551d3a",
+ "value": 5
+ }
+ },
+ "c6211ddb71e64f9e92c70158da2f7ef1": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "FloatProgressModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "FloatProgressModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "ProgressView",
+ "bar_style": "success",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_03047a13f06744fcac17c77cb03bca62",
+ "max": 1167042,
+ "min": 0,
+ "orientation": "horizontal",
+ "style": "IPY_MODEL_4d77a9c44d1c47b18022f8a51b29e20d",
+ "value": 1167042
+ }
+ },
+ "c80f3825981646a8a3e178192e338962": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "cd94205b05d54e4c96c9c475e13abe83": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "cdee3b61ca6a487c8ec8e7e884eb8b07": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "d7365b62df59406dbd38677299cce1c8": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "d7c7c83b341b4471ad8a0ca1fe76d9ff": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "dd2debcf1c774181bef97efab0f3d6e1": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HBoxModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_633d7df9a17e4bf6951249aca83a9e96",
+ "IPY_MODEL_6469e9991d7b41a0b83a7b443b9eebe5",
+ "IPY_MODEL_0b9c72fa39c241ba9dd22dd67c2436fe"
+ ],
+ "layout": "IPY_MODEL_99e707cfe1454757aad4014230f6dae8"
+ }
+ },
+ "e1fb8dec23c04d6f8d1217242f8a495c": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "e560f25c3e334cf2a4c748981ac38da6": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "f6b3aa0f980e450289ee15cea7cb3ed7": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "fd93adba791f46c1b0a25ff692426149": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_e560f25c3e334cf2a4c748981ac38da6",
+ "placeholder": "",
+ "style": "IPY_MODEL_65173381a80b40748b7b2800fdb89151",
+ "value": " 1.19M/1.19M [00:00<00:00, 12.7MB/s]"
+ }
+ }
+ }
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}
diff --git a/diffusers/examples/research_projects/sd3_lora_colab/train_dreambooth_lora_sd3_miniature.py b/diffusers/examples/research_projects/sd3_lora_colab/train_dreambooth_lora_sd3_miniature.py
new file mode 100644
index 0000000000000000000000000000000000000000..163ff8f08931ab3eaffa9e23049c8060d3455507
--- /dev/null
+++ b/diffusers/examples/research_projects/sd3_lora_colab/train_dreambooth_lora_sd3_miniature.py
@@ -0,0 +1,1147 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import copy
+import gc
+import hashlib
+import logging
+import math
+import os
+import random
+import shutil
+from contextlib import nullcontext
+from pathlib import Path
+
+import numpy as np
+import pandas as pd
+import torch
+import torch.utils.checkpoint
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
+from huggingface_hub import create_repo, upload_folder
+from peft import LoraConfig, set_peft_model_state_dict
+from peft.utils import get_peft_model_state_dict
+from PIL import Image
+from PIL.ImageOps import exif_transpose
+from torch.utils.data import Dataset
+from torchvision import transforms
+from torchvision.transforms.functional import crop
+from tqdm.auto import tqdm
+
+import diffusers
+from diffusers import (
+ AutoencoderKL,
+ FlowMatchEulerDiscreteScheduler,
+ SD3Transformer2DModel,
+ StableDiffusion3Pipeline,
+)
+from diffusers.optimization import get_scheduler
+from diffusers.training_utils import (
+ cast_training_params,
+ compute_density_for_timestep_sampling,
+ compute_loss_weighting_for_sd3,
+)
+from diffusers.utils import (
+ check_min_version,
+ convert_unet_state_dict_to_peft,
+ is_wandb_available,
+)
+from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
+from diffusers.utils.torch_utils import is_compiled_module
+
+
+if is_wandb_available():
+ import wandb
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.30.0.dev0")
+
+logger = get_logger(__name__)
+
+
+def save_model_card(
+ repo_id: str,
+ images=None,
+ base_model: str = None,
+ train_text_encoder=False,
+ instance_prompt=None,
+ validation_prompt=None,
+ repo_folder=None,
+):
+ widget_dict = []
+ if images is not None:
+ for i, image in enumerate(images):
+ image.save(os.path.join(repo_folder, f"image_{i}.png"))
+ widget_dict.append(
+ {"text": validation_prompt if validation_prompt else " ", "output": {"url": f"image_{i}.png"}}
+ )
+
+ model_description = f"""
+# SD3 DreamBooth LoRA - {repo_id}
+
+
+
+## Model description
+
+These are {repo_id} DreamBooth weights for {base_model}.
+
+The weights were trained using [DreamBooth](https://dreambooth.github.io/).
+
+LoRA for the text encoder was enabled: {train_text_encoder}.
+
+## Trigger words
+
+You should use {instance_prompt} to trigger the image generation.
+
+## Download model
+
+[Download]({repo_id}/tree/main) them in the Files & versions tab.
+
+## License
+
+Please adhere to the licensing terms as described [here](https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/LICENSE).
+"""
+ model_card = load_or_create_model_card(
+ repo_id_or_path=repo_id,
+ from_training=True,
+ license="openrail++",
+ base_model=base_model,
+ prompt=instance_prompt,
+ model_description=model_description,
+ widget=widget_dict,
+ )
+ tags = [
+ "text-to-image",
+ "diffusers-training",
+ "diffusers",
+ "lora",
+ "sd3",
+ "sd3-diffusers",
+ "template:sd-lora",
+ ]
+
+ model_card = populate_model_card(model_card, tags=tags)
+ model_card.save(os.path.join(repo_folder, "README.md"))
+
+
+def log_validation(
+ pipeline,
+ args,
+ accelerator,
+ pipeline_args,
+ epoch,
+ is_final_validation=False,
+):
+ logger.info(
+ f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
+ f" {args.validation_prompt}."
+ )
+ pipeline.enable_model_cpu_offload()
+ pipeline.set_progress_bar_config(disable=True)
+
+ # run inference
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
+ # autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()
+ autocast_ctx = nullcontext()
+
+ with autocast_ctx:
+ images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)]
+
+ for tracker in accelerator.trackers:
+ phase_name = "test" if is_final_validation else "validation"
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images(phase_name, np_images, epoch, dataformats="NHWC")
+ if tracker.name == "wandb":
+ tracker.log(
+ {
+ phase_name: [
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images)
+ ]
+ }
+ )
+
+ del pipeline
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+
+ return images
+
+
+def parse_args(input_args=None):
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--variant",
+ type=str,
+ default=None,
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
+ )
+ parser.add_argument(
+ "--instance_data_dir",
+ type=str,
+ default=None,
+ help=("A folder containing the training data. "),
+ )
+ parser.add_argument(
+ "--data_df_path",
+ type=str,
+ default=None,
+ help=("Path to the parquet file serialized with compute_embeddings.py."),
+ )
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+ parser.add_argument(
+ "--instance_prompt",
+ type=str,
+ default=None,
+ required=True,
+ help="The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'",
+ )
+ parser.add_argument(
+ "--max_sequence_length",
+ type=int,
+ default=77,
+ help="Maximum sequence length to use with with the T5 text encoder",
+ )
+ parser.add_argument(
+ "--validation_prompt",
+ type=str,
+ default=None,
+ help="A prompt that is used during validation to verify that the model is learning.",
+ )
+ parser.add_argument(
+ "--num_validation_images",
+ type=int,
+ default=4,
+ help="Number of images that should be generated during validation with `validation_prompt`.",
+ )
+ parser.add_argument(
+ "--validation_epochs",
+ type=int,
+ default=50,
+ help=(
+ "Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt"
+ " `args.validation_prompt` multiple times: `args.num_validation_images`."
+ ),
+ )
+ parser.add_argument(
+ "--rank",
+ type=int,
+ default=4,
+ help=("The dimension of the LoRA update matrices."),
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="sd3-dreambooth-lora",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--center_crop",
+ default=False,
+ action="store_true",
+ help=(
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
+ " cropped. The images will be resized to the resolution first before cropping."
+ ),
+ )
+ parser.add_argument(
+ "--random_flip",
+ action="store_true",
+ help="whether to randomly flip images horizontally",
+ )
+
+ parser.add_argument(
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=1)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
+ " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--lr_num_cycles",
+ type=int,
+ default=1,
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
+ )
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ parser.add_argument(
+ "--weighting_scheme",
+ type=str,
+ default="logit_normal",
+ choices=["sigma_sqrt", "logit_normal", "mode", "cosmap"],
+ )
+ parser.add_argument(
+ "--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme."
+ )
+ parser.add_argument(
+ "--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme."
+ )
+ parser.add_argument(
+ "--mode_scale",
+ type=float,
+ default=1.29,
+ help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.",
+ )
+ parser.add_argument(
+ "--optimizer",
+ type=str,
+ default="AdamW",
+ help=('The optimizer type to use. Choose between ["AdamW"]'),
+ )
+
+ parser.add_argument(
+ "--use_8bit_adam",
+ action="store_true",
+ help="Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW",
+ )
+
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params")
+
+ parser.add_argument(
+ "--adam_epsilon",
+ type=float,
+ default=1e-08,
+ help="Epsilon value for the Adam optimizer.",
+ )
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument(
+ "--prior_generation_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp32", "fp16", "bf16"],
+ help=(
+ "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32."
+ ),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+
+ if input_args is not None:
+ args = parser.parse_args(input_args)
+ else:
+ args = parser.parse_args()
+
+ if args.instance_data_dir is None:
+ raise ValueError("Specify `instance_data_dir`.")
+
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ return args
+
+
+class DreamBoothDataset(Dataset):
+ """
+ A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
+ It pre-processes the images.
+ """
+
+ def __init__(
+ self,
+ data_df_path,
+ instance_data_root,
+ instance_prompt,
+ size=1024,
+ center_crop=False,
+ ):
+ # Logistics
+ self.size = size
+ self.center_crop = center_crop
+
+ self.instance_prompt = instance_prompt
+ self.instance_data_root = Path(instance_data_root)
+ if not self.instance_data_root.exists():
+ raise ValueError("Instance images root doesn't exists.")
+
+ # Load images.
+ instance_images = [Image.open(path) for path in list(Path(instance_data_root).iterdir())]
+ image_hashes = [self.generate_image_hash(path) for path in list(Path(instance_data_root).iterdir())]
+ self.instance_images = instance_images
+ self.image_hashes = image_hashes
+
+ # Image transformations
+ self.pixel_values = self.apply_image_transformations(
+ instance_images=instance_images, size=size, center_crop=center_crop
+ )
+
+ # Map hashes to embeddings.
+ self.data_dict = self.map_image_hash_embedding(data_df_path=data_df_path)
+
+ self.num_instance_images = len(instance_images)
+ self._length = self.num_instance_images
+
+ def __len__(self):
+ return self._length
+
+ def __getitem__(self, index):
+ example = {}
+ instance_image = self.pixel_values[index % self.num_instance_images]
+ image_hash = self.image_hashes[index % self.num_instance_images]
+ prompt_embeds, pooled_prompt_embeds = self.data_dict[image_hash]
+ example["instance_images"] = instance_image
+ example["prompt_embeds"] = prompt_embeds
+ example["pooled_prompt_embeds"] = pooled_prompt_embeds
+ return example
+
+ def apply_image_transformations(self, instance_images, size, center_crop):
+ pixel_values = []
+
+ train_resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR)
+ train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size)
+ train_flip = transforms.RandomHorizontalFlip(p=1.0)
+ train_transforms = transforms.Compose(
+ [
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+ )
+ for image in instance_images:
+ image = exif_transpose(image)
+ if not image.mode == "RGB":
+ image = image.convert("RGB")
+ image = train_resize(image)
+ if args.random_flip and random.random() < 0.5:
+ # flip
+ image = train_flip(image)
+ if args.center_crop:
+ y1 = max(0, int(round((image.height - args.resolution) / 2.0)))
+ x1 = max(0, int(round((image.width - args.resolution) / 2.0)))
+ image = train_crop(image)
+ else:
+ y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution))
+ image = crop(image, y1, x1, h, w)
+ image = train_transforms(image)
+ pixel_values.append(image)
+
+ return pixel_values
+
+ def convert_to_torch_tensor(self, embeddings: list):
+ prompt_embeds = embeddings[0]
+ pooled_prompt_embeds = embeddings[1]
+ prompt_embeds = np.array(prompt_embeds).reshape(154, 4096)
+ pooled_prompt_embeds = np.array(pooled_prompt_embeds).reshape(2048)
+ return torch.from_numpy(prompt_embeds), torch.from_numpy(pooled_prompt_embeds)
+
+ def map_image_hash_embedding(self, data_df_path):
+ hashes_df = pd.read_parquet(data_df_path)
+ data_dict = {}
+ for i, row in hashes_df.iterrows():
+ embeddings = [row["prompt_embeds"], row["pooled_prompt_embeds"]]
+ prompt_embeds, pooled_prompt_embeds = self.convert_to_torch_tensor(embeddings=embeddings)
+ data_dict.update({row["image_hash"]: (prompt_embeds, pooled_prompt_embeds)})
+ return data_dict
+
+ def generate_image_hash(self, image_path):
+ with open(image_path, "rb") as f:
+ img_data = f.read()
+ return hashlib.sha256(img_data).hexdigest()
+
+
+def collate_fn(examples):
+ pixel_values = [example["instance_images"] for example in examples]
+ prompt_embeds = [example["prompt_embeds"] for example in examples]
+ pooled_prompt_embeds = [example["pooled_prompt_embeds"] for example in examples]
+
+ pixel_values = torch.stack(pixel_values)
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
+ prompt_embeds = torch.stack(prompt_embeds)
+ pooled_prompt_embeds = torch.stack(pooled_prompt_embeds)
+
+ batch = {
+ "pixel_values": pixel_values,
+ "prompt_embeds": prompt_embeds,
+ "pooled_prompt_embeds": pooled_prompt_embeds,
+ }
+ return batch
+
+
+def main(args):
+ if args.report_to == "wandb" and args.hub_token is not None:
+ raise ValueError(
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
+ " Please use `huggingface-cli login` to authenticate with the Hub."
+ )
+
+ if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
+ # due to pytorch#99272, MPS does not yet support bfloat16.
+ raise ValueError(
+ "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
+ )
+
+ logging_dir = Path(args.output_dir, args.logging_dir)
+
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+ kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ kwargs_handlers=[kwargs],
+ )
+
+ # Disable AMP for MPS.
+ if torch.backends.mps.is_available():
+ accelerator.native_amp = False
+
+ if args.report_to == "wandb":
+ if not is_wandb_available():
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name,
+ exist_ok=True,
+ ).repo_id
+
+ # Load scheduler and models
+ noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="scheduler"
+ )
+ noise_scheduler_copy = copy.deepcopy(noise_scheduler)
+ vae = AutoencoderKL.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="vae",
+ revision=args.revision,
+ variant=args.variant,
+ )
+ transformer = SD3Transformer2DModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="transformer", revision=args.revision, variant=args.variant
+ )
+
+ transformer.requires_grad_(False)
+ vae.requires_grad_(False)
+
+ # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora transformer) to half-precision
+ # as these weights are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:
+ # due to pytorch#99272, MPS does not yet support bfloat16.
+ raise ValueError(
+ "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
+ )
+
+ vae.to(accelerator.device, dtype=torch.float32)
+ transformer.to(accelerator.device, dtype=weight_dtype)
+
+ if args.gradient_checkpointing:
+ transformer.enable_gradient_checkpointing()
+
+ # now we will add new LoRA weights to the attention layers
+ transformer_lora_config = LoraConfig(
+ r=args.rank,
+ lora_alpha=args.rank,
+ init_lora_weights="gaussian",
+ target_modules=["to_k", "to_q", "to_v", "to_out.0"],
+ )
+ transformer.add_adapter(transformer_lora_config)
+
+ def unwrap_model(model):
+ model = accelerator.unwrap_model(model)
+ model = model._orig_mod if is_compiled_module(model) else model
+ return model
+
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
+ def save_model_hook(models, weights, output_dir):
+ if accelerator.is_main_process:
+ transformer_lora_layers_to_save = None
+ for model in models:
+ if isinstance(model, type(unwrap_model(transformer))):
+ transformer_lora_layers_to_save = get_peft_model_state_dict(model)
+ else:
+ raise ValueError(f"unexpected save model: {model.__class__}")
+
+ # make sure to pop weight so that corresponding model is not saved again
+ weights.pop()
+
+ StableDiffusion3Pipeline.save_lora_weights(
+ output_dir,
+ transformer_lora_layers=transformer_lora_layers_to_save,
+ )
+
+ def load_model_hook(models, input_dir):
+ transformer_ = None
+
+ while len(models) > 0:
+ model = models.pop()
+
+ if isinstance(model, type(unwrap_model(transformer))):
+ transformer_ = model
+ else:
+ raise ValueError(f"unexpected save model: {model.__class__}")
+
+ lora_state_dict = StableDiffusion3Pipeline.lora_state_dict(input_dir)
+
+ transformer_state_dict = {
+ f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")
+ }
+ transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
+ incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
+ if incompatible_keys is not None:
+ # check only for unexpected keys
+ unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
+ if unexpected_keys:
+ logger.warning(
+ f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
+ f" {unexpected_keys}. "
+ )
+
+ # Make sure the trainable params are in float32. This is again needed since the base models
+ # are in `weight_dtype`. More details:
+ # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
+ if args.mixed_precision == "fp16":
+ models = [transformer_]
+ # only upcast trainable parameters (LoRA) into fp32
+ cast_training_params(models)
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32 and torch.cuda.is_available():
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ # Make sure the trainable params are in float32.
+ if args.mixed_precision == "fp16":
+ models = [transformer]
+ # only upcast trainable parameters (LoRA) into fp32
+ cast_training_params(models, dtype=torch.float32)
+
+ # Optimization parameters
+ transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters()))
+ transformer_parameters_with_lr = {"params": transformer_lora_parameters, "lr": args.learning_rate}
+ params_to_optimize = [transformer_parameters_with_lr]
+
+ # Optimizer creation
+ if not args.optimizer.lower() == "adamw":
+ logger.warning(
+ f"Unsupported choice of optimizer: {args.optimizer}. Supported optimizers include [adamW]."
+ "Defaulting to adamW"
+ )
+ args.optimizer = "adamw"
+
+ if args.use_8bit_adam and not args.optimizer.lower() == "adamw":
+ logger.warning(
+ f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was "
+ f"set to {args.optimizer.lower()}"
+ )
+
+ if args.optimizer.lower() == "adamw":
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
+ )
+
+ optimizer_class = bnb.optim.AdamW8bit
+ else:
+ optimizer_class = torch.optim.AdamW
+
+ optimizer = optimizer_class(
+ params_to_optimize,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ # Dataset and DataLoaders creation:
+ train_dataset = DreamBoothDataset(
+ data_df_path=args.data_df_path,
+ instance_data_root=args.instance_data_dir,
+ instance_prompt=args.instance_prompt,
+ size=args.resolution,
+ center_crop=args.center_crop,
+ )
+
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset,
+ batch_size=args.train_batch_size,
+ shuffle=True,
+ collate_fn=lambda examples: collate_fn(examples),
+ num_workers=args.dataloader_num_workers,
+ )
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
+ num_cycles=args.lr_num_cycles,
+ power=args.lr_power,
+ )
+
+ # Prepare everything with our `accelerator`.
+ transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ transformer, optimizer, train_dataloader, lr_scheduler
+ )
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ tracker_name = "dreambooth-sd3-lora-miniature"
+ accelerator.init_trackers(tracker_name, config=vars(args))
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the mos recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ initial_global_step = 0
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ initial_global_step = global_step
+ first_epoch = global_step // num_update_steps_per_epoch
+
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
+
+ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
+ sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype)
+ schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device)
+ timesteps = timesteps.to(accelerator.device)
+ step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
+
+ sigma = sigmas[step_indices].flatten()
+ while len(sigma.shape) < n_dim:
+ sigma = sigma.unsqueeze(-1)
+ return sigma
+
+ for epoch in range(first_epoch, args.num_train_epochs):
+ transformer.train()
+
+ for step, batch in enumerate(train_dataloader):
+ models_to_accumulate = [transformer]
+ with accelerator.accumulate(models_to_accumulate):
+ pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
+
+ # Convert images to latent space
+ model_input = vae.encode(pixel_values).latent_dist.sample()
+ model_input = model_input * vae.config.scaling_factor
+ model_input = model_input.to(dtype=weight_dtype)
+
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(model_input)
+ bsz = model_input.shape[0]
+
+ # Sample a random timestep for each image
+ # for weighting schemes where we sample timesteps non-uniformly
+ u = compute_density_for_timestep_sampling(
+ weighting_scheme=args.weighting_scheme,
+ batch_size=bsz,
+ logit_mean=args.logit_mean,
+ logit_std=args.logit_std,
+ mode_scale=args.mode_scale,
+ )
+ indices = (u * noise_scheduler_copy.config.num_train_timesteps).long()
+ timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device)
+
+ # Add noise according to flow matching.
+ sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)
+ noisy_model_input = sigmas * noise + (1.0 - sigmas) * model_input
+
+ # Predict the noise residual
+ prompt_embeds, pooled_prompt_embeds = batch["prompt_embeds"], batch["pooled_prompt_embeds"]
+ prompt_embeds = prompt_embeds.to(device=accelerator.device, dtype=weight_dtype)
+ pooled_prompt_embeds = pooled_prompt_embeds.to(device=accelerator.device, dtype=weight_dtype)
+ model_pred = transformer(
+ hidden_states=noisy_model_input,
+ timestep=timesteps,
+ encoder_hidden_states=prompt_embeds,
+ pooled_projections=pooled_prompt_embeds,
+ return_dict=False,
+ )[0]
+
+ # Follow: Section 5 of https://arxiv.org/abs/2206.00364.
+ # Preconditioning of the model outputs.
+ model_pred = model_pred * (-sigmas) + noisy_model_input
+
+ # these weighting schemes use a uniform timestep sampling
+ # and instead post-weight the loss
+ weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
+
+ # flow matching loss
+ target = model_input
+
+ # Compute regular loss.
+ loss = torch.mean(
+ (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1),
+ 1,
+ )
+ loss = loss.mean()
+
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ params_to_clip = transformer_lora_parameters
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
+
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ global_step += 1
+
+ if accelerator.is_main_process:
+ if global_step % args.checkpointing_steps == 0:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+ accelerator.log(logs, step=global_step)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ if accelerator.is_main_process:
+ if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
+ pipeline = StableDiffusion3Pipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ vae=vae,
+ transformer=accelerator.unwrap_model(transformer),
+ revision=args.revision,
+ variant=args.variant,
+ torch_dtype=weight_dtype,
+ )
+ pipeline_args = {"prompt": args.validation_prompt}
+ images = log_validation(
+ pipeline=pipeline,
+ args=args,
+ accelerator=accelerator,
+ pipeline_args=pipeline_args,
+ epoch=epoch,
+ )
+ torch.cuda.empty_cache()
+ gc.collect()
+
+ # Save the lora layers
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ transformer = unwrap_model(transformer)
+ transformer = transformer.to(torch.float32)
+ transformer_lora_layers = get_peft_model_state_dict(transformer)
+
+ StableDiffusion3Pipeline.save_lora_weights(
+ save_directory=args.output_dir,
+ transformer_lora_layers=transformer_lora_layers,
+ )
+
+ # Final inference
+ # Load previous pipeline
+ pipeline = StableDiffusion3Pipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ revision=args.revision,
+ variant=args.variant,
+ torch_dtype=weight_dtype,
+ )
+ # load attention processors
+ pipeline.load_lora_weights(args.output_dir)
+
+ # run inference
+ images = []
+ if args.validation_prompt and args.num_validation_images > 0:
+ pipeline_args = {"prompt": args.validation_prompt}
+ images = log_validation(
+ pipeline=pipeline,
+ args=args,
+ accelerator=accelerator,
+ pipeline_args=pipeline_args,
+ epoch=epoch,
+ is_final_validation=True,
+ )
+
+ if args.push_to_hub:
+ save_model_card(
+ repo_id,
+ images=images,
+ base_model=args.pretrained_model_name_or_path,
+ instance_prompt=args.instance_prompt,
+ validation_prompt=args.validation_prompt,
+ repo_folder=args.output_dir,
+ )
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ main(args)
diff --git a/diffusers/examples/research_projects/sdxl_flax/README.md b/diffusers/examples/research_projects/sdxl_flax/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..dfbe90e63bde910eb40535e655c9b9845a89da4c
--- /dev/null
+++ b/diffusers/examples/research_projects/sdxl_flax/README.md
@@ -0,0 +1,243 @@
+# Stable Diffusion XL for JAX + TPUv5e
+
+[TPU v5e](https://cloud.google.com/blog/products/compute/how-cloud-tpu-v5e-accelerates-large-scale-ai-inference) is a new generation of TPUs from Google Cloud. It is the most cost-effective, versatile, and scalable Cloud TPU to date. This makes them ideal for serving and scaling large diffusion models.
+
+[JAX](https://github.com/google/jax) is a high-performance numerical computation library that is well-suited to develop and deploy diffusion models:
+
+- **High performance**. All JAX operations are implemented in terms of operations in [XLA](https://www.tensorflow.org/xla/) - the Accelerated Linear Algebra compiler
+
+- **Compilation**. JAX uses just-in-time (jit) compilation of JAX Python functions so it can be executed efficiently in XLA. In order to get the best performance, we must use static shapes for jitted functions, this is because JAX transforms work by tracing a function and to determine its effect on inputs of a specific shape and type. When a new shape is introduced to an already compiled function, it retriggers compilation on the new shape, which can greatly reduce performance. **Note**: JIT compilation is particularly well-suited for text-to-image generation because all inputs and outputs (image input / output sizes) are static.
+
+- **Parallelization**. Workloads can be scaled across multiple devices using JAX's [pmap](https://jax.readthedocs.io/en/latest/_autosummary/jax.pmap.html), which expresses single-program multiple-data (SPMD) programs. Applying pmap to a function will compile a function with XLA, then execute in parallel on XLA devices. For text-to-image generation workloads this means that increasing the number of images rendered simultaneously is straightforward to implement and doesn't compromise performance.
+
+👉 Try it out for yourself:
+
+[](https://huggingface.co/spaces/google/sdxl)
+
+## Stable Diffusion XL pipeline in JAX
+
+Upon having access to a TPU VM (TPUs higher than version 3), you should first install
+a TPU-compatible version of JAX:
+```sh
+pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
+```
+
+Next, we can install [flax](https://github.com/google/flax) and the diffusers library:
+
+```sh
+pip install flax diffusers transformers
+```
+
+In [sdxl_single.py](./sdxl_single.py) we give a simple example of how to write a text-to-image generation pipeline in JAX using [StabilityAI's Stable Diffusion XL](stabilityai/stable-diffusion-xl-base-1.0).
+
+Let's explain it step-by-step:
+
+**Imports and Setup**
+
+```python
+import jax
+import jax.numpy as jnp
+import numpy as np
+from flax.jax_utils import replicate
+from diffusers import FlaxStableDiffusionXLPipeline
+
+from jax.experimental.compilation_cache import compilation_cache as cc
+cc.initialize_cache("/tmp/sdxl_cache")
+import time
+
+NUM_DEVICES = jax.device_count()
+```
+
+First, we import the necessary libraries:
+- `jax` is provides the primitives for TPU operations
+- `flax.jax_utils` contains some useful utility functions for `Flax`, a neural network library built on top of JAX
+- `diffusers` has all the code that is relevant for SDXL.
+- We also initialize a cache to speed up the JAX model compilation.
+- We automatically determine the number of available TPU devices.
+
+**1. Downloading Model and Loading Pipeline**
+
+```python
+pipeline, params = FlaxStableDiffusionXLPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0", revision="refs/pr/95", split_head_dim=True
+)
+```
+Here, a pre-trained model `stable-diffusion-xl-base-1.0` from the namespace `stabilityai` is loaded. It returns a pipeline for inference and its parameters.
+
+**2. Casting Parameter Types**
+
+```python
+scheduler_state = params.pop("scheduler")
+params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), params)
+params["scheduler"] = scheduler_state
+```
+This section adjusts the data types of the model parameters.
+We convert all parameters to `bfloat16` to speed-up the computation with model weights.
+**Note** that the scheduler parameters are **not** converted to `blfoat16` as the loss
+in precision is degrading the pipeline's performance too significantly.
+
+**3. Define Inputs to Pipeline**
+
+```python
+default_prompt = ...
+default_neg_prompt = ...
+default_seed = 33
+default_guidance_scale = 5.0
+default_num_steps = 25
+```
+Here, various default inputs for the pipeline are set, including the prompt, negative prompt, random seed, guidance scale, and the number of inference steps.
+
+**4. Tokenizing Inputs**
+
+```python
+def tokenize_prompt(prompt, neg_prompt):
+ prompt_ids = pipeline.prepare_inputs(prompt)
+ neg_prompt_ids = pipeline.prepare_inputs(neg_prompt)
+ return prompt_ids, neg_prompt_ids
+```
+This function tokenizes the given prompts. It's essential because the text encoders of SDXL don't understand raw text; they work with numbers. Tokenization converts text to numbers.
+
+**5. Parallelization and Replication**
+
+```python
+p_params = replicate(params)
+
+def replicate_all(prompt_ids, neg_prompt_ids, seed):
+ ...
+```
+To utilize JAX's parallel capabilities, the parameters and input tensors are duplicated across devices. The `replicate_all` function also ensures that every device produces a different image by creating a unique random seed for each device.
+
+**6. Putting Everything Together**
+
+```python
+def generate(...):
+ ...
+```
+This function integrates all the steps to produce the desired outputs from the model. It takes in prompts, tokenizes them, replicates them across devices, runs them through the pipeline, and converts the images to a format that's more interpretable (PIL format).
+
+**7. Compilation Step**
+
+```python
+start = time.time()
+print(f"Compiling ...")
+generate(default_prompt, default_neg_prompt)
+print(f"Compiled in {time.time() - start}")
+```
+The initial run of the `generate` function will be slow because JAX compiles the function during this call. By running it once here, subsequent calls will be much faster. This section measures and prints the compilation time.
+
+**8. Fast Inference**
+
+```python
+start = time.time()
+prompt = ...
+neg_prompt = ...
+images = generate(prompt, neg_prompt)
+print(f"Inference in {time.time() - start}")
+```
+Now that the function is compiled, this section shows how to use it for fast inference. It measures and prints the inference time.
+
+In summary, the code demonstrates how to load a pre-trained model using Flax and JAX, prepare it for inference, and run it efficiently using JAX's capabilities.
+
+## Ahead of Time (AOT) Compilation
+
+FlaxStableDiffusionXLPipeline takes care of parallelization across multiple devices using jit. Now let's build parallelization ourselves.
+
+For this we will be using a JAX feature called [Ahead of Time](https://jax.readthedocs.io/en/latest/aot.html) (AOT) lowering and compilation. AOT allows to fully compile prior to execution time and have control over different parts of the compilation process.
+
+In [sdxl_single_aot.py](./sdxl_single_aot.py) we give a simple example of how to write our own parallelization logic for text-to-image generation pipeline in JAX using [StabilityAI's Stable Diffusion XL](stabilityai/stable-diffusion-xl-base-1.0)
+
+We add a `aot_compile` function that compiles the `pipeline._generate` function
+telling JAX which input arguments are static, that is, arguments that
+are known at compile time and won't change. In our case, it is num_inference_steps,
+height, width and return_latents.
+
+Once the function is compiled, these parameters are omitted from future calls and
+cannot be changed without modifying the code and recompiling.
+
+```python
+def aot_compile(
+ prompt=default_prompt,
+ negative_prompt=default_neg_prompt,
+ seed=default_seed,
+ guidance_scale=default_guidance_scale,
+ num_inference_steps=default_num_steps
+):
+ prompt_ids, neg_prompt_ids = tokenize_prompt(prompt, negative_prompt)
+ prompt_ids, neg_prompt_ids, rng = replicate_all(prompt_ids, neg_prompt_ids, seed)
+ g = jnp.array([guidance_scale] * prompt_ids.shape[0], dtype=jnp.float32)
+ g = g[:, None]
+
+ return pmap(
+ pipeline._generate,static_broadcasted_argnums=[3, 4, 5, 9]
+ ).lower(
+ prompt_ids,
+ p_params,
+ rng,
+ num_inference_steps, # num_inference_steps
+ height, # height
+ width, # width
+ g,
+ None,
+ neg_prompt_ids,
+ False # return_latents
+ ).compile()
+````
+
+Next we can compile the generate function by executing `aot_compile`.
+
+```python
+start = time.time()
+print("Compiling ...")
+p_generate = aot_compile()
+print(f"Compiled in {time.time() - start}")
+```
+And again we put everything together in a `generate` function.
+
+```python
+def generate(
+ prompt,
+ negative_prompt,
+ seed=default_seed,
+ guidance_scale=default_guidance_scale
+):
+ prompt_ids, neg_prompt_ids = tokenize_prompt(prompt, negative_prompt)
+ prompt_ids, neg_prompt_ids, rng = replicate_all(prompt_ids, neg_prompt_ids, seed)
+ g = jnp.array([guidance_scale] * prompt_ids.shape[0], dtype=jnp.float32)
+ g = g[:, None]
+ images = p_generate(
+ prompt_ids,
+ p_params,
+ rng,
+ g,
+ None,
+ neg_prompt_ids)
+
+ # convert the images to PIL
+ images = images.reshape((images.shape[0] * images.shape[1], ) + images.shape[-3:])
+ return pipeline.numpy_to_pil(np.array(images))
+```
+
+The first forward pass after AOT compilation still takes a while longer than
+subsequent passes, this is because on the first pass, JAX uses Python dispatch, which
+Fills the C++ dispatch cache.
+When using jit, this extra step is done automatically, but when using AOT compilation,
+it doesn't happen until the function call is made.
+
+```python
+start = time.time()
+prompt = "photo of a rhino dressed suit and tie sitting at a table in a bar with a bar stools, award winning photography, Elke vogelsang"
+neg_prompt = "cartoon, illustration, animation. face. male, female"
+images = generate(prompt, neg_prompt)
+print(f"First inference in {time.time() - start}")
+```
+
+From this point forward, any calls to generate should result in a faster inference
+time and it won't change.
+
+```python
+start = time.time()
+prompt = "photo of a rhino dressed suit and tie sitting at a table in a bar with a bar stools, award winning photography, Elke vogelsang"
+neg_prompt = "cartoon, illustration, animation. face. male, female"
+images = generate(prompt, neg_prompt)
+print(f"Inference in {time.time() - start}")
+```
diff --git a/diffusers/examples/research_projects/sdxl_flax/sdxl_single.py b/diffusers/examples/research_projects/sdxl_flax/sdxl_single.py
new file mode 100644
index 0000000000000000000000000000000000000000..5b9b862d99b5c2d583d28395ccb2191e50bef7a4
--- /dev/null
+++ b/diffusers/examples/research_projects/sdxl_flax/sdxl_single.py
@@ -0,0 +1,106 @@
+# Show best practices for SDXL JAX
+import time
+
+import jax
+import jax.numpy as jnp
+import numpy as np
+from flax.jax_utils import replicate
+
+# Let's cache the model compilation, so that it doesn't take as long the next time around.
+from jax.experimental.compilation_cache import compilation_cache as cc
+
+from diffusers import FlaxStableDiffusionXLPipeline
+
+
+cc.initialize_cache("/tmp/sdxl_cache")
+
+
+NUM_DEVICES = jax.device_count()
+
+# 1. Let's start by downloading the model and loading it into our pipeline class
+# Adhering to JAX's functional approach, the model's parameters are returned seperatetely and
+# will have to be passed to the pipeline during inference
+pipeline, params = FlaxStableDiffusionXLPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0", revision="refs/pr/95", split_head_dim=True
+)
+
+# 2. We cast all parameters to bfloat16 EXCEPT the scheduler which we leave in
+# float32 to keep maximal precision
+scheduler_state = params.pop("scheduler")
+params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), params)
+params["scheduler"] = scheduler_state
+
+# 3. Next, we define the different inputs to the pipeline
+default_prompt = "a colorful photo of a castle in the middle of a forest with trees and bushes, by Ismail Inceoglu, shadows, high contrast, dynamic shading, hdr, detailed vegetation, digital painting, digital drawing, detailed painting, a detailed digital painting, gothic art, featured on deviantart"
+default_neg_prompt = "fog, grainy, purple"
+default_seed = 33
+default_guidance_scale = 5.0
+default_num_steps = 25
+
+
+# 4. In order to be able to compile the pipeline
+# all inputs have to be tensors or strings
+# Let's tokenize the prompt and negative prompt
+def tokenize_prompt(prompt, neg_prompt):
+ prompt_ids = pipeline.prepare_inputs(prompt)
+ neg_prompt_ids = pipeline.prepare_inputs(neg_prompt)
+ return prompt_ids, neg_prompt_ids
+
+
+# 5. To make full use of JAX's parallelization capabilities
+# the parameters and input tensors are duplicated across devices
+# To make sure every device generates a different image, we create
+# different seeds for each image. The model parameters won't change
+# during inference so we do not wrap them into a function
+p_params = replicate(params)
+
+
+def replicate_all(prompt_ids, neg_prompt_ids, seed):
+ p_prompt_ids = replicate(prompt_ids)
+ p_neg_prompt_ids = replicate(neg_prompt_ids)
+ rng = jax.random.PRNGKey(seed)
+ rng = jax.random.split(rng, NUM_DEVICES)
+ return p_prompt_ids, p_neg_prompt_ids, rng
+
+
+# 6. Let's now put it all together in a generate function
+def generate(
+ prompt,
+ negative_prompt,
+ seed=default_seed,
+ guidance_scale=default_guidance_scale,
+ num_inference_steps=default_num_steps,
+):
+ prompt_ids, neg_prompt_ids = tokenize_prompt(prompt, negative_prompt)
+ prompt_ids, neg_prompt_ids, rng = replicate_all(prompt_ids, neg_prompt_ids, seed)
+ images = pipeline(
+ prompt_ids,
+ p_params,
+ rng,
+ num_inference_steps=num_inference_steps,
+ neg_prompt_ids=neg_prompt_ids,
+ guidance_scale=guidance_scale,
+ jit=True,
+ ).images
+
+ # convert the images to PIL
+ images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])
+ return pipeline.numpy_to_pil(np.array(images))
+
+
+# 7. Remember that the first call will compile the function and hence be very slow. Let's run generate once
+# so that the pipeline call is compiled
+start = time.time()
+print("Compiling ...")
+generate(default_prompt, default_neg_prompt)
+print(f"Compiled in {time.time() - start}")
+
+# 8. Now the model forward pass will run very quickly, let's try it again
+start = time.time()
+prompt = "photo of a rhino dressed suit and tie sitting at a table in a bar with a bar stools, award winning photography, Elke vogelsang"
+neg_prompt = "cartoon, illustration, animation. face. male, female"
+images = generate(prompt, neg_prompt)
+print(f"Inference in {time.time() - start}")
+
+for i, image in enumerate(images):
+ image.save(f"castle_{i}.png")
diff --git a/diffusers/examples/research_projects/sdxl_flax/sdxl_single_aot.py b/diffusers/examples/research_projects/sdxl_flax/sdxl_single_aot.py
new file mode 100644
index 0000000000000000000000000000000000000000..08bd13902aa9b3caa8553de12b6702a07d6a0936
--- /dev/null
+++ b/diffusers/examples/research_projects/sdxl_flax/sdxl_single_aot.py
@@ -0,0 +1,143 @@
+import time
+
+import jax
+import jax.numpy as jnp
+import numpy as np
+from flax.jax_utils import replicate
+from jax import pmap
+
+# Let's cache the model compilation, so that it doesn't take as long the next time around.
+from jax.experimental.compilation_cache import compilation_cache as cc
+
+from diffusers import FlaxStableDiffusionXLPipeline
+
+
+cc.initialize_cache("/tmp/sdxl_cache")
+
+
+NUM_DEVICES = jax.device_count()
+
+# 1. Let's start by downloading the model and loading it into our pipeline class
+# Adhering to JAX's functional approach, the model's parameters are returned separately and
+# will have to be passed to the pipeline during inference
+pipeline, params = FlaxStableDiffusionXLPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0", revision="refs/pr/95", split_head_dim=True
+)
+
+# 2. We cast all parameters to bfloat16 EXCEPT the scheduler which we leave in
+# float32 to keep maximal precision
+scheduler_state = params.pop("scheduler")
+params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), params)
+params["scheduler"] = scheduler_state
+
+# 3. Next, we define the different inputs to the pipeline
+default_prompt = "a colorful photo of a castle in the middle of a forest with trees and bushes, by Ismail Inceoglu, shadows, high contrast, dynamic shading, hdr, detailed vegetation, digital painting, digital drawing, detailed painting, a detailed digital painting, gothic art, featured on deviantart"
+default_neg_prompt = "fog, grainy, purple"
+default_seed = 33
+default_guidance_scale = 5.0
+default_num_steps = 25
+width = 1024
+height = 1024
+
+
+# 4. In order to be able to compile the pipeline
+# all inputs have to be tensors or strings
+# Let's tokenize the prompt and negative prompt
+def tokenize_prompt(prompt, neg_prompt):
+ prompt_ids = pipeline.prepare_inputs(prompt)
+ neg_prompt_ids = pipeline.prepare_inputs(neg_prompt)
+ return prompt_ids, neg_prompt_ids
+
+
+# 5. To make full use of JAX's parallelization capabilities
+# the parameters and input tensors are duplicated across devices
+# To make sure every device generates a different image, we create
+# different seeds for each image. The model parameters won't change
+# during inference so we do not wrap them into a function
+p_params = replicate(params)
+
+
+def replicate_all(prompt_ids, neg_prompt_ids, seed):
+ p_prompt_ids = replicate(prompt_ids)
+ p_neg_prompt_ids = replicate(neg_prompt_ids)
+ rng = jax.random.PRNGKey(seed)
+ rng = jax.random.split(rng, NUM_DEVICES)
+ return p_prompt_ids, p_neg_prompt_ids, rng
+
+
+# 6. To compile the pipeline._generate function, we must pass all parameters
+# to the function and tell JAX which are static arguments, that is, arguments that
+# are known at compile time and won't change. In our case, it is num_inference_steps,
+# height, width and return_latents.
+# Once the function is compiled, these parameters are omitted from future calls and
+# cannot be changed without modifying the code and recompiling.
+def aot_compile(
+ prompt=default_prompt,
+ negative_prompt=default_neg_prompt,
+ seed=default_seed,
+ guidance_scale=default_guidance_scale,
+ num_inference_steps=default_num_steps,
+):
+ prompt_ids, neg_prompt_ids = tokenize_prompt(prompt, negative_prompt)
+ prompt_ids, neg_prompt_ids, rng = replicate_all(prompt_ids, neg_prompt_ids, seed)
+ g = jnp.array([guidance_scale] * prompt_ids.shape[0], dtype=jnp.float32)
+ g = g[:, None]
+
+ return (
+ pmap(pipeline._generate, static_broadcasted_argnums=[3, 4, 5, 9])
+ .lower(
+ prompt_ids,
+ p_params,
+ rng,
+ num_inference_steps, # num_inference_steps
+ height, # height
+ width, # width
+ g,
+ None,
+ neg_prompt_ids,
+ False, # return_latents
+ )
+ .compile()
+ )
+
+
+start = time.time()
+print("Compiling ...")
+p_generate = aot_compile()
+print(f"Compiled in {time.time() - start}")
+
+
+# 7. Let's now put it all together in a generate function.
+def generate(prompt, negative_prompt, seed=default_seed, guidance_scale=default_guidance_scale):
+ prompt_ids, neg_prompt_ids = tokenize_prompt(prompt, negative_prompt)
+ prompt_ids, neg_prompt_ids, rng = replicate_all(prompt_ids, neg_prompt_ids, seed)
+ g = jnp.array([guidance_scale] * prompt_ids.shape[0], dtype=jnp.float32)
+ g = g[:, None]
+ images = p_generate(prompt_ids, p_params, rng, g, None, neg_prompt_ids)
+
+ # convert the images to PIL
+ images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])
+ return pipeline.numpy_to_pil(np.array(images))
+
+
+# 8. The first forward pass after AOT compilation still takes a while longer than
+# subsequent passes, this is because on the first pass, JAX uses Python dispatch, which
+# Fills the C++ dispatch cache.
+# When using jit, this extra step is done automatically, but when using AOT compilation,
+# it doesn't happen until the function call is made.
+start = time.time()
+prompt = "photo of a rhino dressed suit and tie sitting at a table in a bar with a bar stools, award winning photography, Elke vogelsang"
+neg_prompt = "cartoon, illustration, animation. face. male, female"
+images = generate(prompt, neg_prompt)
+print(f"First inference in {time.time() - start}")
+
+# 9. From this point forward, any calls to generate should result in a faster inference
+# time and it won't change.
+start = time.time()
+prompt = "photo of a rhino dressed suit and tie sitting at a table in a bar with a bar stools, award winning photography, Elke vogelsang"
+neg_prompt = "cartoon, illustration, animation. face. male, female"
+images = generate(prompt, neg_prompt)
+print(f"Inference in {time.time() - start}")
+
+for i, image in enumerate(images):
+ image.save(f"castle_{i}.png")
diff --git a/diffusers/examples/research_projects/vae/README.md b/diffusers/examples/research_projects/vae/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..5bb8a5ffd2caccee3052e7f2aacc25a540fbf98c
--- /dev/null
+++ b/diffusers/examples/research_projects/vae/README.md
@@ -0,0 +1,11 @@
+# VAE
+
+`vae_roundtrip.py` Demonstrates the use of a VAE by roundtripping an image through the encoder and decoder. Original and reconstructed images are displayed side by side.
+
+```
+cd examples/research_projects/vae
+python vae_roundtrip.py \
+ --pretrained_model_name_or_path="stable-diffusion-v1-5/stable-diffusion-v1-5" \
+ --subfolder="vae" \
+ --input_image="/path/to/your/input.png"
+```
diff --git a/diffusers/examples/research_projects/vae/vae_roundtrip.py b/diffusers/examples/research_projects/vae/vae_roundtrip.py
new file mode 100644
index 0000000000000000000000000000000000000000..65c2b43a9bde606d2fd78910e837559d3bdc60c8
--- /dev/null
+++ b/diffusers/examples/research_projects/vae/vae_roundtrip.py
@@ -0,0 +1,282 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+import argparse
+import typing
+from typing import Optional, Union
+
+import torch
+from PIL import Image
+from torchvision import transforms # type: ignore
+
+from diffusers.image_processor import VaeImageProcessor
+from diffusers.models.autoencoders.autoencoder_kl import (
+ AutoencoderKL,
+ AutoencoderKLOutput,
+)
+from diffusers.models.autoencoders.autoencoder_tiny import (
+ AutoencoderTiny,
+ AutoencoderTinyOutput,
+)
+from diffusers.models.autoencoders.vae import DecoderOutput
+
+
+SupportedAutoencoder = Union[AutoencoderKL, AutoencoderTiny]
+
+
+def load_vae_model(
+ *,
+ device: torch.device,
+ model_name_or_path: str,
+ revision: Optional[str],
+ variant: Optional[str],
+ # NOTE: use subfolder="vae" if the pointed model is for stable diffusion as a whole instead of just the VAE
+ subfolder: Optional[str],
+ use_tiny_nn: bool,
+) -> SupportedAutoencoder:
+ if use_tiny_nn:
+ # NOTE: These scaling factors don't have to be the same as each other.
+ down_scale = 2
+ up_scale = 2
+ vae = AutoencoderTiny.from_pretrained( # type: ignore
+ model_name_or_path,
+ subfolder=subfolder,
+ revision=revision,
+ variant=variant,
+ downscaling_scaling_factor=down_scale,
+ upsampling_scaling_factor=up_scale,
+ )
+ assert isinstance(vae, AutoencoderTiny)
+ else:
+ vae = AutoencoderKL.from_pretrained( # type: ignore
+ model_name_or_path,
+ subfolder=subfolder,
+ revision=revision,
+ variant=variant,
+ )
+ assert isinstance(vae, AutoencoderKL)
+ vae = vae.to(device)
+ vae.eval() # Set the model to inference mode
+ return vae
+
+
+def pil_to_nhwc(
+ *,
+ device: torch.device,
+ image: Image.Image,
+) -> torch.Tensor:
+ assert image.mode == "RGB"
+ transform = transforms.ToTensor()
+ nhwc = transform(image).unsqueeze(0).to(device) # type: ignore
+ assert isinstance(nhwc, torch.Tensor)
+ return nhwc
+
+
+def nhwc_to_pil(
+ *,
+ nhwc: torch.Tensor,
+) -> Image.Image:
+ assert nhwc.shape[0] == 1
+ hwc = nhwc.squeeze(0).cpu()
+ return transforms.ToPILImage()(hwc) # type: ignore
+
+
+def concatenate_images(
+ *,
+ left: Image.Image,
+ right: Image.Image,
+ vertical: bool = False,
+) -> Image.Image:
+ width1, height1 = left.size
+ width2, height2 = right.size
+ if vertical:
+ total_height = height1 + height2
+ max_width = max(width1, width2)
+ new_image = Image.new("RGB", (max_width, total_height))
+ new_image.paste(left, (0, 0))
+ new_image.paste(right, (0, height1))
+ else:
+ total_width = width1 + width2
+ max_height = max(height1, height2)
+ new_image = Image.new("RGB", (total_width, max_height))
+ new_image.paste(left, (0, 0))
+ new_image.paste(right, (width1, 0))
+ return new_image
+
+
+def to_latent(
+ *,
+ rgb_nchw: torch.Tensor,
+ vae: SupportedAutoencoder,
+) -> torch.Tensor:
+ rgb_nchw = VaeImageProcessor.normalize(rgb_nchw) # type: ignore
+ encoding_nchw = vae.encode(typing.cast(torch.FloatTensor, rgb_nchw))
+ if isinstance(encoding_nchw, AutoencoderKLOutput):
+ latent = encoding_nchw.latent_dist.sample() # type: ignore
+ assert isinstance(latent, torch.Tensor)
+ elif isinstance(encoding_nchw, AutoencoderTinyOutput):
+ latent = encoding_nchw.latents
+ do_internal_vae_scaling = False # Is this needed?
+ if do_internal_vae_scaling:
+ latent = vae.scale_latents(latent).mul(255).round().byte() # type: ignore
+ latent = vae.unscale_latents(latent / 255.0) # type: ignore
+ assert isinstance(latent, torch.Tensor)
+ else:
+ assert False, f"Unknown encoding type: {type(encoding_nchw)}"
+ return latent
+
+
+def from_latent(
+ *,
+ latent_nchw: torch.Tensor,
+ vae: SupportedAutoencoder,
+) -> torch.Tensor:
+ decoding_nchw = vae.decode(latent_nchw) # type: ignore
+ assert isinstance(decoding_nchw, DecoderOutput)
+ rgb_nchw = VaeImageProcessor.denormalize(decoding_nchw.sample) # type: ignore
+ assert isinstance(rgb_nchw, torch.Tensor)
+ return rgb_nchw
+
+
+def main_kwargs(
+ *,
+ device: torch.device,
+ input_image_path: str,
+ pretrained_model_name_or_path: str,
+ revision: Optional[str],
+ variant: Optional[str],
+ subfolder: Optional[str],
+ use_tiny_nn: bool,
+) -> None:
+ vae = load_vae_model(
+ device=device,
+ model_name_or_path=pretrained_model_name_or_path,
+ revision=revision,
+ variant=variant,
+ subfolder=subfolder,
+ use_tiny_nn=use_tiny_nn,
+ )
+ original_pil = Image.open(input_image_path).convert("RGB")
+ original_image = pil_to_nhwc(
+ device=device,
+ image=original_pil,
+ )
+ print(f"Original image shape: {original_image.shape}")
+ reconstructed_image: Optional[torch.Tensor] = None
+
+ with torch.no_grad():
+ latent_image = to_latent(rgb_nchw=original_image, vae=vae)
+ print(f"Latent shape: {latent_image.shape}")
+ reconstructed_image = from_latent(latent_nchw=latent_image, vae=vae)
+ reconstructed_pil = nhwc_to_pil(nhwc=reconstructed_image)
+ combined_image = concatenate_images(
+ left=original_pil,
+ right=reconstructed_pil,
+ vertical=False,
+ )
+ combined_image.show("Original | Reconstruction")
+ print(f"Reconstructed image shape: {reconstructed_image.shape}")
+
+
+def parse_args() -> argparse.Namespace:
+ parser = argparse.ArgumentParser(description="Inference with VAE")
+ parser.add_argument(
+ "--input_image",
+ type=str,
+ required=True,
+ help="Path to the input image for inference.",
+ )
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ required=True,
+ help="Path to pretrained VAE model.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ help="Model version.",
+ )
+ parser.add_argument(
+ "--variant",
+ type=str,
+ default=None,
+ help="Model file variant, e.g., 'fp16'.",
+ )
+ parser.add_argument(
+ "--subfolder",
+ type=str,
+ default=None,
+ help="Subfolder in the model file.",
+ )
+ parser.add_argument(
+ "--use_cuda",
+ action="store_true",
+ help="Use CUDA if available.",
+ )
+ parser.add_argument(
+ "--use_tiny_nn",
+ action="store_true",
+ help="Use tiny neural network.",
+ )
+ return parser.parse_args()
+
+
+# EXAMPLE USAGE:
+#
+# python vae_roundtrip.py --use_cuda --pretrained_model_name_or_path "runwayml/stable-diffusion-v1-5" --subfolder "vae" --input_image "foo.png"
+#
+# python vae_roundtrip.py --use_cuda --pretrained_model_name_or_path "madebyollin/taesd" --use_tiny_nn --input_image "foo.png"
+#
+def main_cli() -> None:
+ args = parse_args()
+
+ input_image_path = args.input_image
+ assert isinstance(input_image_path, str)
+
+ pretrained_model_name_or_path = args.pretrained_model_name_or_path
+ assert isinstance(pretrained_model_name_or_path, str)
+
+ revision = args.revision
+ assert isinstance(revision, (str, type(None)))
+
+ variant = args.variant
+ assert isinstance(variant, (str, type(None)))
+
+ subfolder = args.subfolder
+ assert isinstance(subfolder, (str, type(None)))
+
+ use_cuda = args.use_cuda
+ assert isinstance(use_cuda, bool)
+
+ use_tiny_nn = args.use_tiny_nn
+ assert isinstance(use_tiny_nn, bool)
+
+ device = torch.device("cuda" if use_cuda else "cpu")
+
+ main_kwargs(
+ device=device,
+ input_image_path=input_image_path,
+ pretrained_model_name_or_path=pretrained_model_name_or_path,
+ revision=revision,
+ variant=variant,
+ subfolder=subfolder,
+ use_tiny_nn=use_tiny_nn,
+ )
+
+
+if __name__ == "__main__":
+ main_cli()
diff --git a/diffusers/examples/t2i_adapter/README.md b/diffusers/examples/t2i_adapter/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..7d7491950d0ec22ed6f9badd8307efee73789786
--- /dev/null
+++ b/diffusers/examples/t2i_adapter/README.md
@@ -0,0 +1 @@
+We don't yet support training T2I-Adapters on Stable Diffusion yet. For training T2I-Adapters on Stable Diffusion XL, refer [here](./README_sdxl.md).
\ No newline at end of file
diff --git a/diffusers/examples/t2i_adapter/README_sdxl.md b/diffusers/examples/t2i_adapter/README_sdxl.md
new file mode 100644
index 0000000000000000000000000000000000000000..1e5a19fedad1b1e94de1d8a8ad3202fa7a96c53c
--- /dev/null
+++ b/diffusers/examples/t2i_adapter/README_sdxl.md
@@ -0,0 +1,131 @@
+# T2I-Adapter training example for Stable Diffusion XL (SDXL)
+
+The `train_t2i_adapter_sdxl.py` script shows how to implement the [T2I-Adapter training procedure](https://hf.co/papers/2302.08453) for [Stable Diffusion XL](https://huggingface.co/papers/2307.01952).
+
+## Running locally with PyTorch
+
+### Installing the dependencies
+
+Before running the scripts, make sure to install the library's training dependencies:
+
+**Important**
+
+To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
+
+```bash
+git clone https://github.com/huggingface/diffusers
+cd diffusers
+pip install -e .
+```
+
+Then cd in the `examples/t2i_adapter` folder and run
+```bash
+pip install -r requirements.txt
+```
+
+And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
+
+```bash
+accelerate config
+```
+
+Or for a default accelerate configuration without answering questions about your environment
+
+```bash
+accelerate config default
+```
+
+Or if your environment doesn't support an interactive shell (e.g., a notebook)
+
+```python
+from accelerate.utils import write_basic_config
+write_basic_config()
+```
+
+When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
+
+## Circle filling dataset
+
+The original dataset is hosted in the [ControlNet repo](https://huggingface.co/lllyasviel/ControlNet/blob/main/training/fill50k.zip). We re-uploaded it to be compatible with `datasets` [here](https://huggingface.co/datasets/fusing/fill50k). Note that `datasets` handles dataloading within the training script.
+
+## Training
+
+Our training examples use two test conditioning images. They can be downloaded by running
+
+```sh
+wget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_1.png
+
+wget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_2.png
+```
+
+Then run `huggingface-cli login` to log into your Hugging Face account. This is needed to be able to push the trained T2IAdapter parameters to Hugging Face Hub.
+
+```bash
+export MODEL_DIR="stabilityai/stable-diffusion-xl-base-1.0"
+export OUTPUT_DIR="path to save model"
+
+accelerate launch train_t2i_adapter_sdxl.py \
+ --pretrained_model_name_or_path=$MODEL_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --dataset_name=fusing/fill50k \
+ --mixed_precision="fp16" \
+ --resolution=1024 \
+ --learning_rate=1e-5 \
+ --max_train_steps=15000 \
+ --validation_image "./conditioning_image_1.png" "./conditioning_image_2.png" \
+ --validation_prompt "red circle with blue background" "cyan circle with brown floral background" \
+ --validation_steps=100 \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=4 \
+ --report_to="wandb" \
+ --seed=42 \
+ --push_to_hub
+```
+
+To better track our training experiments, we're using the following flags in the command above:
+
+* `report_to="wandb` will ensure the training runs are tracked on Weights and Biases. To use it, be sure to install `wandb` with `pip install wandb`.
+* `validation_image`, `validation_prompt`, and `validation_steps` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.
+
+Our experiments were conducted on a single 40GB A100 GPU.
+
+### Inference
+
+Once training is done, we can perform inference like so:
+
+```python
+from diffusers import StableDiffusionXLAdapterPipeline, T2IAdapter, EulerAncestralDiscreteSchedulerTest
+from diffusers.utils import load_image
+import torch
+
+base_model_path = "stabilityai/stable-diffusion-xl-base-1.0"
+adapter_path = "path to adapter"
+
+adapter = T2IAdapter.from_pretrained(adapter_path, torch_dtype=torch.float16)
+pipe = StableDiffusionXLAdapterPipeline.from_pretrained(
+ base_model_path, adapter=adapter, torch_dtype=torch.float16
+)
+
+# speed up diffusion process with faster scheduler and memory optimization
+pipe.scheduler = EulerAncestralDiscreteSchedulerTest.from_config(pipe.scheduler.config)
+# remove following line if xformers is not installed or when using Torch 2.0.
+pipe.enable_xformers_memory_efficient_attention()
+# memory optimization.
+pipe.enable_model_cpu_offload()
+
+control_image = load_image("./conditioning_image_1.png")
+prompt = "pale golden rod circle with old lace background"
+
+# generate image
+generator = torch.manual_seed(0)
+image = pipe(
+ prompt, num_inference_steps=20, generator=generator, image=control_image
+).images[0]
+image.save("./output.png")
+```
+
+## Notes
+
+### Specifying a better VAE
+
+SDXL's VAE is known to suffer from numerical instability issues. This is why we also expose a CLI argument namely `--pretrained_vae_model_name_or_path` that lets you specify the location of a better VAE (such as [this one](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)).
diff --git a/diffusers/examples/t2i_adapter/requirements.txt b/diffusers/examples/t2i_adapter/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..2955535b192796e2618393012560a7b534fbea23
--- /dev/null
+++ b/diffusers/examples/t2i_adapter/requirements.txt
@@ -0,0 +1,8 @@
+transformers>=4.25.1
+accelerate>=0.16.0
+safetensors
+datasets
+torchvision
+ftfy
+tensorboard
+wandb
\ No newline at end of file
diff --git a/diffusers/examples/t2i_adapter/test_t2i_adapter.py b/diffusers/examples/t2i_adapter/test_t2i_adapter.py
new file mode 100644
index 0000000000000000000000000000000000000000..cdf124cdd9321a37540e5de6645d905eba3900d7
--- /dev/null
+++ b/diffusers/examples/t2i_adapter/test_t2i_adapter.py
@@ -0,0 +1,51 @@
+# coding=utf-8
+# Copyright 2024 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import os
+import sys
+import tempfile
+
+
+sys.path.append("..")
+from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
+
+
+logging.basicConfig(level=logging.DEBUG)
+
+logger = logging.getLogger()
+stream_handler = logging.StreamHandler(sys.stdout)
+logger.addHandler(stream_handler)
+
+
+class T2IAdapter(ExamplesTestsAccelerate):
+ def test_t2i_adapter_sdxl(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/t2i_adapter/train_t2i_adapter_sdxl.py
+ --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-xl-pipe
+ --adapter_model_name_or_path=hf-internal-testing/tiny-adapter
+ --dataset_name=hf-internal-testing/fill10
+ --output_dir={tmpdir}
+ --resolution=64
+ --train_batch_size=1
+ --gradient_accumulation_steps=1
+ --max_train_steps=9
+ --checkpointing_steps=2
+ """.split()
+
+ run_command(self._launch_args + test_args)
+
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "diffusion_pytorch_model.safetensors")))
diff --git a/diffusers/examples/t2i_adapter/train_t2i_adapter_sdxl.py b/diffusers/examples/t2i_adapter/train_t2i_adapter_sdxl.py
new file mode 100644
index 0000000000000000000000000000000000000000..02a064fa81ed30ea8f04331fcfbdc285b885c0a2
--- /dev/null
+++ b/diffusers/examples/t2i_adapter/train_t2i_adapter_sdxl.py
@@ -0,0 +1,1327 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+import argparse
+import functools
+import gc
+import logging
+import math
+import os
+import random
+import shutil
+from pathlib import Path
+
+import accelerate
+import numpy as np
+import torch
+import torch.utils.checkpoint
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import ProjectConfiguration, set_seed
+from datasets import load_dataset
+from huggingface_hub import create_repo, upload_folder
+from packaging import version
+from PIL import Image
+from torchvision import transforms
+from tqdm.auto import tqdm
+from transformers import AutoTokenizer, PretrainedConfig
+
+import diffusers
+from diffusers import (
+ AutoencoderKL,
+ EulerDiscreteScheduler,
+ StableDiffusionXLAdapterPipeline,
+ T2IAdapter,
+ UNet2DConditionModel,
+)
+from diffusers.optimization import get_scheduler
+from diffusers.utils import check_min_version, is_wandb_available
+from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
+from diffusers.utils.import_utils import is_xformers_available
+from diffusers.utils.torch_utils import is_compiled_module
+
+
+MAX_SEQ_LENGTH = 77
+
+if is_wandb_available():
+ import wandb
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.31.0.dev0")
+
+logger = get_logger(__name__)
+
+
+def image_grid(imgs, rows, cols):
+ assert len(imgs) == rows * cols
+
+ w, h = imgs[0].size
+ grid = Image.new("RGB", size=(cols * w, rows * h))
+
+ for i, img in enumerate(imgs):
+ grid.paste(img, box=(i % cols * w, i // cols * h))
+ return grid
+
+
+def log_validation(vae, unet, adapter, args, accelerator, weight_dtype, step):
+ logger.info("Running validation... ")
+
+ adapter = accelerator.unwrap_model(adapter)
+
+ pipeline = StableDiffusionXLAdapterPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ vae=vae,
+ unet=unet,
+ adapter=adapter,
+ revision=args.revision,
+ variant=args.variant,
+ torch_dtype=weight_dtype,
+ )
+ pipeline = pipeline.to(accelerator.device)
+ pipeline.set_progress_bar_config(disable=True)
+
+ if args.enable_xformers_memory_efficient_attention:
+ pipeline.enable_xformers_memory_efficient_attention()
+
+ if args.seed is None:
+ generator = None
+ else:
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
+
+ if len(args.validation_image) == len(args.validation_prompt):
+ validation_images = args.validation_image
+ validation_prompts = args.validation_prompt
+ elif len(args.validation_image) == 1:
+ validation_images = args.validation_image * len(args.validation_prompt)
+ validation_prompts = args.validation_prompt
+ elif len(args.validation_prompt) == 1:
+ validation_images = args.validation_image
+ validation_prompts = args.validation_prompt * len(args.validation_image)
+ else:
+ raise ValueError(
+ "number of `args.validation_image` and `args.validation_prompt` should be checked in `parse_args`"
+ )
+
+ image_logs = []
+
+ for validation_prompt, validation_image in zip(validation_prompts, validation_images):
+ validation_image = Image.open(validation_image).convert("RGB")
+ validation_image = validation_image.resize((args.resolution, args.resolution))
+
+ images = []
+
+ for _ in range(args.num_validation_images):
+ with torch.autocast("cuda"):
+ image = pipeline(
+ prompt=validation_prompt, image=validation_image, num_inference_steps=20, generator=generator
+ ).images[0]
+ images.append(image)
+
+ image_logs.append(
+ {"validation_image": validation_image, "images": images, "validation_prompt": validation_prompt}
+ )
+
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ for log in image_logs:
+ images = log["images"]
+ validation_prompt = log["validation_prompt"]
+ validation_image = log["validation_image"]
+
+ formatted_images = []
+
+ formatted_images.append(np.asarray(validation_image))
+
+ for image in images:
+ formatted_images.append(np.asarray(image))
+
+ formatted_images = np.stack(formatted_images)
+
+ tracker.writer.add_images(validation_prompt, formatted_images, step, dataformats="NHWC")
+ elif tracker.name == "wandb":
+ formatted_images = []
+
+ for log in image_logs:
+ images = log["images"]
+ validation_prompt = log["validation_prompt"]
+ validation_image = log["validation_image"]
+
+ formatted_images.append(wandb.Image(validation_image, caption="adapter conditioning"))
+
+ for image in images:
+ image = wandb.Image(image, caption=validation_prompt)
+ formatted_images.append(image)
+
+ tracker.log({"validation": formatted_images})
+ else:
+ logger.warning(f"image logging not implemented for {tracker.name}")
+
+ del pipeline
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ return image_logs
+
+
+def import_model_class_from_model_name_or_path(
+ pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
+):
+ text_encoder_config = PretrainedConfig.from_pretrained(
+ pretrained_model_name_or_path, subfolder=subfolder, revision=revision
+ )
+ model_class = text_encoder_config.architectures[0]
+
+ if model_class == "CLIPTextModel":
+ from transformers import CLIPTextModel
+
+ return CLIPTextModel
+ elif model_class == "CLIPTextModelWithProjection":
+ from transformers import CLIPTextModelWithProjection
+
+ return CLIPTextModelWithProjection
+ else:
+ raise ValueError(f"{model_class} is not supported.")
+
+
+def save_model_card(repo_id: str, image_logs: dict = None, base_model: str = None, repo_folder: str = None):
+ img_str = ""
+ if image_logs is not None:
+ img_str = "You can find some example images below.\n"
+ for i, log in enumerate(image_logs):
+ images = log["images"]
+ validation_prompt = log["validation_prompt"]
+ validation_image = log["validation_image"]
+ validation_image.save(os.path.join(repo_folder, "image_control.png"))
+ img_str += f"prompt: {validation_prompt}\n"
+ images = [validation_image] + images
+ image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images_{i}.png"))
+ img_str += f"\n"
+
+ model_description = f"""
+# t2iadapter-{repo_id}
+
+These are t2iadapter weights trained on {base_model} with new type of conditioning.
+{img_str}
+"""
+ model_card = load_or_create_model_card(
+ repo_id_or_path=repo_id,
+ from_training=True,
+ license="creativeml-openrail-m",
+ base_model=base_model,
+ model_description=model_description,
+ inference=True,
+ )
+
+ tags = [
+ "stable-diffusion-xl",
+ "stable-diffusion-xl-diffusers",
+ "text-to-image",
+ "diffusers",
+ "t2iadapter",
+ "diffusers-training",
+ ]
+ model_card = populate_model_card(model_card, tags=tags)
+
+ model_card.save(os.path.join(repo_folder, "README.md"))
+
+
+def parse_args(input_args=None):
+ parser = argparse.ArgumentParser(description="Simple example of a ControlNet training script.")
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--pretrained_vae_model_name_or_path",
+ type=str,
+ default=None,
+ help="Path to an improved VAE to stabilize training. For more details check out: https://github.com/huggingface/diffusers/pull/4038.",
+ )
+ parser.add_argument(
+ "--adapter_model_name_or_path",
+ type=str,
+ default=None,
+ help="Path to pretrained adapter model or model identifier from huggingface.co/models."
+ " If not specified adapter weights are initialized w.r.t the configurations of SDXL.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help=(
+ "Revision of pretrained model identifier from huggingface.co/models. Trainable model components should be"
+ " float32 precision."
+ ),
+ )
+ parser.add_argument(
+ "--variant",
+ type=str,
+ default=None,
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
+ )
+ parser.add_argument(
+ "--tokenizer_name",
+ type=str,
+ default=None,
+ help="Pretrained tokenizer name or path if not the same as model_name",
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="t2iadapter-model",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=1024,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--detection_resolution",
+ type=int,
+ default=None,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--crops_coords_top_left_h",
+ type=int,
+ default=0,
+ help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."),
+ )
+ parser.add_argument(
+ "--crops_coords_top_left_w",
+ type=int,
+ default=0,
+ help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."),
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=1)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. "
+ "In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference."
+ "Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components."
+ "See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step"
+ "instructions."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=3,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=5e-6,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--lr_num_cycles",
+ type=int,
+ default=1,
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
+ )
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
+ parser.add_argument(
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
+ )
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=1,
+ help=("Number of subprocesses to use for data loading."),
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument(
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
+ )
+ parser.add_argument(
+ "--set_grads_to_none",
+ action="store_true",
+ help=(
+ "Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain"
+ " behaviors, so disable this argument if it causes any problems. More info:"
+ " https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html"
+ ),
+ )
+ parser.add_argument(
+ "--dataset_name",
+ type=str,
+ default=None,
+ help=(
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
+ " or to a folder containing files that 🤗 Datasets can understand."
+ ),
+ )
+ parser.add_argument(
+ "--dataset_config_name",
+ type=str,
+ default=None,
+ help="The config of the Dataset, leave as None if there's only one config.",
+ )
+ parser.add_argument(
+ "--train_data_dir",
+ type=str,
+ default=None,
+ help=(
+ "A folder containing the training data. Folder contents must follow the structure described in"
+ " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
+ " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
+ ),
+ )
+ parser.add_argument(
+ "--image_column", type=str, default="image", help="The column of the dataset containing the target image."
+ )
+ parser.add_argument(
+ "--conditioning_image_column",
+ type=str,
+ default="conditioning_image",
+ help="The column of the dataset containing the adapter conditioning image.",
+ )
+ parser.add_argument(
+ "--caption_column",
+ type=str,
+ default="text",
+ help="The column of the dataset containing a caption or a list of captions.",
+ )
+ parser.add_argument(
+ "--max_train_samples",
+ type=int,
+ default=None,
+ help=(
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ ),
+ )
+ parser.add_argument(
+ "--proportion_empty_prompts",
+ type=float,
+ default=0,
+ help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).",
+ )
+ parser.add_argument(
+ "--validation_prompt",
+ type=str,
+ default=None,
+ nargs="+",
+ help=(
+ "A set of prompts evaluated every `--validation_steps` and logged to `--report_to`."
+ " Provide either a matching number of `--validation_image`s, a single `--validation_image`"
+ " to be used with all prompts, or a single prompt that will be used with all `--validation_image`s."
+ ),
+ )
+ parser.add_argument(
+ "--validation_image",
+ type=str,
+ default=None,
+ nargs="+",
+ help=(
+ "A set of paths to the t2iadapter conditioning image be evaluated every `--validation_steps`"
+ " and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a"
+ " a single `--validation_prompt` to be used with all `--validation_image`s, or a single"
+ " `--validation_image` that will be used with all `--validation_prompt`s."
+ ),
+ )
+ parser.add_argument(
+ "--num_validation_images",
+ type=int,
+ default=4,
+ help="Number of images to be generated for each `--validation_image`, `--validation_prompt` pair",
+ )
+ parser.add_argument(
+ "--validation_steps",
+ type=int,
+ default=100,
+ help=(
+ "Run validation every X steps. Validation consists of running the prompt"
+ " `args.validation_prompt` multiple times: `args.num_validation_images`"
+ " and logging the images."
+ ),
+ )
+ parser.add_argument(
+ "--tracker_project_name",
+ type=str,
+ default="sd_xl_train_t2iadapter",
+ help=(
+ "The `project_name` argument passed to Accelerator.init_trackers for"
+ " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
+ ),
+ )
+
+ if input_args is not None:
+ args = parser.parse_args(input_args)
+ else:
+ args = parser.parse_args()
+
+ if args.dataset_name is None and args.train_data_dir is None:
+ raise ValueError("Specify either `--dataset_name` or `--train_data_dir`")
+
+ if args.dataset_name is not None and args.train_data_dir is not None:
+ raise ValueError("Specify only one of `--dataset_name` or `--train_data_dir`")
+
+ if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1:
+ raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].")
+
+ if args.validation_prompt is not None and args.validation_image is None:
+ raise ValueError("`--validation_image` must be set if `--validation_prompt` is set")
+
+ if args.validation_prompt is None and args.validation_image is not None:
+ raise ValueError("`--validation_prompt` must be set if `--validation_image` is set")
+
+ if (
+ args.validation_image is not None
+ and args.validation_prompt is not None
+ and len(args.validation_image) != 1
+ and len(args.validation_prompt) != 1
+ and len(args.validation_image) != len(args.validation_prompt)
+ ):
+ raise ValueError(
+ "Must provide either 1 `--validation_image`, 1 `--validation_prompt`,"
+ " or the same number of `--validation_prompt`s and `--validation_image`s"
+ )
+
+ if args.resolution % 8 != 0:
+ raise ValueError(
+ "`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the t2iadapter encoder."
+ )
+
+ return args
+
+
+def get_train_dataset(args, accelerator):
+ # Get the datasets: you can either provide your own training and evaluation files (see below)
+ # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
+
+ # In distributed training, the load_dataset function guarantees that only one local process can concurrently
+ # download the dataset.
+ if args.dataset_name is not None:
+ # Downloading and loading a dataset from the hub.
+ dataset = load_dataset(
+ args.dataset_name,
+ args.dataset_config_name,
+ cache_dir=args.cache_dir,
+ )
+ else:
+ if args.train_data_dir is not None:
+ dataset = load_dataset(
+ args.train_data_dir,
+ cache_dir=args.cache_dir,
+ )
+ # See more about loading custom images at
+ # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script
+
+ # Preprocessing the datasets.
+ # We need to tokenize inputs and targets.
+ column_names = dataset["train"].column_names
+
+ # 6. Get the column names for input/target.
+ if args.image_column is None:
+ image_column = column_names[0]
+ logger.info(f"image column defaulting to {image_column}")
+ else:
+ image_column = args.image_column
+ if image_column not in column_names:
+ raise ValueError(
+ f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
+ )
+
+ if args.caption_column is None:
+ caption_column = column_names[1]
+ logger.info(f"caption column defaulting to {caption_column}")
+ else:
+ caption_column = args.caption_column
+ if caption_column not in column_names:
+ raise ValueError(
+ f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
+ )
+
+ if args.conditioning_image_column is None:
+ conditioning_image_column = column_names[2]
+ logger.info(f"conditioning image column defaulting to {conditioning_image_column}")
+ else:
+ conditioning_image_column = args.conditioning_image_column
+ if conditioning_image_column not in column_names:
+ raise ValueError(
+ f"`--conditioning_image_column` value '{args.conditioning_image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
+ )
+
+ with accelerator.main_process_first():
+ train_dataset = dataset["train"].shuffle(seed=args.seed)
+ if args.max_train_samples is not None:
+ train_dataset = train_dataset.select(range(args.max_train_samples))
+ return train_dataset
+
+
+# Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt
+def encode_prompt(prompt_batch, text_encoders, tokenizers, proportion_empty_prompts, is_train=True):
+ prompt_embeds_list = []
+
+ captions = []
+ for caption in prompt_batch:
+ if random.random() < proportion_empty_prompts:
+ captions.append("")
+ elif isinstance(caption, str):
+ captions.append(caption)
+ elif isinstance(caption, (list, np.ndarray)):
+ # take a random caption if there are multiple
+ captions.append(random.choice(caption) if is_train else caption[0])
+
+ with torch.no_grad():
+ for tokenizer, text_encoder in zip(tokenizers, text_encoders):
+ text_inputs = tokenizer(
+ captions,
+ padding="max_length",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ prompt_embeds = text_encoder(
+ text_input_ids.to(text_encoder.device),
+ output_hidden_states=True,
+ )
+
+ # We are only ALWAYS interested in the pooled output of the final text encoder
+ pooled_prompt_embeds = prompt_embeds[0]
+ prompt_embeds = prompt_embeds.hidden_states[-2]
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
+ prompt_embeds_list.append(prompt_embeds)
+
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
+ pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
+ return prompt_embeds, pooled_prompt_embeds
+
+
+def prepare_train_dataset(dataset, accelerator):
+ image_transforms = transforms.Compose(
+ [
+ transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
+ transforms.CenterCrop(args.resolution),
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+ )
+
+ conditioning_image_transforms = transforms.Compose(
+ [
+ transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
+ transforms.CenterCrop(args.resolution),
+ transforms.ToTensor(),
+ ]
+ )
+
+ def preprocess_train(examples):
+ images = [image.convert("RGB") for image in examples[args.image_column]]
+ images = [image_transforms(image) for image in images]
+
+ conditioning_images = [image.convert("RGB") for image in examples[args.conditioning_image_column]]
+ conditioning_images = [conditioning_image_transforms(image) for image in conditioning_images]
+
+ examples["pixel_values"] = images
+ examples["conditioning_pixel_values"] = conditioning_images
+
+ return examples
+
+ with accelerator.main_process_first():
+ dataset = dataset.with_transform(preprocess_train)
+
+ return dataset
+
+
+def collate_fn(examples):
+ pixel_values = torch.stack([example["pixel_values"] for example in examples])
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
+
+ conditioning_pixel_values = torch.stack([example["conditioning_pixel_values"] for example in examples])
+ conditioning_pixel_values = conditioning_pixel_values.to(memory_format=torch.contiguous_format).float()
+
+ prompt_ids = torch.stack([torch.tensor(example["prompt_embeds"]) for example in examples])
+
+ add_text_embeds = torch.stack([torch.tensor(example["text_embeds"]) for example in examples])
+ add_time_ids = torch.stack([torch.tensor(example["time_ids"]) for example in examples])
+
+ return {
+ "pixel_values": pixel_values,
+ "conditioning_pixel_values": conditioning_pixel_values,
+ "prompt_ids": prompt_ids,
+ "unet_added_conditions": {"text_embeds": add_text_embeds, "time_ids": add_time_ids},
+ }
+
+
+def main(args):
+ if args.report_to == "wandb" and args.hub_token is not None:
+ raise ValueError(
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
+ " Please use `huggingface-cli login` to authenticate with the Hub."
+ )
+
+ logging_dir = Path(args.output_dir, args.logging_dir)
+
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ )
+
+ # Disable AMP for MPS.
+ if torch.backends.mps.is_available():
+ accelerator.native_amp = False
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name,
+ exist_ok=True,
+ token=args.hub_token,
+ private=True,
+ ).repo_id
+
+ # Load the tokenizers
+ tokenizer_one = AutoTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="tokenizer",
+ revision=args.revision,
+ use_fast=False,
+ )
+ tokenizer_two = AutoTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="tokenizer_2",
+ revision=args.revision,
+ use_fast=False,
+ )
+
+ # import correct text encoder classes
+ text_encoder_cls_one = import_model_class_from_model_name_or_path(
+ args.pretrained_model_name_or_path, args.revision
+ )
+ text_encoder_cls_two = import_model_class_from_model_name_or_path(
+ args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2"
+ )
+
+ # Load scheduler and models
+ noise_scheduler = EulerDiscreteScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
+ text_encoder_one = text_encoder_cls_one.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
+ )
+ text_encoder_two = text_encoder_cls_two.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant
+ )
+ vae_path = (
+ args.pretrained_model_name_or_path
+ if args.pretrained_vae_model_name_or_path is None
+ else args.pretrained_vae_model_name_or_path
+ )
+ vae = AutoencoderKL.from_pretrained(
+ vae_path,
+ subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
+ revision=args.revision,
+ variant=args.variant,
+ )
+ unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
+ )
+
+ if args.adapter_model_name_or_path:
+ logger.info("Loading existing adapter weights.")
+ t2iadapter = T2IAdapter.from_pretrained(args.adapter_model_name_or_path)
+ else:
+ logger.info("Initializing t2iadapter weights.")
+ t2iadapter = T2IAdapter(
+ in_channels=3,
+ channels=(320, 640, 1280, 1280),
+ num_res_blocks=2,
+ downscale_factor=16,
+ adapter_type="full_adapter_xl",
+ )
+
+ # `accelerate` 0.16.0 will have better support for customized saving
+ if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
+ def save_model_hook(models, weights, output_dir):
+ i = len(weights) - 1
+
+ while len(weights) > 0:
+ weights.pop()
+ model = models[i]
+
+ sub_dir = "t2iadapter"
+ model.save_pretrained(os.path.join(output_dir, sub_dir))
+
+ i -= 1
+
+ def load_model_hook(models, input_dir):
+ while len(models) > 0:
+ # pop models so that they are not loaded again
+ model = models.pop()
+
+ # load diffusers style into model
+ load_model = T2IAdapter.from_pretrained(os.path.join(input_dir, "t2iadapter"))
+
+ if args.control_type != "style":
+ model.register_to_config(**load_model.config)
+
+ model.load_state_dict(load_model.state_dict())
+ del load_model
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
+
+ vae.requires_grad_(False)
+ text_encoder_one.requires_grad_(False)
+ text_encoder_two.requires_grad_(False)
+ t2iadapter.train()
+ unet.train()
+
+ if args.enable_xformers_memory_efficient_attention:
+ if is_xformers_available():
+ import xformers
+
+ xformers_version = version.parse(xformers.__version__)
+ if xformers_version == version.parse("0.0.16"):
+ logger.warning(
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
+ )
+ unet.enable_xformers_memory_efficient_attention()
+ else:
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
+
+ def unwrap_model(model):
+ model = accelerator.unwrap_model(model)
+ model = model._orig_mod if is_compiled_module(model) else model
+ return model
+
+ if args.gradient_checkpointing:
+ unet.enable_gradient_checkpointing()
+
+ # Check that all trainable models are in full precision
+ low_precision_error_string = (
+ " Please make sure to always have all model weights in full float32 precision when starting training - even if"
+ " doing mixed precision training, copy of the weights should still be float32."
+ )
+
+ if unwrap_model(t2iadapter).dtype != torch.float32:
+ raise ValueError(
+ f"Controlnet loaded as datatype {unwrap_model(t2iadapter).dtype}. {low_precision_error_string}"
+ )
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
+ )
+
+ optimizer_class = bnb.optim.AdamW8bit
+ else:
+ optimizer_class = torch.optim.AdamW
+
+ # Optimizer creation
+ params_to_optimize = t2iadapter.parameters()
+ optimizer = optimizer_class(
+ params_to_optimize,
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ # For mixed precision training we cast the text_encoder and vae weights to half-precision
+ # as these models are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ # Move vae, unet and text_encoder to device and cast to weight_dtype
+ # The VAE is in float32 to avoid NaN losses.
+ if args.pretrained_vae_model_name_or_path is not None:
+ vae.to(accelerator.device, dtype=weight_dtype)
+ else:
+ vae.to(accelerator.device, dtype=torch.float32)
+ unet.to(accelerator.device, dtype=weight_dtype)
+ text_encoder_one.to(accelerator.device, dtype=weight_dtype)
+ text_encoder_two.to(accelerator.device, dtype=weight_dtype)
+
+ # Here, we compute not just the text embeddings but also the additional embeddings
+ # needed for the SD XL UNet to operate.
+ def compute_embeddings(batch, proportion_empty_prompts, text_encoders, tokenizers, is_train=True):
+ original_size = (args.resolution, args.resolution)
+ target_size = (args.resolution, args.resolution)
+ crops_coords_top_left = (args.crops_coords_top_left_h, args.crops_coords_top_left_w)
+ prompt_batch = batch[args.caption_column]
+
+ prompt_embeds, pooled_prompt_embeds = encode_prompt(
+ prompt_batch, text_encoders, tokenizers, proportion_empty_prompts, is_train
+ )
+ add_text_embeds = pooled_prompt_embeds
+
+ # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
+ add_time_ids = torch.tensor([add_time_ids])
+
+ prompt_embeds = prompt_embeds.to(accelerator.device)
+ add_text_embeds = add_text_embeds.to(accelerator.device)
+ add_time_ids = add_time_ids.repeat(len(prompt_batch), 1)
+ add_time_ids = add_time_ids.to(accelerator.device, dtype=prompt_embeds.dtype)
+ unet_added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
+
+ return {"prompt_embeds": prompt_embeds, **unet_added_cond_kwargs}
+
+ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
+ sigmas = noise_scheduler.sigmas.to(device=accelerator.device, dtype=dtype)
+ schedule_timesteps = noise_scheduler.timesteps.to(accelerator.device)
+ timesteps = timesteps.to(accelerator.device)
+
+ step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
+
+ sigma = sigmas[step_indices].flatten()
+ while len(sigma.shape) < n_dim:
+ sigma = sigma.unsqueeze(-1)
+ return sigma
+
+ # Let's first compute all the embeddings so that we can free up the text encoders
+ # from memory.
+ text_encoders = [text_encoder_one, text_encoder_two]
+ tokenizers = [tokenizer_one, tokenizer_two]
+ train_dataset = get_train_dataset(args, accelerator)
+ compute_embeddings_fn = functools.partial(
+ compute_embeddings,
+ proportion_empty_prompts=args.proportion_empty_prompts,
+ text_encoders=text_encoders,
+ tokenizers=tokenizers,
+ )
+ with accelerator.main_process_first():
+ from datasets.fingerprint import Hasher
+
+ # fingerprint used by the cache for the other processes to load the result
+ # details: https://github.com/huggingface/diffusers/pull/4038#discussion_r1266078401
+ new_fingerprint = Hasher.hash(args)
+ train_dataset = train_dataset.map(compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint)
+
+ # Then get the training dataset ready to be passed to the dataloader.
+ train_dataset = prepare_train_dataset(train_dataset, accelerator)
+
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset,
+ shuffle=True,
+ collate_fn=collate_fn,
+ batch_size=args.train_batch_size,
+ num_workers=args.dataloader_num_workers,
+ )
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps,
+ num_training_steps=args.max_train_steps,
+ num_cycles=args.lr_num_cycles,
+ power=args.lr_power,
+ )
+
+ # Prepare everything with our `accelerator`.
+ t2iadapter, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ t2iadapter, optimizer, train_dataloader, lr_scheduler
+ )
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ tracker_config = dict(vars(args))
+
+ # tensorboard cannot handle list types for config
+ tracker_config.pop("validation_prompt")
+ tracker_config.pop("validation_image")
+
+ accelerator.init_trackers(args.tracker_project_name, config=tracker_config)
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the most recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ initial_global_step = 0
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ initial_global_step = global_step
+ first_epoch = global_step // num_update_steps_per_epoch
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
+
+ image_logs = None
+ for epoch in range(first_epoch, args.num_train_epochs):
+ for step, batch in enumerate(train_dataloader):
+ with accelerator.accumulate(t2iadapter):
+ if args.pretrained_vae_model_name_or_path is not None:
+ pixel_values = batch["pixel_values"].to(dtype=weight_dtype)
+ else:
+ pixel_values = batch["pixel_values"]
+
+ # encode pixel values with batch size of at most 8 to avoid OOM
+ latents = []
+ for i in range(0, pixel_values.shape[0], 8):
+ latents.append(vae.encode(pixel_values[i : i + 8]).latent_dist.sample())
+ latents = torch.cat(latents, dim=0)
+ latents = latents * vae.config.scaling_factor
+ if args.pretrained_vae_model_name_or_path is None:
+ latents = latents.to(weight_dtype)
+
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(latents)
+ bsz = latents.shape[0]
+
+ # Cubic sampling to sample a random timestep for each image.
+ # For more details about why cubic sampling is used, refer to section 3.4 of https://arxiv.org/abs/2302.08453
+ timesteps = torch.rand((bsz,), device=latents.device)
+ timesteps = (1 - timesteps**3) * noise_scheduler.config.num_train_timesteps
+ timesteps = timesteps.long().to(noise_scheduler.timesteps.dtype)
+ timesteps = timesteps.clamp(0, noise_scheduler.config.num_train_timesteps - 1)
+
+ # Add noise to the latents according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
+
+ # Scale the noisy latents for the UNet
+ sigmas = get_sigmas(timesteps, len(noisy_latents.shape), noisy_latents.dtype)
+ inp_noisy_latents = noisy_latents / ((sigmas**2 + 1) ** 0.5)
+
+ # Adapter conditioning.
+ t2iadapter_image = batch["conditioning_pixel_values"].to(dtype=weight_dtype)
+ down_block_additional_residuals = t2iadapter(t2iadapter_image)
+ down_block_additional_residuals = [
+ sample.to(dtype=weight_dtype) for sample in down_block_additional_residuals
+ ]
+
+ # Predict the noise residual
+ model_pred = unet(
+ inp_noisy_latents,
+ timesteps,
+ encoder_hidden_states=batch["prompt_ids"],
+ added_cond_kwargs=batch["unet_added_conditions"],
+ down_block_additional_residuals=down_block_additional_residuals,
+ return_dict=False,
+ )[0]
+
+ # Denoise the latents
+ denoised_latents = model_pred * (-sigmas) + noisy_latents
+ weighing = sigmas**-2.0
+
+ # Get the target for loss depending on the prediction type
+ if noise_scheduler.config.prediction_type == "epsilon":
+ target = latents # we are computing loss against denoise latents
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
+ else:
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
+
+ # MSE loss
+ loss = torch.mean(
+ (weighing.float() * (denoised_latents.float() - target.float()) ** 2).reshape(target.shape[0], -1),
+ dim=1,
+ )
+ loss = loss.mean()
+
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ params_to_clip = t2iadapter.parameters()
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad(set_to_none=args.set_grads_to_none)
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ global_step += 1
+
+ if accelerator.is_main_process:
+ if global_step % args.checkpointing_steps == 0:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ if args.validation_prompt is not None and global_step % args.validation_steps == 0:
+ image_logs = log_validation(
+ vae,
+ unet,
+ t2iadapter,
+ args,
+ accelerator,
+ weight_dtype,
+ global_step,
+ )
+
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+ accelerator.log(logs, step=global_step)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ # Create the pipeline using using the trained modules and save it.
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ t2iadapter = unwrap_model(t2iadapter)
+ t2iadapter.save_pretrained(args.output_dir)
+
+ if args.push_to_hub:
+ save_model_card(
+ repo_id,
+ image_logs=image_logs,
+ base_model=args.pretrained_model_name_or_path,
+ repo_folder=args.output_dir,
+ )
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ main(args)
diff --git a/diffusers/examples/test_examples_utils.py b/diffusers/examples/test_examples_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..b57a35e1f16ec4bcd401a4392f1bc16e7f0a2408
--- /dev/null
+++ b/diffusers/examples/test_examples_utils.py
@@ -0,0 +1,61 @@
+# coding=utf-8
+# Copyright 2024 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import shutil
+import subprocess
+import tempfile
+import unittest
+from typing import List
+
+from accelerate.utils import write_basic_config
+
+
+# These utils relate to ensuring the right error message is received when running scripts
+class SubprocessCallException(Exception):
+ pass
+
+
+def run_command(command: List[str], return_stdout=False):
+ """
+ Runs `command` with `subprocess.check_output` and will potentially return the `stdout`. Will also properly capture
+ if an error occurred while running `command`
+ """
+ try:
+ output = subprocess.check_output(command, stderr=subprocess.STDOUT)
+ if return_stdout:
+ if hasattr(output, "decode"):
+ output = output.decode("utf-8")
+ return output
+ except subprocess.CalledProcessError as e:
+ raise SubprocessCallException(
+ f"Command `{' '.join(command)}` failed with the following error:\n\n{e.output.decode()}"
+ ) from e
+
+
+class ExamplesTestsAccelerate(unittest.TestCase):
+ @classmethod
+ def setUpClass(cls):
+ super().setUpClass()
+ cls._tmpdir = tempfile.mkdtemp()
+ cls.configPath = os.path.join(cls._tmpdir, "default_config.yml")
+
+ write_basic_config(save_location=cls.configPath)
+ cls._launch_args = ["accelerate", "launch", "--config_file", cls.configPath]
+
+ @classmethod
+ def tearDownClass(cls):
+ super().tearDownClass()
+ shutil.rmtree(cls._tmpdir)
diff --git a/diffusers/examples/text_to_image/README.md b/diffusers/examples/text_to_image/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..0bdf02f804bd837847cb0758d13bcdc9a802f4ef
--- /dev/null
+++ b/diffusers/examples/text_to_image/README.md
@@ -0,0 +1,339 @@
+# Stable Diffusion text-to-image fine-tuning
+
+The `train_text_to_image.py` script shows how to fine-tune stable diffusion model on your own dataset.
+
+___Note___:
+
+___This script is experimental. The script fine-tunes the whole model and often times the model overfits and runs into issues like catastrophic forgetting. It's recommended to try different hyperparameters to get the best result on your dataset.___
+
+
+## Running locally with PyTorch
+### Installing the dependencies
+
+Before running the scripts, make sure to install the library's training dependencies:
+
+**Important**
+
+To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
+```bash
+git clone https://github.com/huggingface/diffusers
+cd diffusers
+pip install .
+```
+
+Then cd in the example folder and run
+```bash
+pip install -r requirements.txt
+```
+
+And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
+
+```bash
+accelerate config
+```
+
+Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.
+
+### Naruto example
+
+You need to accept the model license before downloading or using the weights. In this example we'll use model version `v1-4`, so you'll need to visit [its card](https://huggingface.co/CompVis/stable-diffusion-v1-4), read the license and tick the checkbox if you agree.
+
+You have to be a registered user in 🤗 Hugging Face Hub, and you'll also need to use an access token for the code to work. For more information on access tokens, please refer to [this section of the documentation](https://huggingface.co/docs/hub/security-tokens).
+
+Run the following command to authenticate your token
+
+```bash
+huggingface-cli login
+```
+
+If you have already cloned the repo, then you won't need to go through these steps.
+
+
+
+#### Hardware
+With `gradient_checkpointing` and `mixed_precision` it should be possible to fine tune the model on a single 24GB GPU. For higher `batch_size` and faster training it's better to use GPUs with >30GB memory.
+
+**___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.___**
+
+```bash
+export MODEL_NAME="CompVis/stable-diffusion-v1-4"
+export DATASET_NAME="lambdalabs/naruto-blip-captions"
+
+accelerate launch --mixed_precision="fp16" train_text_to_image.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --dataset_name=$DATASET_NAME \
+ --use_ema \
+ --resolution=512 --center_crop --random_flip \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=4 \
+ --gradient_checkpointing \
+ --max_train_steps=15000 \
+ --learning_rate=1e-05 \
+ --max_grad_norm=1 \
+ --lr_scheduler="constant" --lr_warmup_steps=0 \
+ --output_dir="sd-naruto-model"
+```
+
+
+
+To run on your own training files prepare the dataset according to the format required by `datasets`, you can find the instructions for how to do that in this [document](https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder-with-metadata).
+If you wish to use custom loading logic, you should modify the script, we have left pointers for that in the training script.
+
+```bash
+export MODEL_NAME="CompVis/stable-diffusion-v1-4"
+export TRAIN_DIR="path_to_your_dataset"
+
+accelerate launch --mixed_precision="fp16" train_text_to_image.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --train_data_dir=$TRAIN_DIR \
+ --use_ema \
+ --resolution=512 --center_crop --random_flip \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=4 \
+ --gradient_checkpointing \
+ --max_train_steps=15000 \
+ --learning_rate=1e-05 \
+ --max_grad_norm=1 \
+ --lr_scheduler="constant" --lr_warmup_steps=0 \
+ --output_dir="sd-naruto-model"
+```
+
+
+Once the training is finished the model will be saved in the `output_dir` specified in the command. In this example it's `sd-naruto-model`. To load the fine-tuned model for inference just pass that path to `StableDiffusionPipeline`
+
+```python
+import torch
+from diffusers import StableDiffusionPipeline
+
+model_path = "path_to_saved_model"
+pipe = StableDiffusionPipeline.from_pretrained(model_path, torch_dtype=torch.float16)
+pipe.to("cuda")
+
+image = pipe(prompt="yoda").images[0]
+image.save("yoda-naruto.png")
+```
+
+Checkpoints only save the unet, so to run inference from a checkpoint, just load the unet
+
+```python
+import torch
+from diffusers import StableDiffusionPipeline, UNet2DConditionModel
+
+model_path = "path_to_saved_model"
+unet = UNet2DConditionModel.from_pretrained(model_path + "/checkpoint-/unet", torch_dtype=torch.float16)
+
+pipe = StableDiffusionPipeline.from_pretrained("", unet=unet, torch_dtype=torch.float16)
+pipe.to("cuda")
+
+image = pipe(prompt="yoda").images[0]
+image.save("yoda-naruto.png")
+```
+
+#### Training with multiple GPUs
+
+`accelerate` allows for seamless multi-GPU training. Follow the instructions [here](https://huggingface.co/docs/accelerate/basic_tutorials/launch)
+for running distributed training with `accelerate`. Here is an example command:
+
+```bash
+export MODEL_NAME="CompVis/stable-diffusion-v1-4"
+export DATASET_NAME="lambdalabs/naruto-blip-captions"
+
+accelerate launch --mixed_precision="fp16" --multi_gpu train_text_to_image.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --dataset_name=$DATASET_NAME \
+ --use_ema \
+ --resolution=512 --center_crop --random_flip \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=4 \
+ --gradient_checkpointing \
+ --max_train_steps=15000 \
+ --learning_rate=1e-05 \
+ --max_grad_norm=1 \
+ --lr_scheduler="constant" --lr_warmup_steps=0 \
+ --output_dir="sd-naruto-model"
+```
+
+
+#### Training with Min-SNR weighting
+
+We support training with the Min-SNR weighting strategy proposed in [Efficient Diffusion Training via Min-SNR Weighting Strategy](https://arxiv.org/abs/2303.09556) which helps to achieve faster convergence
+by rebalancing the loss. In order to use it, one needs to set the `--snr_gamma` argument. The recommended
+value when using it is 5.0.
+
+You can find [this project on Weights and Biases](https://wandb.ai/sayakpaul/text2image-finetune-minsnr) that compares the loss surfaces of the following setups:
+
+* Training without the Min-SNR weighting strategy
+* Training with the Min-SNR weighting strategy (`snr_gamma` set to 5.0)
+* Training with the Min-SNR weighting strategy (`snr_gamma` set to 1.0)
+
+For our small Narutos dataset, the effects of Min-SNR weighting strategy might not appear to be pronounced, but for larger datasets, we believe the effects will be more pronounced.
+
+Also, note that in this example, we either predict `epsilon` (i.e., the noise) or the `v_prediction`. For both of these cases, the formulation of the Min-SNR weighting strategy that we have used holds.
+
+
+#### Training with EMA weights
+
+Through the `EMAModel` class, we support a convenient method of tracking an exponential moving average of model parameters. This helps to smooth out noise in model parameter updates and generally improves model performance. If enabled with the `--use_ema` argument, the final model checkpoint that is saved at the end of training will use the EMA weights.
+
+EMA weights require an additional full-precision copy of the model parameters to be stored in memory, but otherwise have very little performance overhead. `--foreach_ema` can be used to further reduce the overhead. If you are short on VRAM and still want to use EMA weights, you can store them in CPU RAM by using the `--offload_ema` argument. This will keep the EMA weights in pinned CPU memory during the training step. Then, once every model parameter update, it will transfer the EMA weights back to the GPU which can then update the parameters on the GPU, before sending them back to the CPU. Both of these transfers are set up as non-blocking, so CUDA devices should be able to overlap this transfer with other computations. With sufficient bandwidth between the host and device and a sufficiently long gap between model parameter updates, storing EMA weights in CPU RAM should have no additional performance overhead, as long as no other calls force synchronization.
+
+#### Training with DREAM
+
+We support training epsilon (noise) prediction models using the [DREAM (Diffusion Rectification and Estimation-Adaptive Models) strategy](https://arxiv.org/abs/2312.00210). DREAM claims to increase model fidelity for the performance cost of an extra grad-less unet `forward` step in the training loop. You can turn on DREAM training by using the `--dream_training` argument. The `--dream_detail_preservation` argument controls the detail preservation variable p and is the default of 1 from the paper.
+
+
+
+## Training with LoRA
+
+Low-Rank Adaption of Large Language Models was first introduced by Microsoft in [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen*.
+
+In a nutshell, LoRA allows adapting pretrained models by adding pairs of rank-decomposition matrices to existing weights and **only** training those newly added weights. This has a couple of advantages:
+
+- Previous pretrained weights are kept frozen so that model is not prone to [catastrophic forgetting](https://www.pnas.org/doi/10.1073/pnas.1611835114).
+- Rank-decomposition matrices have significantly fewer parameters than original model, which means that trained LoRA weights are easily portable.
+- LoRA attention layers allow to control to which extent the model is adapted toward new training images via a `scale` parameter.
+
+[cloneofsimo](https://github.com/cloneofsimo) was the first to try out LoRA training for Stable Diffusion in the popular [lora](https://github.com/cloneofsimo/lora) GitHub repository.
+
+With LoRA, it's possible to fine-tune Stable Diffusion on a custom image-caption pair dataset
+on consumer GPUs like Tesla T4, Tesla V100.
+
+### Training
+
+First, you need to set up your development environment as is explained in the [installation section](#installing-the-dependencies). Make sure to set the `MODEL_NAME` and `DATASET_NAME` environment variables. Here, we will use [Stable Diffusion v1-4](https://hf.co/CompVis/stable-diffusion-v1-4) and the [Narutos dataset](https://huggingface.co/datasets/lambdalabs/naruto-blip-captions).
+
+**___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.___**
+
+**___Note: It is quite useful to monitor the training progress by regularly generating sample images during training. [Weights and Biases](https://docs.wandb.ai/quickstart) is a nice solution to easily see generating images during training. All you need to do is to run `pip install wandb` before training to automatically log images.___**
+
+```bash
+export MODEL_NAME="CompVis/stable-diffusion-v1-4"
+export DATASET_NAME="lambdalabs/naruto-blip-captions"
+```
+
+For this example we want to directly store the trained LoRA embeddings on the Hub, so
+we need to be logged in and add the `--push_to_hub` flag.
+
+```bash
+huggingface-cli login
+```
+
+Now we can start training!
+
+```bash
+accelerate launch --mixed_precision="fp16" train_text_to_image_lora.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --dataset_name=$DATASET_NAME --caption_column="text" \
+ --resolution=512 --random_flip \
+ --train_batch_size=1 \
+ --num_train_epochs=100 --checkpointing_steps=5000 \
+ --learning_rate=1e-04 --lr_scheduler="constant" --lr_warmup_steps=0 \
+ --seed=42 \
+ --output_dir="sd-naruto-model-lora" \
+ --validation_prompt="cute dragon creature" --report_to="wandb"
+```
+
+The above command will also run inference as fine-tuning progresses and log the results to Weights and Biases.
+
+**___Note: When using LoRA we can use a much higher learning rate compared to non-LoRA fine-tuning. Here we use *1e-4* instead of the usual *1e-5*. Also, by using LoRA, it's possible to run `train_text_to_image_lora.py` in consumer GPUs like T4 or V100.___**
+
+The final LoRA embedding weights have been uploaded to [sayakpaul/sd-model-finetuned-lora-t4](https://huggingface.co/sayakpaul/sd-model-finetuned-lora-t4). **___Note: [The final weights](https://huggingface.co/sayakpaul/sd-model-finetuned-lora-t4/blob/main/pytorch_lora_weights.bin) are only 3 MB in size, which is orders of magnitudes smaller than the original model.___**
+
+You can check some inference samples that were logged during the course of the fine-tuning process [here](https://wandb.ai/sayakpaul/text2image-fine-tune/runs/q4lc0xsw).
+
+### Inference
+
+Once you have trained a model using above command, the inference can be done simply using the `StableDiffusionPipeline` after loading the trained LoRA weights. You
+need to pass the `output_dir` for loading the LoRA weights which, in this case, is `sd-naruto-model-lora`.
+
+```python
+from diffusers import StableDiffusionPipeline
+import torch
+
+model_path = "sayakpaul/sd-model-finetuned-lora-t4"
+pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16)
+pipe.unet.load_attn_procs(model_path)
+pipe.to("cuda")
+
+prompt = "A naruto with green eyes and red legs."
+image = pipe(prompt, num_inference_steps=30, guidance_scale=7.5).images[0]
+image.save("naruto.png")
+```
+
+If you are loading the LoRA parameters from the Hub and if the Hub repository has
+a `base_model` tag (such as [this](https://huggingface.co/sayakpaul/sd-model-finetuned-lora-t4/blob/main/README.md?code=true#L4)), then
+you can do:
+
+```py
+from huggingface_hub.repocard import RepoCard
+
+lora_model_id = "sayakpaul/sd-model-finetuned-lora-t4"
+card = RepoCard.load(lora_model_id)
+base_model_id = card.data.to_dict()["base_model"]
+
+pipe = StableDiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16)
+...
+```
+
+## Training with Flax/JAX
+
+For faster training on TPUs and GPUs you can leverage the flax training example. Follow the instructions above to get the model and dataset before running the script.
+
+**___Note: The flax example doesn't yet support features like gradient checkpoint, gradient accumulation etc, so to use flax for faster training we will need >30GB cards or TPU v3.___**
+
+
+Before running the scripts, make sure to install the library's training dependencies:
+
+```bash
+pip install -U -r requirements_flax.txt
+```
+
+```bash
+export MODEL_NAME="duongna/stable-diffusion-v1-4-flax"
+export DATASET_NAME="lambdalabs/naruto-blip-captions"
+
+python train_text_to_image_flax.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --dataset_name=$DATASET_NAME \
+ --resolution=512 --center_crop --random_flip \
+ --train_batch_size=1 \
+ --mixed_precision="fp16" \
+ --max_train_steps=15000 \
+ --learning_rate=1e-05 \
+ --max_grad_norm=1 \
+ --output_dir="sd-naruto-model"
+```
+
+To run on your own training files prepare the dataset according to the format required by `datasets`, you can find the instructions for how to do that in this [document](https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder-with-metadata).
+If you wish to use custom loading logic, you should modify the script, we have left pointers for that in the training script.
+
+```bash
+export MODEL_NAME="duongna/stable-diffusion-v1-4-flax"
+export TRAIN_DIR="path_to_your_dataset"
+
+python train_text_to_image_flax.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --train_data_dir=$TRAIN_DIR \
+ --resolution=512 --center_crop --random_flip \
+ --train_batch_size=1 \
+ --mixed_precision="fp16" \
+ --max_train_steps=15000 \
+ --learning_rate=1e-05 \
+ --max_grad_norm=1 \
+ --output_dir="sd-naruto-model"
+```
+
+### Training with xFormers:
+
+You can enable memory efficient attention by [installing xFormers](https://huggingface.co/docs/diffusers/main/en/optimization/xformers) and passing the `--enable_xformers_memory_efficient_attention` argument to the script.
+
+xFormers training is not available for Flax/JAX.
+
+**Note**:
+
+According to [this issue](https://github.com/huggingface/diffusers/issues/2234#issuecomment-1416931212), xFormers `v0.0.16` cannot be used for training in some GPUs. If you observe that problem, please install a development version as indicated in that comment.
+
+## Stable Diffusion XL
+
+* We support fine-tuning the UNet shipped in [Stable Diffusion XL](https://huggingface.co/papers/2307.01952) via the `train_text_to_image_sdxl.py` script. Please refer to the docs [here](./README_sdxl.md).
+* We also support fine-tuning of the UNet and Text Encoder shipped in [Stable Diffusion XL](https://huggingface.co/papers/2307.01952) with LoRA via the `train_text_to_image_lora_sdxl.py` script. Please refer to the docs [here](./README_sdxl.md).
diff --git a/diffusers/examples/text_to_image/README_sdxl.md b/diffusers/examples/text_to_image/README_sdxl.md
new file mode 100644
index 0000000000000000000000000000000000000000..08d82ac133f416a8075be16ba673265927e0a82a
--- /dev/null
+++ b/diffusers/examples/text_to_image/README_sdxl.md
@@ -0,0 +1,285 @@
+# Stable Diffusion XL text-to-image fine-tuning
+
+The `train_text_to_image_sdxl.py` script shows how to fine-tune Stable Diffusion XL (SDXL) on your own dataset.
+
+🚨 This script is experimental. The script fine-tunes the whole model and often times the model overfits and runs into issues like catastrophic forgetting. It's recommended to try different hyperparameters to get the best result on your dataset. 🚨
+
+## Running locally with PyTorch
+
+### Installing the dependencies
+
+Before running the scripts, make sure to install the library's training dependencies:
+
+**Important**
+
+To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
+
+```bash
+git clone https://github.com/huggingface/diffusers
+cd diffusers
+pip install -e .
+```
+
+Then cd in the `examples/text_to_image` folder and run
+```bash
+pip install -r requirements_sdxl.txt
+```
+
+And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
+
+```bash
+accelerate config
+```
+
+Or for a default accelerate configuration without answering questions about your environment
+
+```bash
+accelerate config default
+```
+
+Or if your environment doesn't support an interactive shell (e.g., a notebook)
+
+```python
+from accelerate.utils import write_basic_config
+write_basic_config()
+```
+
+When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
+Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.
+
+### Training
+
+```bash
+export MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0"
+export VAE_NAME="madebyollin/sdxl-vae-fp16-fix"
+export DATASET_NAME="lambdalabs/naruto-blip-captions"
+
+accelerate launch train_text_to_image_sdxl.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --pretrained_vae_model_name_or_path=$VAE_NAME \
+ --dataset_name=$DATASET_NAME \
+ --enable_xformers_memory_efficient_attention \
+ --resolution=512 --center_crop --random_flip \
+ --proportion_empty_prompts=0.2 \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=4 --gradient_checkpointing \
+ --max_train_steps=10000 \
+ --use_8bit_adam \
+ --learning_rate=1e-06 --lr_scheduler="constant" --lr_warmup_steps=0 \
+ --mixed_precision="fp16" \
+ --report_to="wandb" \
+ --validation_prompt="a cute Sundar Pichai creature" --validation_epochs 5 \
+ --checkpointing_steps=5000 \
+ --output_dir="sdxl-naruto-model" \
+ --push_to_hub
+```
+
+**Notes**:
+
+* The `train_text_to_image_sdxl.py` script pre-computes text embeddings and the VAE encodings and keeps them in memory. While for smaller datasets like [`lambdalabs/naruto-blip-captions`](https://hf.co/datasets/lambdalabs/naruto-blip-captions), it might not be a problem, it can definitely lead to memory problems when the script is used on a larger dataset. For those purposes, you would want to serialize these pre-computed representations to disk separately and load them during the fine-tuning process. Refer to [this PR](https://github.com/huggingface/diffusers/pull/4505) for a more in-depth discussion.
+* The training script is compute-intensive and may not run on a consumer GPU like Tesla T4.
+* The training command shown above performs intermediate quality validation in between the training epochs and logs the results to Weights and Biases. `--report_to`, `--validation_prompt`, and `--validation_epochs` are the relevant CLI arguments here.
+* SDXL's VAE is known to suffer from numerical instability issues. This is why we also expose a CLI argument namely `--pretrained_vae_model_name_or_path` that lets you specify the location of a better VAE (such as [this one](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)).
+
+### Inference
+
+```python
+from diffusers import DiffusionPipeline
+import torch
+
+model_path = "you-model-id-goes-here" # <-- change this
+pipe = DiffusionPipeline.from_pretrained(model_path, torch_dtype=torch.float16)
+pipe.to("cuda")
+
+prompt = "A naruto with green eyes and red legs."
+image = pipe(prompt, num_inference_steps=30, guidance_scale=7.5).images[0]
+image.save("naruto.png")
+```
+
+### Inference in Pytorch XLA
+```python
+from diffusers import DiffusionPipeline
+import torch
+import torch_xla.core.xla_model as xm
+
+model_id = "stabilityai/stable-diffusion-xl-base-1.0"
+pipe = DiffusionPipeline.from_pretrained(model_id)
+
+device = xm.xla_device()
+pipe.to(device)
+
+prompt = "A naruto with green eyes and red legs."
+start = time()
+image = pipe(prompt, num_inference_steps=inference_steps).images[0]
+print(f'Compilation time is {time()-start} sec')
+image.save("naruto.png")
+
+start = time()
+image = pipe(prompt, num_inference_steps=inference_steps).images[0]
+print(f'Inference time is {time()-start} sec after compilation')
+```
+
+Note: There is a warmup step in PyTorch XLA. This takes longer because of
+compilation and optimization. To see the real benefits of Pytorch XLA and
+speedup, we need to call the pipe again on the input with the same length
+as the original prompt to reuse the optimized graph and get the performance
+boost.
+
+## LoRA training example for Stable Diffusion XL (SDXL)
+
+Low-Rank Adaption of Large Language Models was first introduced by Microsoft in [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen*.
+
+In a nutshell, LoRA allows adapting pretrained models by adding pairs of rank-decomposition matrices to existing weights and **only** training those newly added weights. This has a couple of advantages:
+
+- Previous pretrained weights are kept frozen so that model is not prone to [catastrophic forgetting](https://www.pnas.org/doi/10.1073/pnas.1611835114).
+- Rank-decomposition matrices have significantly fewer parameters than original model, which means that trained LoRA weights are easily portable.
+- LoRA attention layers allow to control to which extent the model is adapted toward new training images via a `scale` parameter.
+
+[cloneofsimo](https://github.com/cloneofsimo) was the first to try out LoRA training for Stable Diffusion in the popular [lora](https://github.com/cloneofsimo/lora) GitHub repository.
+
+With LoRA, it's possible to fine-tune Stable Diffusion on a custom image-caption pair dataset
+on consumer GPUs like Tesla T4, Tesla V100.
+
+### Training
+
+First, you need to set up your development environment as is explained in the [installation section](#installing-the-dependencies). Make sure to set the `MODEL_NAME` and `DATASET_NAME` environment variables and, optionally, the `VAE_NAME` variable. Here, we will use [Stable Diffusion XL 1.0-base](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) and the [Narutos dataset](https://huggingface.co/datasets/lambdalabs/naruto-blip-captions).
+
+**___Note: It is quite useful to monitor the training progress by regularly generating sample images during training. [Weights and Biases](https://docs.wandb.ai/quickstart) is a nice solution to easily see generating images during training. All you need to do is to run `pip install wandb` before training to automatically log images.___**
+
+```bash
+export MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0"
+export VAE_NAME="madebyollin/sdxl-vae-fp16-fix"
+export DATASET_NAME="lambdalabs/naruto-blip-captions"
+```
+
+For this example we want to directly store the trained LoRA embeddings on the Hub, so
+we need to be logged in and add the `--push_to_hub` flag.
+
+```bash
+huggingface-cli login
+```
+
+Now we can start training!
+
+```bash
+accelerate launch train_text_to_image_lora_sdxl.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --pretrained_vae_model_name_or_path=$VAE_NAME \
+ --dataset_name=$DATASET_NAME --caption_column="text" \
+ --resolution=1024 --random_flip \
+ --train_batch_size=1 \
+ --num_train_epochs=2 --checkpointing_steps=500 \
+ --learning_rate=1e-04 --lr_scheduler="constant" --lr_warmup_steps=0 \
+ --mixed_precision="fp16" \
+ --seed=42 \
+ --output_dir="sd-naruto-model-lora-sdxl" \
+ --validation_prompt="cute dragon creature" --report_to="wandb" \
+ --push_to_hub
+```
+
+The above command will also run inference as fine-tuning progresses and log the results to Weights and Biases.
+
+**Notes**:
+
+* SDXL's VAE is known to suffer from numerical instability issues. This is why we also expose a CLI argument namely `--pretrained_vae_model_name_or_path` that lets you specify the location of a better VAE (such as [this one](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)).
+
+
+### Using DeepSpeed
+Using DeepSpeed one can reduce the consumption of GPU memory, enabling the training of models on GPUs with smaller memory sizes. DeepSpeed is capable of offloading model parameters to the machine's memory, or it can distribute parameters, gradients, and optimizer states across multiple GPUs. This allows for the training of larger models under the same hardware configuration.
+
+First, you need to use the `accelerate config` command to choose to use DeepSpeed, or manually use the accelerate config file to set up DeepSpeed.
+
+Here is an example of a config file for using DeepSpeed. For more detailed explanations of the configuration, you can refer to this [link](https://huggingface.co/docs/accelerate/usage_guides/deepspeed).
+```yaml
+compute_environment: LOCAL_MACHINE
+debug: true
+deepspeed_config:
+ gradient_accumulation_steps: 1
+ gradient_clipping: 1.0
+ offload_optimizer_device: none
+ offload_param_device: none
+ zero3_init_flag: false
+ zero_stage: 2
+distributed_type: DEEPSPEED
+downcast_bf16: 'no'
+machine_rank: 0
+main_training_function: main
+mixed_precision: fp16
+num_machines: 1
+num_processes: 1
+rdzv_backend: static
+same_network: true
+tpu_env: []
+tpu_use_cluster: false
+tpu_use_sudo: false
+use_cpu: false
+```
+You need to save the mentioned configuration as an `accelerate_config.yaml` file. Then, you need to input the path of your `accelerate_config.yaml` file into the `ACCELERATE_CONFIG_FILE` parameter. This way you can use DeepSpeed to train your SDXL model in LoRA. Additionally, you can use DeepSpeed to train other SD models in this way.
+
+```shell
+export MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0"
+export VAE_NAME="madebyollin/sdxl-vae-fp16-fix"
+export DATASET_NAME="lambdalabs/naruto-blip-captions"
+export ACCELERATE_CONFIG_FILE="your accelerate_config.yaml"
+
+accelerate launch --config_file $ACCELERATE_CONFIG_FILE train_text_to_image_lora_sdxl.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --pretrained_vae_model_name_or_path=$VAE_NAME \
+ --dataset_name=$DATASET_NAME --caption_column="text" \
+ --resolution=1024 \
+ --train_batch_size=1 \
+ --num_train_epochs=2 \
+ --checkpointing_steps=2 \
+ --learning_rate=1e-04 \
+ --lr_scheduler="constant" \
+ --lr_warmup_steps=0 \
+ --mixed_precision="fp16" \
+ --max_train_steps=20 \
+ --validation_epochs=20 \
+ --seed=1234 \
+ --output_dir="sd-naruto-model-lora-sdxl" \
+ --validation_prompt="cute dragon creature"
+```
+
+
+### Finetuning the text encoder and UNet
+
+The script also allows you to finetune the `text_encoder` along with the `unet`.
+
+🚨 Training the text encoder requires additional memory.
+
+Pass the `--train_text_encoder` argument to the training script to enable finetuning the `text_encoder` and `unet`:
+
+```bash
+accelerate launch train_text_to_image_lora_sdxl.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --dataset_name=$DATASET_NAME --caption_column="text" \
+ --resolution=1024 --random_flip \
+ --train_batch_size=1 \
+ --num_train_epochs=2 --checkpointing_steps=500 \
+ --learning_rate=1e-04 --lr_scheduler="constant" --lr_warmup_steps=0 \
+ --seed=42 \
+ --output_dir="sd-naruto-model-lora-sdxl-txt" \
+ --train_text_encoder \
+ --validation_prompt="cute dragon creature" --report_to="wandb" \
+ --push_to_hub
+```
+
+### Inference
+
+Once you have trained a model using above command, the inference can be done simply using the `DiffusionPipeline` after loading the trained LoRA weights. You
+need to pass the `output_dir` for loading the LoRA weights which, in this case, is `sd-naruto-model-lora-sdxl`.
+
+```python
+from diffusers import DiffusionPipeline
+import torch
+
+model_path = "takuoko/sd-naruto-model-lora-sdxl"
+pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16)
+pipe.to("cuda")
+pipe.load_lora_weights(model_path)
+
+prompt = "A naruto with green eyes and red legs."
+image = pipe(prompt, num_inference_steps=30, guidance_scale=7.5).images[0]
+image.save("naruto.png")
+```
diff --git a/diffusers/examples/text_to_image/requirements.txt b/diffusers/examples/text_to_image/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..c3ffa42f0edc9862345ab82b69ffb5cdfaeac45e
--- /dev/null
+++ b/diffusers/examples/text_to_image/requirements.txt
@@ -0,0 +1,8 @@
+accelerate>=0.16.0
+torchvision
+transformers>=4.25.1
+datasets>=2.19.1
+ftfy
+tensorboard
+Jinja2
+peft==0.7.0
diff --git a/diffusers/examples/text_to_image/requirements_flax.txt b/diffusers/examples/text_to_image/requirements_flax.txt
new file mode 100644
index 0000000000000000000000000000000000000000..b6eb64e254625ee8eff2ef126d67adfd5b6994dc
--- /dev/null
+++ b/diffusers/examples/text_to_image/requirements_flax.txt
@@ -0,0 +1,9 @@
+transformers>=4.25.1
+datasets
+flax
+optax
+torch
+torchvision
+ftfy
+tensorboard
+Jinja2
diff --git a/diffusers/examples/text_to_image/requirements_sdxl.txt b/diffusers/examples/text_to_image/requirements_sdxl.txt
new file mode 100644
index 0000000000000000000000000000000000000000..64cbc9205fd08b960303dfc611214ae6dc48dd7c
--- /dev/null
+++ b/diffusers/examples/text_to_image/requirements_sdxl.txt
@@ -0,0 +1,8 @@
+accelerate>=0.22.0
+torchvision
+transformers>=4.25.1
+ftfy
+tensorboard
+Jinja2
+datasets
+peft==0.7.0
\ No newline at end of file
diff --git a/diffusers/examples/text_to_image/test_text_to_image.py b/diffusers/examples/text_to_image/test_text_to_image.py
new file mode 100644
index 0000000000000000000000000000000000000000..6231a89b1d1d9c16ab34281b223d7365791eab67
--- /dev/null
+++ b/diffusers/examples/text_to_image/test_text_to_image.py
@@ -0,0 +1,365 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import os
+import shutil
+import sys
+import tempfile
+
+from diffusers import DiffusionPipeline, UNet2DConditionModel # noqa: E402
+
+
+sys.path.append("..")
+from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
+
+
+logging.basicConfig(level=logging.DEBUG)
+
+logger = logging.getLogger()
+stream_handler = logging.StreamHandler(sys.stdout)
+logger.addHandler(stream_handler)
+
+
+class TextToImage(ExamplesTestsAccelerate):
+ def test_text_to_image(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/text_to_image/train_text_to_image.py
+ --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe
+ --dataset_name hf-internal-testing/dummy_image_text_data
+ --resolution 64
+ --center_crop
+ --random_flip
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ """.split()
+
+ run_command(self._launch_args + test_args)
+ # save_pretrained smoke test
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.safetensors")))
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json")))
+
+ def test_text_to_image_checkpointing(self):
+ pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe"
+ prompt = "a prompt"
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ # Run training script with checkpointing
+ # max_train_steps == 4, checkpointing_steps == 2
+ # Should create checkpoints at steps 2, 4
+
+ initial_run_args = f"""
+ examples/text_to_image/train_text_to_image.py
+ --pretrained_model_name_or_path {pretrained_model_name_or_path}
+ --dataset_name hf-internal-testing/dummy_image_text_data
+ --resolution 64
+ --center_crop
+ --random_flip
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 4
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --checkpointing_steps=2
+ --seed=0
+ """.split()
+
+ run_command(self._launch_args + initial_run_args)
+
+ pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)
+ pipe(prompt, num_inference_steps=1)
+
+ # check checkpoint directories exist
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-2", "checkpoint-4"},
+ )
+
+ # check can run an intermediate checkpoint
+ unet = UNet2DConditionModel.from_pretrained(tmpdir, subfolder="checkpoint-2/unet")
+ pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, unet=unet, safety_checker=None)
+ pipe(prompt, num_inference_steps=1)
+
+ # Remove checkpoint 2 so that we can check only later checkpoints exist after resuming
+ shutil.rmtree(os.path.join(tmpdir, "checkpoint-2"))
+
+ # Run training script for 2 total steps resuming from checkpoint 4
+
+ resume_run_args = f"""
+ examples/text_to_image/train_text_to_image.py
+ --pretrained_model_name_or_path {pretrained_model_name_or_path}
+ --dataset_name hf-internal-testing/dummy_image_text_data
+ --resolution 64
+ --center_crop
+ --random_flip
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --checkpointing_steps=1
+ --resume_from_checkpoint=checkpoint-4
+ --seed=0
+ """.split()
+
+ run_command(self._launch_args + resume_run_args)
+
+ # check can run new fully trained pipeline
+ pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)
+ pipe(prompt, num_inference_steps=1)
+
+ # no checkpoint-2 -> check old checkpoints do not exist
+ # check new checkpoints exist
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-4", "checkpoint-5"},
+ )
+
+ def test_text_to_image_checkpointing_use_ema(self):
+ pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe"
+ prompt = "a prompt"
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ # Run training script with checkpointing
+ # max_train_steps == 4, checkpointing_steps == 2
+ # Should create checkpoints at steps 2, 4
+
+ initial_run_args = f"""
+ examples/text_to_image/train_text_to_image.py
+ --pretrained_model_name_or_path {pretrained_model_name_or_path}
+ --dataset_name hf-internal-testing/dummy_image_text_data
+ --resolution 64
+ --center_crop
+ --random_flip
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 4
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --checkpointing_steps=2
+ --use_ema
+ --seed=0
+ """.split()
+
+ run_command(self._launch_args + initial_run_args)
+
+ pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)
+ pipe(prompt, num_inference_steps=2)
+
+ # check checkpoint directories exist
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-2", "checkpoint-4"},
+ )
+
+ # check can run an intermediate checkpoint
+ unet = UNet2DConditionModel.from_pretrained(tmpdir, subfolder="checkpoint-2/unet")
+ pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, unet=unet, safety_checker=None)
+ pipe(prompt, num_inference_steps=1)
+
+ # Remove checkpoint 2 so that we can check only later checkpoints exist after resuming
+ shutil.rmtree(os.path.join(tmpdir, "checkpoint-2"))
+
+ # Run training script for 2 total steps resuming from checkpoint 4
+
+ resume_run_args = f"""
+ examples/text_to_image/train_text_to_image.py
+ --pretrained_model_name_or_path {pretrained_model_name_or_path}
+ --dataset_name hf-internal-testing/dummy_image_text_data
+ --resolution 64
+ --center_crop
+ --random_flip
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --checkpointing_steps=1
+ --resume_from_checkpoint=checkpoint-4
+ --use_ema
+ --seed=0
+ """.split()
+
+ run_command(self._launch_args + resume_run_args)
+
+ # check can run new fully trained pipeline
+ pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)
+ pipe(prompt, num_inference_steps=1)
+
+ # no checkpoint-2 -> check old checkpoints do not exist
+ # check new checkpoints exist
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-4", "checkpoint-5"},
+ )
+
+ def test_text_to_image_checkpointing_checkpoints_total_limit(self):
+ pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe"
+ prompt = "a prompt"
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ # Run training script with checkpointing
+ # max_train_steps == 6, checkpointing_steps == 2, checkpoints_total_limit == 2
+ # Should create checkpoints at steps 2, 4, 6
+ # with checkpoint at step 2 deleted
+
+ initial_run_args = f"""
+ examples/text_to_image/train_text_to_image.py
+ --pretrained_model_name_or_path {pretrained_model_name_or_path}
+ --dataset_name hf-internal-testing/dummy_image_text_data
+ --resolution 64
+ --center_crop
+ --random_flip
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 6
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --checkpointing_steps=2
+ --checkpoints_total_limit=2
+ --seed=0
+ """.split()
+
+ run_command(self._launch_args + initial_run_args)
+
+ pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)
+ pipe(prompt, num_inference_steps=1)
+
+ # check checkpoint directories exist
+ # checkpoint-2 should have been deleted
+ self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-4", "checkpoint-6"})
+
+ def test_text_to_image_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
+ pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe"
+ prompt = "a prompt"
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ # Run training script with checkpointing
+ # max_train_steps == 4, checkpointing_steps == 2
+ # Should create checkpoints at steps 2, 4
+
+ initial_run_args = f"""
+ examples/text_to_image/train_text_to_image.py
+ --pretrained_model_name_or_path {pretrained_model_name_or_path}
+ --dataset_name hf-internal-testing/dummy_image_text_data
+ --resolution 64
+ --center_crop
+ --random_flip
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 4
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --checkpointing_steps=2
+ --seed=0
+ """.split()
+
+ run_command(self._launch_args + initial_run_args)
+
+ pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)
+ pipe(prompt, num_inference_steps=1)
+
+ # check checkpoint directories exist
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-2", "checkpoint-4"},
+ )
+
+ # resume and we should try to checkpoint at 6, where we'll have to remove
+ # checkpoint-2 and checkpoint-4 instead of just a single previous checkpoint
+
+ resume_run_args = f"""
+ examples/text_to_image/train_text_to_image.py
+ --pretrained_model_name_or_path {pretrained_model_name_or_path}
+ --dataset_name hf-internal-testing/dummy_image_text_data
+ --resolution 64
+ --center_crop
+ --random_flip
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 8
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --checkpointing_steps=2
+ --resume_from_checkpoint=checkpoint-4
+ --checkpoints_total_limit=2
+ --seed=0
+ """.split()
+
+ run_command(self._launch_args + resume_run_args)
+
+ pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)
+ pipe(prompt, num_inference_steps=1)
+
+ # check checkpoint directories exist
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-6", "checkpoint-8"},
+ )
+
+
+class TextToImageSDXL(ExamplesTestsAccelerate):
+ def test_text_to_image_sdxl(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/text_to_image/train_text_to_image_sdxl.py
+ --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-xl-pipe
+ --dataset_name hf-internal-testing/dummy_image_text_data
+ --resolution 64
+ --center_crop
+ --random_flip
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ """.split()
+
+ run_command(self._launch_args + test_args)
+ # save_pretrained smoke test
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.safetensors")))
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json")))
diff --git a/diffusers/examples/text_to_image/test_text_to_image_lora.py b/diffusers/examples/text_to_image/test_text_to_image_lora.py
new file mode 100644
index 0000000000000000000000000000000000000000..4604b9f5210ca3e2040cababbd4b7d03b6a18bbb
--- /dev/null
+++ b/diffusers/examples/text_to_image/test_text_to_image_lora.py
@@ -0,0 +1,300 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import os
+import sys
+import tempfile
+
+import safetensors
+
+from diffusers import DiffusionPipeline # noqa: E402
+
+
+sys.path.append("..")
+from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
+
+
+logging.basicConfig(level=logging.DEBUG)
+
+logger = logging.getLogger()
+stream_handler = logging.StreamHandler(sys.stdout)
+logger.addHandler(stream_handler)
+
+
+class TextToImageLoRA(ExamplesTestsAccelerate):
+ def test_text_to_image_lora_sdxl_checkpointing_checkpoints_total_limit(self):
+ prompt = "a prompt"
+ pipeline_path = "hf-internal-testing/tiny-stable-diffusion-xl-pipe"
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ # Run training script with checkpointing
+ # max_train_steps == 6, checkpointing_steps == 2, checkpoints_total_limit == 2
+ # Should create checkpoints at steps 2, 4, 6
+ # with checkpoint at step 2 deleted
+
+ initial_run_args = f"""
+ examples/text_to_image/train_text_to_image_lora_sdxl.py
+ --pretrained_model_name_or_path {pipeline_path}
+ --dataset_name hf-internal-testing/dummy_image_text_data
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 6
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --checkpointing_steps=2
+ --checkpoints_total_limit=2
+ """.split()
+
+ run_command(self._launch_args + initial_run_args)
+
+ pipe = DiffusionPipeline.from_pretrained(pipeline_path)
+ pipe.load_lora_weights(tmpdir)
+ pipe(prompt, num_inference_steps=1)
+
+ # check checkpoint directories exist
+ # checkpoint-2 should have been deleted
+ self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-4", "checkpoint-6"})
+
+ def test_text_to_image_lora_checkpointing_checkpoints_total_limit(self):
+ pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe"
+ prompt = "a prompt"
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ # Run training script with checkpointing
+ # max_train_steps == 6, checkpointing_steps == 2, checkpoints_total_limit == 2
+ # Should create checkpoints at steps 2, 4, 6
+ # with checkpoint at step 2 deleted
+
+ initial_run_args = f"""
+ examples/text_to_image/train_text_to_image_lora.py
+ --pretrained_model_name_or_path {pretrained_model_name_or_path}
+ --dataset_name hf-internal-testing/dummy_image_text_data
+ --resolution 64
+ --center_crop
+ --random_flip
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 6
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --checkpointing_steps=2
+ --checkpoints_total_limit=2
+ --seed=0
+ --num_validation_images=0
+ """.split()
+
+ run_command(self._launch_args + initial_run_args)
+
+ pipe = DiffusionPipeline.from_pretrained(
+ "hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None
+ )
+ pipe.load_lora_weights(tmpdir)
+ pipe(prompt, num_inference_steps=1)
+
+ # check checkpoint directories exist
+ # checkpoint-2 should have been deleted
+ self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-4", "checkpoint-6"})
+
+ def test_text_to_image_lora_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
+ pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe"
+ prompt = "a prompt"
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ # Run training script with checkpointing
+ # max_train_steps == 4, checkpointing_steps == 2
+ # Should create checkpoints at steps 2, 4
+
+ initial_run_args = f"""
+ examples/text_to_image/train_text_to_image_lora.py
+ --pretrained_model_name_or_path {pretrained_model_name_or_path}
+ --dataset_name hf-internal-testing/dummy_image_text_data
+ --resolution 64
+ --center_crop
+ --random_flip
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 4
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --checkpointing_steps=2
+ --seed=0
+ --num_validation_images=0
+ """.split()
+
+ run_command(self._launch_args + initial_run_args)
+
+ pipe = DiffusionPipeline.from_pretrained(
+ "hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None
+ )
+ pipe.load_lora_weights(tmpdir)
+ pipe(prompt, num_inference_steps=1)
+
+ # check checkpoint directories exist
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-2", "checkpoint-4"},
+ )
+
+ # resume and we should try to checkpoint at 6, where we'll have to remove
+ # checkpoint-2 and checkpoint-4 instead of just a single previous checkpoint
+
+ resume_run_args = f"""
+ examples/text_to_image/train_text_to_image_lora.py
+ --pretrained_model_name_or_path {pretrained_model_name_or_path}
+ --dataset_name hf-internal-testing/dummy_image_text_data
+ --resolution 64
+ --center_crop
+ --random_flip
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 8
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --checkpointing_steps=2
+ --resume_from_checkpoint=checkpoint-4
+ --checkpoints_total_limit=2
+ --seed=0
+ --num_validation_images=0
+ """.split()
+
+ run_command(self._launch_args + resume_run_args)
+
+ pipe = DiffusionPipeline.from_pretrained(
+ "hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None
+ )
+ pipe.load_lora_weights(tmpdir)
+ pipe(prompt, num_inference_steps=1)
+
+ # check checkpoint directories exist
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-6", "checkpoint-8"},
+ )
+
+
+class TextToImageLoRASDXL(ExamplesTestsAccelerate):
+ def test_text_to_image_lora_sdxl(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/text_to_image/train_text_to_image_lora_sdxl.py
+ --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-xl-pipe
+ --dataset_name hf-internal-testing/dummy_image_text_data
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ """.split()
+
+ run_command(self._launch_args + test_args)
+ # save_pretrained smoke test
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
+
+ # make sure the state_dict has the correct naming in the parameters.
+ lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
+ is_lora = all("lora" in k for k in lora_state_dict.keys())
+ self.assertTrue(is_lora)
+
+ def test_text_to_image_lora_sdxl_with_text_encoder(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/text_to_image/train_text_to_image_lora_sdxl.py
+ --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-xl-pipe
+ --dataset_name hf-internal-testing/dummy_image_text_data
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --train_text_encoder
+ """.split()
+
+ run_command(self._launch_args + test_args)
+ # save_pretrained smoke test
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
+
+ # make sure the state_dict has the correct naming in the parameters.
+ lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
+ is_lora = all("lora" in k for k in lora_state_dict.keys())
+ self.assertTrue(is_lora)
+
+ # when not training the text encoder, all the parameters in the state dict should start
+ # with `"unet"` or `"text_encoder"` or `"text_encoder_2"` in their names.
+ keys = lora_state_dict.keys()
+ starts_with_unet = all(
+ k.startswith("unet") or k.startswith("text_encoder") or k.startswith("text_encoder_2") for k in keys
+ )
+ self.assertTrue(starts_with_unet)
+
+ def test_text_to_image_lora_sdxl_text_encoder_checkpointing_checkpoints_total_limit(self):
+ prompt = "a prompt"
+ pipeline_path = "hf-internal-testing/tiny-stable-diffusion-xl-pipe"
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ # Run training script with checkpointing
+ # max_train_steps == 6, checkpointing_steps == 2, checkpoints_total_limit == 2
+ # Should create checkpoints at steps 2, 4, 6
+ # with checkpoint at step 2 deleted
+
+ initial_run_args = f"""
+ examples/text_to_image/train_text_to_image_lora_sdxl.py
+ --pretrained_model_name_or_path {pipeline_path}
+ --dataset_name hf-internal-testing/dummy_image_text_data
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 6
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --train_text_encoder
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --checkpointing_steps=2
+ --checkpoints_total_limit=2
+ """.split()
+
+ run_command(self._launch_args + initial_run_args)
+
+ pipe = DiffusionPipeline.from_pretrained(pipeline_path)
+ pipe.load_lora_weights(tmpdir)
+ pipe(prompt, num_inference_steps=1)
+
+ # check checkpoint directories exist
+ # checkpoint-2 should have been deleted
+ self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-4", "checkpoint-6"})
diff --git a/diffusers/examples/text_to_image/train_text_to_image.py b/diffusers/examples/text_to_image/train_text_to_image.py
new file mode 100644
index 0000000000000000000000000000000000000000..684bf352a6c1abf13a622be88d2444e2e6d7436d
--- /dev/null
+++ b/diffusers/examples/text_to_image/train_text_to_image.py
@@ -0,0 +1,1153 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import logging
+import math
+import os
+import random
+import shutil
+from contextlib import nullcontext
+from pathlib import Path
+
+import accelerate
+import datasets
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.state import AcceleratorState
+from accelerate.utils import ProjectConfiguration, set_seed
+from datasets import load_dataset
+from huggingface_hub import create_repo, upload_folder
+from packaging import version
+from torchvision import transforms
+from tqdm.auto import tqdm
+from transformers import CLIPTextModel, CLIPTokenizer
+from transformers.utils import ContextManagers
+
+import diffusers
+from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
+from diffusers.optimization import get_scheduler
+from diffusers.training_utils import EMAModel, compute_dream_and_update_latents, compute_snr
+from diffusers.utils import check_min_version, deprecate, is_wandb_available, make_image_grid
+from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
+from diffusers.utils.import_utils import is_xformers_available
+from diffusers.utils.torch_utils import is_compiled_module
+
+
+if is_wandb_available():
+ import wandb
+
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.31.0.dev0")
+
+logger = get_logger(__name__, log_level="INFO")
+
+DATASET_NAME_MAPPING = {
+ "lambdalabs/naruto-blip-captions": ("image", "text"),
+}
+
+
+def save_model_card(
+ args,
+ repo_id: str,
+ images: list = None,
+ repo_folder: str = None,
+):
+ img_str = ""
+ if len(images) > 0:
+ image_grid = make_image_grid(images, 1, len(args.validation_prompts))
+ image_grid.save(os.path.join(repo_folder, "val_imgs_grid.png"))
+ img_str += "\n"
+
+ model_description = f"""
+# Text-to-image finetuning - {repo_id}
+
+This pipeline was finetuned from **{args.pretrained_model_name_or_path}** on the **{args.dataset_name}** dataset. Below are some example images generated with the finetuned pipeline using the following prompts: {args.validation_prompts}: \n
+{img_str}
+
+## Pipeline usage
+
+You can use the pipeline like so:
+
+```python
+from diffusers import DiffusionPipeline
+import torch
+
+pipeline = DiffusionPipeline.from_pretrained("{repo_id}", torch_dtype=torch.float16)
+prompt = "{args.validation_prompts[0]}"
+image = pipeline(prompt).images[0]
+image.save("my_image.png")
+```
+
+## Training info
+
+These are the key hyperparameters used during training:
+
+* Epochs: {args.num_train_epochs}
+* Learning rate: {args.learning_rate}
+* Batch size: {args.train_batch_size}
+* Gradient accumulation steps: {args.gradient_accumulation_steps}
+* Image resolution: {args.resolution}
+* Mixed-precision: {args.mixed_precision}
+
+"""
+ wandb_info = ""
+ if is_wandb_available():
+ wandb_run_url = None
+ if wandb.run is not None:
+ wandb_run_url = wandb.run.url
+
+ if wandb_run_url is not None:
+ wandb_info = f"""
+More information on all the CLI arguments and the environment are available on your [`wandb` run page]({wandb_run_url}).
+"""
+
+ model_description += wandb_info
+
+ model_card = load_or_create_model_card(
+ repo_id_or_path=repo_id,
+ from_training=True,
+ license="creativeml-openrail-m",
+ base_model=args.pretrained_model_name_or_path,
+ model_description=model_description,
+ inference=True,
+ )
+
+ tags = ["stable-diffusion", "stable-diffusion-diffusers", "text-to-image", "diffusers", "diffusers-training"]
+ model_card = populate_model_card(model_card, tags=tags)
+
+ model_card.save(os.path.join(repo_folder, "README.md"))
+
+
+def log_validation(vae, text_encoder, tokenizer, unet, args, accelerator, weight_dtype, epoch):
+ logger.info("Running validation... ")
+
+ pipeline = StableDiffusionPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ vae=accelerator.unwrap_model(vae),
+ text_encoder=accelerator.unwrap_model(text_encoder),
+ tokenizer=tokenizer,
+ unet=accelerator.unwrap_model(unet),
+ safety_checker=None,
+ revision=args.revision,
+ variant=args.variant,
+ torch_dtype=weight_dtype,
+ )
+ pipeline = pipeline.to(accelerator.device)
+ pipeline.set_progress_bar_config(disable=True)
+
+ if args.enable_xformers_memory_efficient_attention:
+ pipeline.enable_xformers_memory_efficient_attention()
+
+ if args.seed is None:
+ generator = None
+ else:
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
+
+ images = []
+ for i in range(len(args.validation_prompts)):
+ if torch.backends.mps.is_available():
+ autocast_ctx = nullcontext()
+ else:
+ autocast_ctx = torch.autocast(accelerator.device.type)
+
+ with autocast_ctx:
+ image = pipeline(args.validation_prompts[i], num_inference_steps=20, generator=generator).images[0]
+
+ images.append(image)
+
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
+ elif tracker.name == "wandb":
+ tracker.log(
+ {
+ "validation": [
+ wandb.Image(image, caption=f"{i}: {args.validation_prompts[i]}")
+ for i, image in enumerate(images)
+ ]
+ }
+ )
+ else:
+ logger.warning(f"image logging not implemented for {tracker.name}")
+
+ del pipeline
+ torch.cuda.empty_cache()
+
+ return images
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ parser.add_argument(
+ "--input_perturbation", type=float, default=0, help="The scale of input perturbation. Recommended 0.1."
+ )
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--variant",
+ type=str,
+ default=None,
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
+ )
+ parser.add_argument(
+ "--dataset_name",
+ type=str,
+ default=None,
+ help=(
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
+ " or to a folder containing files that 🤗 Datasets can understand."
+ ),
+ )
+ parser.add_argument(
+ "--dataset_config_name",
+ type=str,
+ default=None,
+ help="The config of the Dataset, leave as None if there's only one config.",
+ )
+ parser.add_argument(
+ "--train_data_dir",
+ type=str,
+ default=None,
+ help=(
+ "A folder containing the training data. Folder contents must follow the structure described in"
+ " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
+ " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
+ ),
+ )
+ parser.add_argument(
+ "--image_column", type=str, default="image", help="The column of the dataset containing an image."
+ )
+ parser.add_argument(
+ "--caption_column",
+ type=str,
+ default="text",
+ help="The column of the dataset containing a caption or a list of captions.",
+ )
+ parser.add_argument(
+ "--max_train_samples",
+ type=int,
+ default=None,
+ help=(
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ ),
+ )
+ parser.add_argument(
+ "--validation_prompts",
+ type=str,
+ default=None,
+ nargs="+",
+ help=("A set of prompts evaluated every `--validation_epochs` and logged to `--report_to`."),
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="sd-model-finetuned",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--center_crop",
+ default=False,
+ action="store_true",
+ help=(
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
+ " cropped. The images will be resized to the resolution first before cropping."
+ ),
+ )
+ parser.add_argument(
+ "--random_flip",
+ action="store_true",
+ help="whether to randomly flip images horizontally",
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=100)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--snr_gamma",
+ type=float,
+ default=None,
+ help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
+ "More details here: https://arxiv.org/abs/2303.09556.",
+ )
+ parser.add_argument(
+ "--dream_training",
+ action="store_true",
+ help=(
+ "Use the DREAM training method, which makes training more efficient and accurate at the ",
+ "expense of doing an extra forward pass. See: https://arxiv.org/abs/2312.00210",
+ ),
+ )
+ parser.add_argument(
+ "--dream_detail_preservation",
+ type=float,
+ default=1.0,
+ help="Dream detail preservation factor p (should be greater than 0; default=1.0, as suggested in the paper)",
+ )
+ parser.add_argument(
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")
+ parser.add_argument("--offload_ema", action="store_true", help="Offload EMA model to CPU during training step.")
+ parser.add_argument("--foreach_ema", action="store_true", help="Use faster foreach implementation of EMAModel.")
+ parser.add_argument(
+ "--non_ema_revision",
+ type=str,
+ default=None,
+ required=False,
+ help=(
+ "Revision of pretrained non-ema model identifier. Must be a branch, tag or git identifier of the local or"
+ " remote repository specified with --pretrained_model_name_or_path."
+ ),
+ )
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--prediction_type",
+ type=str,
+ default=None,
+ help="The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediction_type` is chosen.",
+ )
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
+ )
+ parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.")
+ parser.add_argument(
+ "--validation_epochs",
+ type=int,
+ default=5,
+ help="Run validation every X epochs.",
+ )
+ parser.add_argument(
+ "--tracker_project_name",
+ type=str,
+ default="text2image-fine-tune",
+ help=(
+ "The `project_name` argument passed to Accelerator.init_trackers for"
+ " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
+ ),
+ )
+
+ args = parser.parse_args()
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ # Sanity checks
+ if args.dataset_name is None and args.train_data_dir is None:
+ raise ValueError("Need either a dataset name or a training folder.")
+
+ # default to using the same revision for the non-ema model if not specified
+ if args.non_ema_revision is None:
+ args.non_ema_revision = args.revision
+
+ return args
+
+
+def main():
+ args = parse_args()
+
+ if args.report_to == "wandb" and args.hub_token is not None:
+ raise ValueError(
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
+ " Please use `huggingface-cli login` to authenticate with the Hub."
+ )
+
+ if args.non_ema_revision is not None:
+ deprecate(
+ "non_ema_revision!=None",
+ "0.15.0",
+ message=(
+ "Downloading 'non_ema' weights from revision branches of the Hub is deprecated. Please make sure to"
+ " use `--variant=non_ema` instead."
+ ),
+ )
+ logging_dir = os.path.join(args.output_dir, args.logging_dir)
+
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ )
+
+ # Disable AMP for MPS.
+ if torch.backends.mps.is_available():
+ accelerator.native_amp = False
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ datasets.utils.logging.set_verbosity_warning()
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ datasets.utils.logging.set_verbosity_error()
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
+ ).repo_id
+
+ # Load scheduler, tokenizer and models.
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
+ tokenizer = CLIPTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
+ )
+
+ def deepspeed_zero_init_disabled_context_manager():
+ """
+ returns either a context list that includes one that will disable zero.Init or an empty context list
+ """
+ deepspeed_plugin = AcceleratorState().deepspeed_plugin if accelerate.state.is_initialized() else None
+ if deepspeed_plugin is None:
+ return []
+
+ return [deepspeed_plugin.zero3_init_context_manager(enable=False)]
+
+ # Currently Accelerate doesn't know how to handle multiple models under Deepspeed ZeRO stage 3.
+ # For this to work properly all models must be run through `accelerate.prepare`. But accelerate
+ # will try to assign the same optimizer with the same weights to all models during
+ # `deepspeed.initialize`, which of course doesn't work.
+ #
+ # For now the following workaround will partially support Deepspeed ZeRO-3, by excluding the 2
+ # frozen models from being partitioned during `zero.Init` which gets called during
+ # `from_pretrained` So CLIPTextModel and AutoencoderKL will not enjoy the parameter sharding
+ # across multiple gpus and only UNet2DConditionModel will get ZeRO sharded.
+ with ContextManagers(deepspeed_zero_init_disabled_context_manager()):
+ text_encoder = CLIPTextModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
+ )
+ vae = AutoencoderKL.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant
+ )
+
+ unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.non_ema_revision
+ )
+
+ # Freeze vae and text_encoder and set unet to trainable
+ vae.requires_grad_(False)
+ text_encoder.requires_grad_(False)
+ unet.train()
+
+ # Create EMA for the unet.
+ if args.use_ema:
+ ema_unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
+ )
+ ema_unet = EMAModel(
+ ema_unet.parameters(),
+ model_cls=UNet2DConditionModel,
+ model_config=ema_unet.config,
+ foreach=args.foreach_ema,
+ )
+
+ if args.enable_xformers_memory_efficient_attention:
+ if is_xformers_available():
+ import xformers
+
+ xformers_version = version.parse(xformers.__version__)
+ if xformers_version == version.parse("0.0.16"):
+ logger.warning(
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
+ )
+ unet.enable_xformers_memory_efficient_attention()
+ else:
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
+
+ # `accelerate` 0.16.0 will have better support for customized saving
+ if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
+ def save_model_hook(models, weights, output_dir):
+ if accelerator.is_main_process:
+ if args.use_ema:
+ ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema"))
+
+ for i, model in enumerate(models):
+ model.save_pretrained(os.path.join(output_dir, "unet"))
+
+ # make sure to pop weight so that corresponding model is not saved again
+ weights.pop()
+
+ def load_model_hook(models, input_dir):
+ if args.use_ema:
+ load_model = EMAModel.from_pretrained(
+ os.path.join(input_dir, "unet_ema"), UNet2DConditionModel, foreach=args.foreach_ema
+ )
+ ema_unet.load_state_dict(load_model.state_dict())
+ if args.offload_ema:
+ ema_unet.pin_memory()
+ else:
+ ema_unet.to(accelerator.device)
+ del load_model
+
+ for _ in range(len(models)):
+ # pop models so that they are not loaded again
+ model = models.pop()
+
+ # load diffusers style into model
+ load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet")
+ model.register_to_config(**load_model.config)
+
+ model.load_state_dict(load_model.state_dict())
+ del load_model
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
+
+ if args.gradient_checkpointing:
+ unet.enable_gradient_checkpointing()
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ # Initialize the optimizer
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
+ )
+
+ optimizer_cls = bnb.optim.AdamW8bit
+ else:
+ optimizer_cls = torch.optim.AdamW
+
+ optimizer = optimizer_cls(
+ unet.parameters(),
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ # Get the datasets: you can either provide your own training and evaluation files (see below)
+ # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
+
+ # In distributed training, the load_dataset function guarantees that only one local process can concurrently
+ # download the dataset.
+ if args.dataset_name is not None:
+ # Downloading and loading a dataset from the hub.
+ dataset = load_dataset(
+ args.dataset_name,
+ args.dataset_config_name,
+ cache_dir=args.cache_dir,
+ data_dir=args.train_data_dir,
+ )
+ else:
+ data_files = {}
+ if args.train_data_dir is not None:
+ data_files["train"] = os.path.join(args.train_data_dir, "**")
+ dataset = load_dataset(
+ "imagefolder",
+ data_files=data_files,
+ cache_dir=args.cache_dir,
+ )
+ # See more about loading custom images at
+ # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
+
+ # Preprocessing the datasets.
+ # We need to tokenize inputs and targets.
+ column_names = dataset["train"].column_names
+
+ # 6. Get the column names for input/target.
+ dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None)
+ if args.image_column is None:
+ image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
+ else:
+ image_column = args.image_column
+ if image_column not in column_names:
+ raise ValueError(
+ f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}"
+ )
+ if args.caption_column is None:
+ caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
+ else:
+ caption_column = args.caption_column
+ if caption_column not in column_names:
+ raise ValueError(
+ f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}"
+ )
+
+ # Preprocessing the datasets.
+ # We need to tokenize input captions and transform the images.
+ def tokenize_captions(examples, is_train=True):
+ captions = []
+ for caption in examples[caption_column]:
+ if isinstance(caption, str):
+ captions.append(caption)
+ elif isinstance(caption, (list, np.ndarray)):
+ # take a random caption if there are multiple
+ captions.append(random.choice(caption) if is_train else caption[0])
+ else:
+ raise ValueError(
+ f"Caption column `{caption_column}` should contain either strings or lists of strings."
+ )
+ inputs = tokenizer(
+ captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
+ )
+ return inputs.input_ids
+
+ # Preprocessing the datasets.
+ train_transforms = transforms.Compose(
+ [
+ transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
+ transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
+ transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+ )
+
+ def preprocess_train(examples):
+ images = [image.convert("RGB") for image in examples[image_column]]
+ examples["pixel_values"] = [train_transforms(image) for image in images]
+ examples["input_ids"] = tokenize_captions(examples)
+ return examples
+
+ with accelerator.main_process_first():
+ if args.max_train_samples is not None:
+ dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
+ # Set the training transforms
+ train_dataset = dataset["train"].with_transform(preprocess_train)
+
+ def collate_fn(examples):
+ pixel_values = torch.stack([example["pixel_values"] for example in examples])
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
+ input_ids = torch.stack([example["input_ids"] for example in examples])
+ return {"pixel_values": pixel_values, "input_ids": input_ids}
+
+ # DataLoaders creation:
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset,
+ shuffle=True,
+ collate_fn=collate_fn,
+ batch_size=args.train_batch_size,
+ num_workers=args.dataloader_num_workers,
+ )
+
+ # Scheduler and math around the number of training steps.
+ # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
+ num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
+ if args.max_train_steps is None:
+ len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
+ num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
+ num_training_steps_for_scheduler = (
+ args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes
+ )
+ else:
+ num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=num_warmup_steps_for_scheduler,
+ num_training_steps=num_training_steps_for_scheduler,
+ )
+
+ # Prepare everything with our `accelerator`.
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ unet, optimizer, train_dataloader, lr_scheduler
+ )
+
+ if args.use_ema:
+ if args.offload_ema:
+ ema_unet.pin_memory()
+ else:
+ ema_unet.to(accelerator.device)
+
+ # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
+ # as these weights are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ args.mixed_precision = accelerator.mixed_precision
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+ args.mixed_precision = accelerator.mixed_precision
+
+ # Move text_encode and vae to gpu and cast to weight_dtype
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
+ vae.to(accelerator.device, dtype=weight_dtype)
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ if num_training_steps_for_scheduler != args.max_train_steps * accelerator.num_processes:
+ logger.warning(
+ f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
+ f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
+ f"This inconsistency may result in the learning rate scheduler not functioning properly."
+ )
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ tracker_config = dict(vars(args))
+ tracker_config.pop("validation_prompts")
+ accelerator.init_trackers(args.tracker_project_name, tracker_config)
+
+ # Function for unwrapping if model was compiled with `torch.compile`.
+ def unwrap_model(model):
+ model = accelerator.unwrap_model(model)
+ model = model._orig_mod if is_compiled_module(model) else model
+ return model
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the most recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ initial_global_step = 0
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ initial_global_step = global_step
+ first_epoch = global_step // num_update_steps_per_epoch
+
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
+
+ for epoch in range(first_epoch, args.num_train_epochs):
+ train_loss = 0.0
+ for step, batch in enumerate(train_dataloader):
+ with accelerator.accumulate(unet):
+ # Convert images to latent space
+ latents = vae.encode(batch["pixel_values"].to(weight_dtype)).latent_dist.sample()
+ latents = latents * vae.config.scaling_factor
+
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(latents)
+ if args.noise_offset:
+ # https://www.crosslabs.org//blog/diffusion-with-offset-noise
+ noise += args.noise_offset * torch.randn(
+ (latents.shape[0], latents.shape[1], 1, 1), device=latents.device
+ )
+ if args.input_perturbation:
+ new_noise = noise + args.input_perturbation * torch.randn_like(noise)
+ bsz = latents.shape[0]
+ # Sample a random timestep for each image
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
+ timesteps = timesteps.long()
+
+ # Add noise to the latents according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ if args.input_perturbation:
+ noisy_latents = noise_scheduler.add_noise(latents, new_noise, timesteps)
+ else:
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
+
+ # Get the text embedding for conditioning
+ encoder_hidden_states = text_encoder(batch["input_ids"], return_dict=False)[0]
+
+ # Get the target for loss depending on the prediction type
+ if args.prediction_type is not None:
+ # set prediction_type of scheduler if defined
+ noise_scheduler.register_to_config(prediction_type=args.prediction_type)
+
+ if noise_scheduler.config.prediction_type == "epsilon":
+ target = noise
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
+ else:
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
+
+ if args.dream_training:
+ noisy_latents, target = compute_dream_and_update_latents(
+ unet,
+ noise_scheduler,
+ timesteps,
+ noise,
+ noisy_latents,
+ target,
+ encoder_hidden_states,
+ args.dream_detail_preservation,
+ )
+
+ # Predict the noise residual and compute loss
+ model_pred = unet(noisy_latents, timesteps, encoder_hidden_states, return_dict=False)[0]
+
+ if args.snr_gamma is None:
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
+ else:
+ # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
+ # Since we predict the noise instead of x_0, the original formulation is slightly changed.
+ # This is discussed in Section 4.2 of the same paper.
+ snr = compute_snr(noise_scheduler, timesteps)
+ mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(
+ dim=1
+ )[0]
+ if noise_scheduler.config.prediction_type == "epsilon":
+ mse_loss_weights = mse_loss_weights / snr
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ mse_loss_weights = mse_loss_weights / (snr + 1)
+
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
+ loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
+ loss = loss.mean()
+
+ # Gather the losses across all processes for logging (if we use distributed training).
+ avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
+ train_loss += avg_loss.item() / args.gradient_accumulation_steps
+
+ # Backpropagate
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ if args.use_ema:
+ if args.offload_ema:
+ ema_unet.to(device="cuda", non_blocking=True)
+ ema_unet.step(unet.parameters())
+ if args.offload_ema:
+ ema_unet.to(device="cpu", non_blocking=True)
+ progress_bar.update(1)
+ global_step += 1
+ accelerator.log({"train_loss": train_loss}, step=global_step)
+ train_loss = 0.0
+
+ if global_step % args.checkpointing_steps == 0:
+ if accelerator.is_main_process:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ if accelerator.is_main_process:
+ if args.validation_prompts is not None and epoch % args.validation_epochs == 0:
+ if args.use_ema:
+ # Store the UNet parameters temporarily and load the EMA parameters to perform inference.
+ ema_unet.store(unet.parameters())
+ ema_unet.copy_to(unet.parameters())
+ log_validation(
+ vae,
+ text_encoder,
+ tokenizer,
+ unet,
+ args,
+ accelerator,
+ weight_dtype,
+ global_step,
+ )
+ if args.use_ema:
+ # Switch back to the original UNet parameters.
+ ema_unet.restore(unet.parameters())
+
+ # Create the pipeline using the trained modules and save it.
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ unet = unwrap_model(unet)
+ if args.use_ema:
+ ema_unet.copy_to(unet.parameters())
+
+ pipeline = StableDiffusionPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ text_encoder=text_encoder,
+ vae=vae,
+ unet=unet,
+ revision=args.revision,
+ variant=args.variant,
+ )
+ pipeline.save_pretrained(args.output_dir)
+
+ # Run a final round of inference.
+ images = []
+ if args.validation_prompts is not None:
+ logger.info("Running inference for collecting generated images...")
+ pipeline = pipeline.to(accelerator.device)
+ pipeline.torch_dtype = weight_dtype
+ pipeline.set_progress_bar_config(disable=True)
+
+ if args.enable_xformers_memory_efficient_attention:
+ pipeline.enable_xformers_memory_efficient_attention()
+
+ if args.seed is None:
+ generator = None
+ else:
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
+
+ for i in range(len(args.validation_prompts)):
+ with torch.autocast("cuda"):
+ image = pipeline(args.validation_prompts[i], num_inference_steps=20, generator=generator).images[0]
+ images.append(image)
+
+ if args.push_to_hub:
+ save_model_card(args, repo_id, images, repo_folder=args.output_dir)
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/diffusers/examples/text_to_image/train_text_to_image_flax.py b/diffusers/examples/text_to_image/train_text_to_image_flax.py
new file mode 100644
index 0000000000000000000000000000000000000000..4a80067d693d64c85ae5befff5852ea7537d9c48
--- /dev/null
+++ b/diffusers/examples/text_to_image/train_text_to_image_flax.py
@@ -0,0 +1,620 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import logging
+import math
+import os
+import random
+from pathlib import Path
+
+import jax
+import jax.numpy as jnp
+import numpy as np
+import optax
+import torch
+import torch.utils.checkpoint
+import transformers
+from datasets import load_dataset
+from flax import jax_utils
+from flax.training import train_state
+from flax.training.common_utils import shard
+from huggingface_hub import create_repo, upload_folder
+from torchvision import transforms
+from tqdm.auto import tqdm
+from transformers import CLIPImageProcessor, CLIPTokenizer, FlaxCLIPTextModel, set_seed
+
+from diffusers import (
+ FlaxAutoencoderKL,
+ FlaxDDPMScheduler,
+ FlaxPNDMScheduler,
+ FlaxStableDiffusionPipeline,
+ FlaxUNet2DConditionModel,
+)
+from diffusers.pipelines.stable_diffusion import FlaxStableDiffusionSafetyChecker
+from diffusers.utils import check_min_version
+
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.31.0.dev0")
+
+logger = logging.getLogger(__name__)
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--variant",
+ type=str,
+ default=None,
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
+ )
+ parser.add_argument(
+ "--dataset_name",
+ type=str,
+ default=None,
+ help=(
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
+ " or to a folder containing files that 🤗 Datasets can understand."
+ ),
+ )
+ parser.add_argument(
+ "--dataset_config_name",
+ type=str,
+ default=None,
+ help="The config of the Dataset, leave as None if there's only one config.",
+ )
+ parser.add_argument(
+ "--train_data_dir",
+ type=str,
+ default=None,
+ help=(
+ "A folder containing the training data. Folder contents must follow the structure described in"
+ " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
+ " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
+ ),
+ )
+ parser.add_argument(
+ "--image_column", type=str, default="image", help="The column of the dataset containing an image."
+ )
+ parser.add_argument(
+ "--caption_column",
+ type=str,
+ default="text",
+ help="The column of the dataset containing a caption or a list of captions.",
+ )
+ parser.add_argument(
+ "--max_train_samples",
+ type=int,
+ default=None,
+ help=(
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ ),
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="sd-model-finetuned",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+ parser.add_argument("--seed", type=int, default=0, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--center_crop",
+ default=False,
+ action="store_true",
+ help=(
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
+ " cropped. The images will be resized to the resolution first before cropping."
+ ),
+ )
+ parser.add_argument(
+ "--random_flip",
+ action="store_true",
+ help="whether to randomly flip images horizontally",
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=100)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default="no",
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose"
+ "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
+ "and an Nvidia Ampere GPU."
+ ),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+ parser.add_argument(
+ "--from_pt",
+ action="store_true",
+ default=False,
+ help="Flag to indicate whether to convert models from PyTorch.",
+ )
+
+ args = parser.parse_args()
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ # Sanity checks
+ if args.dataset_name is None and args.train_data_dir is None:
+ raise ValueError("Need either a dataset name or a training folder.")
+
+ return args
+
+
+dataset_name_mapping = {
+ "lambdalabs/naruto-blip-captions": ("image", "text"),
+}
+
+
+def get_params_to_save(params):
+ return jax.device_get(jax.tree_util.tree_map(lambda x: x[0], params))
+
+
+def main():
+ args = parse_args()
+
+ if args.report_to == "wandb" and args.hub_token is not None:
+ raise ValueError(
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
+ " Please use `huggingface-cli login` to authenticate with the Hub."
+ )
+
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ # Setup logging, we only want one process per machine to log things on the screen.
+ logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
+ if jax.process_index() == 0:
+ transformers.utils.logging.set_verbosity_info()
+ else:
+ transformers.utils.logging.set_verbosity_error()
+
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Handle the repository creation
+ if jax.process_index() == 0:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
+ ).repo_id
+
+ # Get the datasets: you can either provide your own training and evaluation files (see below)
+ # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
+
+ # In distributed training, the load_dataset function guarantees that only one local process can concurrently
+ # download the dataset.
+ if args.dataset_name is not None:
+ # Downloading and loading a dataset from the hub.
+ dataset = load_dataset(
+ args.dataset_name, args.dataset_config_name, cache_dir=args.cache_dir, data_dir=args.train_data_dir
+ )
+ else:
+ data_files = {}
+ if args.train_data_dir is not None:
+ data_files["train"] = os.path.join(args.train_data_dir, "**")
+ dataset = load_dataset(
+ "imagefolder",
+ data_files=data_files,
+ cache_dir=args.cache_dir,
+ )
+ # See more about loading custom images at
+ # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
+
+ # Preprocessing the datasets.
+ # We need to tokenize inputs and targets.
+ column_names = dataset["train"].column_names
+
+ # 6. Get the column names for input/target.
+ dataset_columns = dataset_name_mapping.get(args.dataset_name, None)
+ if args.image_column is None:
+ image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
+ else:
+ image_column = args.image_column
+ if image_column not in column_names:
+ raise ValueError(
+ f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}"
+ )
+ if args.caption_column is None:
+ caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
+ else:
+ caption_column = args.caption_column
+ if caption_column not in column_names:
+ raise ValueError(
+ f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}"
+ )
+
+ # Preprocessing the datasets.
+ # We need to tokenize input captions and transform the images.
+ def tokenize_captions(examples, is_train=True):
+ captions = []
+ for caption in examples[caption_column]:
+ if isinstance(caption, str):
+ captions.append(caption)
+ elif isinstance(caption, (list, np.ndarray)):
+ # take a random caption if there are multiple
+ captions.append(random.choice(caption) if is_train else caption[0])
+ else:
+ raise ValueError(
+ f"Caption column `{caption_column}` should contain either strings or lists of strings."
+ )
+ inputs = tokenizer(captions, max_length=tokenizer.model_max_length, padding="do_not_pad", truncation=True)
+ input_ids = inputs.input_ids
+ return input_ids
+
+ train_transforms = transforms.Compose(
+ [
+ transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
+ transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
+ transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+ )
+
+ def preprocess_train(examples):
+ images = [image.convert("RGB") for image in examples[image_column]]
+ examples["pixel_values"] = [train_transforms(image) for image in images]
+ examples["input_ids"] = tokenize_captions(examples)
+
+ return examples
+
+ if args.max_train_samples is not None:
+ dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
+ # Set the training transforms
+ train_dataset = dataset["train"].with_transform(preprocess_train)
+
+ def collate_fn(examples):
+ pixel_values = torch.stack([example["pixel_values"] for example in examples])
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
+ input_ids = [example["input_ids"] for example in examples]
+
+ padded_tokens = tokenizer.pad(
+ {"input_ids": input_ids}, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt"
+ )
+ batch = {
+ "pixel_values": pixel_values,
+ "input_ids": padded_tokens.input_ids,
+ }
+ batch = {k: v.numpy() for k, v in batch.items()}
+
+ return batch
+
+ total_train_batch_size = args.train_batch_size * jax.local_device_count()
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset, shuffle=True, collate_fn=collate_fn, batch_size=total_train_batch_size, drop_last=True
+ )
+
+ weight_dtype = jnp.float32
+ if args.mixed_precision == "fp16":
+ weight_dtype = jnp.float16
+ elif args.mixed_precision == "bf16":
+ weight_dtype = jnp.bfloat16
+
+ # Load models and create wrapper for stable diffusion
+ tokenizer = CLIPTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path,
+ from_pt=args.from_pt,
+ revision=args.revision,
+ subfolder="tokenizer",
+ )
+ text_encoder = FlaxCLIPTextModel.from_pretrained(
+ args.pretrained_model_name_or_path,
+ from_pt=args.from_pt,
+ revision=args.revision,
+ subfolder="text_encoder",
+ dtype=weight_dtype,
+ )
+ vae, vae_params = FlaxAutoencoderKL.from_pretrained(
+ args.pretrained_model_name_or_path,
+ from_pt=args.from_pt,
+ revision=args.revision,
+ subfolder="vae",
+ dtype=weight_dtype,
+ )
+ unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(
+ args.pretrained_model_name_or_path,
+ from_pt=args.from_pt,
+ revision=args.revision,
+ subfolder="unet",
+ dtype=weight_dtype,
+ )
+
+ # Optimization
+ if args.scale_lr:
+ args.learning_rate = args.learning_rate * total_train_batch_size
+
+ constant_scheduler = optax.constant_schedule(args.learning_rate)
+
+ adamw = optax.adamw(
+ learning_rate=constant_scheduler,
+ b1=args.adam_beta1,
+ b2=args.adam_beta2,
+ eps=args.adam_epsilon,
+ weight_decay=args.adam_weight_decay,
+ )
+
+ optimizer = optax.chain(
+ optax.clip_by_global_norm(args.max_grad_norm),
+ adamw,
+ )
+
+ state = train_state.TrainState.create(apply_fn=unet.__call__, params=unet_params, tx=optimizer)
+
+ noise_scheduler = FlaxDDPMScheduler(
+ beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000
+ )
+ noise_scheduler_state = noise_scheduler.create_state()
+
+ # Initialize our training
+ rng = jax.random.PRNGKey(args.seed)
+ train_rngs = jax.random.split(rng, jax.local_device_count())
+
+ def train_step(state, text_encoder_params, vae_params, batch, train_rng):
+ dropout_rng, sample_rng, new_train_rng = jax.random.split(train_rng, 3)
+
+ def compute_loss(params):
+ # Convert images to latent space
+ vae_outputs = vae.apply(
+ {"params": vae_params}, batch["pixel_values"], deterministic=True, method=vae.encode
+ )
+ latents = vae_outputs.latent_dist.sample(sample_rng)
+ # (NHWC) -> (NCHW)
+ latents = jnp.transpose(latents, (0, 3, 1, 2))
+ latents = latents * vae.config.scaling_factor
+
+ # Sample noise that we'll add to the latents
+ noise_rng, timestep_rng = jax.random.split(sample_rng)
+ noise = jax.random.normal(noise_rng, latents.shape)
+ # Sample a random timestep for each image
+ bsz = latents.shape[0]
+ timesteps = jax.random.randint(
+ timestep_rng,
+ (bsz,),
+ 0,
+ noise_scheduler.config.num_train_timesteps,
+ )
+
+ # Add noise to the latents according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ noisy_latents = noise_scheduler.add_noise(noise_scheduler_state, latents, noise, timesteps)
+
+ # Get the text embedding for conditioning
+ encoder_hidden_states = text_encoder(
+ batch["input_ids"],
+ params=text_encoder_params,
+ train=False,
+ )[0]
+
+ # Predict the noise residual and compute loss
+ model_pred = unet.apply(
+ {"params": params}, noisy_latents, timesteps, encoder_hidden_states, train=True
+ ).sample
+
+ # Get the target for loss depending on the prediction type
+ if noise_scheduler.config.prediction_type == "epsilon":
+ target = noise
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ target = noise_scheduler.get_velocity(noise_scheduler_state, latents, noise, timesteps)
+ else:
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
+
+ loss = (target - model_pred) ** 2
+ loss = loss.mean()
+
+ return loss
+
+ grad_fn = jax.value_and_grad(compute_loss)
+ loss, grad = grad_fn(state.params)
+ grad = jax.lax.pmean(grad, "batch")
+
+ new_state = state.apply_gradients(grads=grad)
+
+ metrics = {"loss": loss}
+ metrics = jax.lax.pmean(metrics, axis_name="batch")
+
+ return new_state, metrics, new_train_rng
+
+ # Create parallel version of the train step
+ p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
+
+ # Replicate the train state on each device
+ state = jax_utils.replicate(state)
+ text_encoder_params = jax_utils.replicate(text_encoder.params)
+ vae_params = jax_utils.replicate(vae_params)
+
+ # Train!
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader))
+
+ # Scheduler and math around the number of training steps.
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel & distributed) = {total_train_batch_size}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+
+ global_step = 0
+
+ epochs = tqdm(range(args.num_train_epochs), desc="Epoch ... ", position=0)
+ for epoch in epochs:
+ # ======================== Training ================================
+
+ train_metrics = []
+
+ steps_per_epoch = len(train_dataset) // total_train_batch_size
+ train_step_progress_bar = tqdm(total=steps_per_epoch, desc="Training...", position=1, leave=False)
+ # train
+ for batch in train_dataloader:
+ batch = shard(batch)
+ state, train_metric, train_rngs = p_train_step(state, text_encoder_params, vae_params, batch, train_rngs)
+ train_metrics.append(train_metric)
+
+ train_step_progress_bar.update(1)
+
+ global_step += 1
+ if global_step >= args.max_train_steps:
+ break
+
+ train_metric = jax_utils.unreplicate(train_metric)
+
+ train_step_progress_bar.close()
+ epochs.write(f"Epoch... ({epoch + 1}/{args.num_train_epochs} | Loss: {train_metric['loss']})")
+
+ # Create the pipeline using using the trained modules and save it.
+ if jax.process_index() == 0:
+ scheduler = FlaxPNDMScheduler(
+ beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True
+ )
+ safety_checker = FlaxStableDiffusionSafetyChecker.from_pretrained(
+ "CompVis/stable-diffusion-safety-checker", from_pt=True
+ )
+ pipeline = FlaxStableDiffusionPipeline(
+ text_encoder=text_encoder,
+ vae=vae,
+ unet=unet,
+ tokenizer=tokenizer,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32"),
+ )
+
+ pipeline.save_pretrained(
+ args.output_dir,
+ params={
+ "text_encoder": get_params_to_save(text_encoder_params),
+ "vae": get_params_to_save(vae_params),
+ "unet": get_params_to_save(state.params),
+ "safety_checker": safety_checker.params,
+ },
+ )
+
+ if args.push_to_hub:
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+
+if __name__ == "__main__":
+ main()
diff --git a/diffusers/examples/text_to_image/train_text_to_image_lora.py b/diffusers/examples/text_to_image/train_text_to_image_lora.py
new file mode 100644
index 0000000000000000000000000000000000000000..379519b4c8122e4892f24f83c667ec8385b73235
--- /dev/null
+++ b/diffusers/examples/text_to_image/train_text_to_image_lora.py
@@ -0,0 +1,979 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Fine-tuning script for Stable Diffusion for text2image with support for LoRA."""
+
+import argparse
+import logging
+import math
+import os
+import random
+import shutil
+from contextlib import nullcontext
+from pathlib import Path
+
+import datasets
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import ProjectConfiguration, set_seed
+from datasets import load_dataset
+from huggingface_hub import create_repo, upload_folder
+from packaging import version
+from peft import LoraConfig
+from peft.utils import get_peft_model_state_dict
+from torchvision import transforms
+from tqdm.auto import tqdm
+from transformers import CLIPTextModel, CLIPTokenizer
+
+import diffusers
+from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, StableDiffusionPipeline, UNet2DConditionModel
+from diffusers.optimization import get_scheduler
+from diffusers.training_utils import cast_training_params, compute_snr
+from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available
+from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
+from diffusers.utils.import_utils import is_xformers_available
+from diffusers.utils.torch_utils import is_compiled_module
+
+
+if is_wandb_available():
+ import wandb
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.31.0.dev0")
+
+logger = get_logger(__name__, log_level="INFO")
+
+
+def save_model_card(
+ repo_id: str,
+ images: list = None,
+ base_model: str = None,
+ dataset_name: str = None,
+ repo_folder: str = None,
+):
+ img_str = ""
+ if images is not None:
+ for i, image in enumerate(images):
+ image.save(os.path.join(repo_folder, f"image_{i}.png"))
+ img_str += f"\n"
+
+ model_description = f"""
+# LoRA text2image fine-tuning - {repo_id}
+These are LoRA adaption weights for {base_model}. The weights were fine-tuned on the {dataset_name} dataset. You can find some example images in the following. \n
+{img_str}
+"""
+
+ model_card = load_or_create_model_card(
+ repo_id_or_path=repo_id,
+ from_training=True,
+ license="creativeml-openrail-m",
+ base_model=base_model,
+ model_description=model_description,
+ inference=True,
+ )
+
+ tags = [
+ "stable-diffusion",
+ "stable-diffusion-diffusers",
+ "text-to-image",
+ "diffusers",
+ "diffusers-training",
+ "lora",
+ ]
+ model_card = populate_model_card(model_card, tags=tags)
+
+ model_card.save(os.path.join(repo_folder, "README.md"))
+
+
+def log_validation(
+ pipeline,
+ args,
+ accelerator,
+ epoch,
+ is_final_validation=False,
+):
+ logger.info(
+ f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
+ f" {args.validation_prompt}."
+ )
+ pipeline = pipeline.to(accelerator.device)
+ pipeline.set_progress_bar_config(disable=True)
+ generator = torch.Generator(device=accelerator.device)
+ if args.seed is not None:
+ generator = generator.manual_seed(args.seed)
+ images = []
+ if torch.backends.mps.is_available():
+ autocast_ctx = nullcontext()
+ else:
+ autocast_ctx = torch.autocast(accelerator.device.type)
+
+ with autocast_ctx:
+ for _ in range(args.num_validation_images):
+ images.append(pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0])
+
+ for tracker in accelerator.trackers:
+ phase_name = "test" if is_final_validation else "validation"
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images(phase_name, np_images, epoch, dataformats="NHWC")
+ if tracker.name == "wandb":
+ tracker.log(
+ {
+ phase_name: [
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images)
+ ]
+ }
+ )
+ return images
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--variant",
+ type=str,
+ default=None,
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
+ )
+ parser.add_argument(
+ "--dataset_name",
+ type=str,
+ default=None,
+ help=(
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
+ " or to a folder containing files that 🤗 Datasets can understand."
+ ),
+ )
+ parser.add_argument(
+ "--dataset_config_name",
+ type=str,
+ default=None,
+ help="The config of the Dataset, leave as None if there's only one config.",
+ )
+ parser.add_argument(
+ "--train_data_dir",
+ type=str,
+ default=None,
+ help=(
+ "A folder containing the training data. Folder contents must follow the structure described in"
+ " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
+ " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
+ ),
+ )
+ parser.add_argument(
+ "--image_column", type=str, default="image", help="The column of the dataset containing an image."
+ )
+ parser.add_argument(
+ "--caption_column",
+ type=str,
+ default="text",
+ help="The column of the dataset containing a caption or a list of captions.",
+ )
+ parser.add_argument(
+ "--validation_prompt", type=str, default=None, help="A prompt that is sampled during training for inference."
+ )
+ parser.add_argument(
+ "--num_validation_images",
+ type=int,
+ default=4,
+ help="Number of images that should be generated during validation with `validation_prompt`.",
+ )
+ parser.add_argument(
+ "--validation_epochs",
+ type=int,
+ default=1,
+ help=(
+ "Run fine-tuning validation every X epochs. The validation process consists of running the prompt"
+ " `args.validation_prompt` multiple times: `args.num_validation_images`."
+ ),
+ )
+ parser.add_argument(
+ "--max_train_samples",
+ type=int,
+ default=None,
+ help=(
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ ),
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="sd-model-finetuned-lora",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--center_crop",
+ default=False,
+ action="store_true",
+ help=(
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
+ " cropped. The images will be resized to the resolution first before cropping."
+ ),
+ )
+ parser.add_argument(
+ "--random_flip",
+ action="store_true",
+ help="whether to randomly flip images horizontally",
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=100)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--snr_gamma",
+ type=float,
+ default=None,
+ help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
+ "More details here: https://arxiv.org/abs/2303.09556.",
+ )
+ parser.add_argument(
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--prediction_type",
+ type=str,
+ default=None,
+ help="The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediction_type` is chosen.",
+ )
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
+ )
+ parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.")
+ parser.add_argument(
+ "--rank",
+ type=int,
+ default=4,
+ help=("The dimension of the LoRA update matrices."),
+ )
+
+ args = parser.parse_args()
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ # Sanity checks
+ if args.dataset_name is None and args.train_data_dir is None:
+ raise ValueError("Need either a dataset name or a training folder.")
+
+ return args
+
+
+DATASET_NAME_MAPPING = {
+ "lambdalabs/naruto-blip-captions": ("image", "text"),
+}
+
+
+def main():
+ args = parse_args()
+ if args.report_to == "wandb" and args.hub_token is not None:
+ raise ValueError(
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
+ " Please use `huggingface-cli login` to authenticate with the Hub."
+ )
+
+ logging_dir = Path(args.output_dir, args.logging_dir)
+
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ )
+
+ # Disable AMP for MPS.
+ if torch.backends.mps.is_available():
+ accelerator.native_amp = False
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ datasets.utils.logging.set_verbosity_warning()
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ datasets.utils.logging.set_verbosity_error()
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
+ ).repo_id
+ # Load scheduler, tokenizer and models.
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
+ tokenizer = CLIPTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
+ )
+ text_encoder = CLIPTextModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
+ )
+ vae = AutoencoderKL.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant
+ )
+ unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
+ )
+ # freeze parameters of models to save more memory
+ unet.requires_grad_(False)
+ vae.requires_grad_(False)
+ text_encoder.requires_grad_(False)
+
+ # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
+ # as these weights are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ # Freeze the unet parameters before adding adapters
+ for param in unet.parameters():
+ param.requires_grad_(False)
+
+ unet_lora_config = LoraConfig(
+ r=args.rank,
+ lora_alpha=args.rank,
+ init_lora_weights="gaussian",
+ target_modules=["to_k", "to_q", "to_v", "to_out.0"],
+ )
+
+ # Move unet, vae and text_encoder to device and cast to weight_dtype
+ unet.to(accelerator.device, dtype=weight_dtype)
+ vae.to(accelerator.device, dtype=weight_dtype)
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
+
+ # Add adapter and make sure the trainable params are in float32.
+ unet.add_adapter(unet_lora_config)
+ if args.mixed_precision == "fp16":
+ # only upcast trainable parameters (LoRA) into fp32
+ cast_training_params(unet, dtype=torch.float32)
+
+ if args.enable_xformers_memory_efficient_attention:
+ if is_xformers_available():
+ import xformers
+
+ xformers_version = version.parse(xformers.__version__)
+ if xformers_version == version.parse("0.0.16"):
+ logger.warning(
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
+ )
+ unet.enable_xformers_memory_efficient_attention()
+ else:
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
+
+ lora_layers = filter(lambda p: p.requires_grad, unet.parameters())
+
+ if args.gradient_checkpointing:
+ unet.enable_gradient_checkpointing()
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ # Initialize the optimizer
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
+ )
+
+ optimizer_cls = bnb.optim.AdamW8bit
+ else:
+ optimizer_cls = torch.optim.AdamW
+
+ optimizer = optimizer_cls(
+ lora_layers,
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ # Get the datasets: you can either provide your own training and evaluation files (see below)
+ # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
+
+ # In distributed training, the load_dataset function guarantees that only one local process can concurrently
+ # download the dataset.
+ if args.dataset_name is not None:
+ # Downloading and loading a dataset from the hub.
+ dataset = load_dataset(
+ args.dataset_name,
+ args.dataset_config_name,
+ cache_dir=args.cache_dir,
+ data_dir=args.train_data_dir,
+ )
+ else:
+ data_files = {}
+ if args.train_data_dir is not None:
+ data_files["train"] = os.path.join(args.train_data_dir, "**")
+ dataset = load_dataset(
+ "imagefolder",
+ data_files=data_files,
+ cache_dir=args.cache_dir,
+ )
+ # See more about loading custom images at
+ # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
+
+ # Preprocessing the datasets.
+ # We need to tokenize inputs and targets.
+ column_names = dataset["train"].column_names
+
+ # 6. Get the column names for input/target.
+ dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None)
+ if args.image_column is None:
+ image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
+ else:
+ image_column = args.image_column
+ if image_column not in column_names:
+ raise ValueError(
+ f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}"
+ )
+ if args.caption_column is None:
+ caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
+ else:
+ caption_column = args.caption_column
+ if caption_column not in column_names:
+ raise ValueError(
+ f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}"
+ )
+
+ # Preprocessing the datasets.
+ # We need to tokenize input captions and transform the images.
+ def tokenize_captions(examples, is_train=True):
+ captions = []
+ for caption in examples[caption_column]:
+ if isinstance(caption, str):
+ captions.append(caption)
+ elif isinstance(caption, (list, np.ndarray)):
+ # take a random caption if there are multiple
+ captions.append(random.choice(caption) if is_train else caption[0])
+ else:
+ raise ValueError(
+ f"Caption column `{caption_column}` should contain either strings or lists of strings."
+ )
+ inputs = tokenizer(
+ captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
+ )
+ return inputs.input_ids
+
+ # Preprocessing the datasets.
+ train_transforms = transforms.Compose(
+ [
+ transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
+ transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
+ transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+ )
+
+ def unwrap_model(model):
+ model = accelerator.unwrap_model(model)
+ model = model._orig_mod if is_compiled_module(model) else model
+ return model
+
+ def preprocess_train(examples):
+ images = [image.convert("RGB") for image in examples[image_column]]
+ examples["pixel_values"] = [train_transforms(image) for image in images]
+ examples["input_ids"] = tokenize_captions(examples)
+ return examples
+
+ with accelerator.main_process_first():
+ if args.max_train_samples is not None:
+ dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
+ # Set the training transforms
+ train_dataset = dataset["train"].with_transform(preprocess_train)
+
+ def collate_fn(examples):
+ pixel_values = torch.stack([example["pixel_values"] for example in examples])
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
+ input_ids = torch.stack([example["input_ids"] for example in examples])
+ return {"pixel_values": pixel_values, "input_ids": input_ids}
+
+ # DataLoaders creation:
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset,
+ shuffle=True,
+ collate_fn=collate_fn,
+ batch_size=args.train_batch_size,
+ num_workers=args.dataloader_num_workers,
+ )
+
+ # Scheduler and math around the number of training steps.
+ # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
+ num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
+ if args.max_train_steps is None:
+ len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
+ num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
+ num_training_steps_for_scheduler = (
+ args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes
+ )
+ else:
+ num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=num_warmup_steps_for_scheduler,
+ num_training_steps=num_training_steps_for_scheduler,
+ )
+
+ # Prepare everything with our `accelerator`.
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ unet, optimizer, train_dataloader, lr_scheduler
+ )
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ if num_training_steps_for_scheduler != args.max_train_steps * accelerator.num_processes:
+ logger.warning(
+ f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
+ f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
+ f"This inconsistency may result in the learning rate scheduler not functioning properly."
+ )
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ accelerator.init_trackers("text2image-fine-tune", config=vars(args))
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the most recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ initial_global_step = 0
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ initial_global_step = global_step
+ first_epoch = global_step // num_update_steps_per_epoch
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
+
+ for epoch in range(first_epoch, args.num_train_epochs):
+ unet.train()
+ train_loss = 0.0
+ for step, batch in enumerate(train_dataloader):
+ with accelerator.accumulate(unet):
+ # Convert images to latent space
+ latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
+ latents = latents * vae.config.scaling_factor
+
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(latents)
+ if args.noise_offset:
+ # https://www.crosslabs.org//blog/diffusion-with-offset-noise
+ noise += args.noise_offset * torch.randn(
+ (latents.shape[0], latents.shape[1], 1, 1), device=latents.device
+ )
+
+ bsz = latents.shape[0]
+ # Sample a random timestep for each image
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
+ timesteps = timesteps.long()
+
+ # Add noise to the latents according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
+
+ # Get the text embedding for conditioning
+ encoder_hidden_states = text_encoder(batch["input_ids"], return_dict=False)[0]
+
+ # Get the target for loss depending on the prediction type
+ if args.prediction_type is not None:
+ # set prediction_type of scheduler if defined
+ noise_scheduler.register_to_config(prediction_type=args.prediction_type)
+
+ if noise_scheduler.config.prediction_type == "epsilon":
+ target = noise
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
+ else:
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
+
+ # Predict the noise residual and compute loss
+ model_pred = unet(noisy_latents, timesteps, encoder_hidden_states, return_dict=False)[0]
+
+ if args.snr_gamma is None:
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
+ else:
+ # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
+ # Since we predict the noise instead of x_0, the original formulation is slightly changed.
+ # This is discussed in Section 4.2 of the same paper.
+ snr = compute_snr(noise_scheduler, timesteps)
+ mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(
+ dim=1
+ )[0]
+ if noise_scheduler.config.prediction_type == "epsilon":
+ mse_loss_weights = mse_loss_weights / snr
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ mse_loss_weights = mse_loss_weights / (snr + 1)
+
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
+ loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
+ loss = loss.mean()
+
+ # Gather the losses across all processes for logging (if we use distributed training).
+ avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
+ train_loss += avg_loss.item() / args.gradient_accumulation_steps
+
+ # Backpropagate
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ params_to_clip = lora_layers
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ global_step += 1
+ accelerator.log({"train_loss": train_loss}, step=global_step)
+ train_loss = 0.0
+
+ if global_step % args.checkpointing_steps == 0:
+ if accelerator.is_main_process:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+
+ unwrapped_unet = unwrap_model(unet)
+ unet_lora_state_dict = convert_state_dict_to_diffusers(
+ get_peft_model_state_dict(unwrapped_unet)
+ )
+
+ StableDiffusionPipeline.save_lora_weights(
+ save_directory=save_path,
+ unet_lora_layers=unet_lora_state_dict,
+ safe_serialization=True,
+ )
+
+ logger.info(f"Saved state to {save_path}")
+
+ logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ if accelerator.is_main_process:
+ if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
+ # create pipeline
+ pipeline = DiffusionPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ unet=unwrap_model(unet),
+ revision=args.revision,
+ variant=args.variant,
+ torch_dtype=weight_dtype,
+ )
+ images = log_validation(pipeline, args, accelerator, epoch)
+
+ del pipeline
+ torch.cuda.empty_cache()
+
+ # Save the lora layers
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ unet = unet.to(torch.float32)
+
+ unwrapped_unet = unwrap_model(unet)
+ unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unwrapped_unet))
+ StableDiffusionPipeline.save_lora_weights(
+ save_directory=args.output_dir,
+ unet_lora_layers=unet_lora_state_dict,
+ safe_serialization=True,
+ )
+
+ # Final inference
+ # Load previous pipeline
+ if args.validation_prompt is not None:
+ pipeline = DiffusionPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ revision=args.revision,
+ variant=args.variant,
+ torch_dtype=weight_dtype,
+ )
+
+ # load attention processors
+ pipeline.load_lora_weights(args.output_dir)
+
+ # run inference
+ images = log_validation(pipeline, args, accelerator, epoch, is_final_validation=True)
+
+ if args.push_to_hub:
+ save_model_card(
+ repo_id,
+ images=images,
+ base_model=args.pretrained_model_name_or_path,
+ dataset_name=args.dataset_name,
+ repo_folder=args.output_dir,
+ )
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/diffusers/examples/text_to_image/train_text_to_image_lora_sdxl.py b/diffusers/examples/text_to_image/train_text_to_image_lora_sdxl.py
new file mode 100644
index 0000000000000000000000000000000000000000..fe098c8638d52ff1ffd4503136f9aee0eaa2ace2
--- /dev/null
+++ b/diffusers/examples/text_to_image/train_text_to_image_lora_sdxl.py
@@ -0,0 +1,1327 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Fine-tuning script for Stable Diffusion XL for text2image with support for LoRA."""
+
+import argparse
+import logging
+import math
+import os
+import random
+import shutil
+from contextlib import nullcontext
+from pathlib import Path
+
+import datasets
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import DistributedDataParallelKwargs, DistributedType, ProjectConfiguration, set_seed
+from datasets import load_dataset
+from huggingface_hub import create_repo, upload_folder
+from packaging import version
+from peft import LoraConfig, set_peft_model_state_dict
+from peft.utils import get_peft_model_state_dict
+from torchvision import transforms
+from torchvision.transforms.functional import crop
+from tqdm.auto import tqdm
+from transformers import AutoTokenizer, PretrainedConfig
+
+import diffusers
+from diffusers import (
+ AutoencoderKL,
+ DDPMScheduler,
+ StableDiffusionXLPipeline,
+ UNet2DConditionModel,
+)
+from diffusers.loaders import StableDiffusionLoraLoaderMixin
+from diffusers.optimization import get_scheduler
+from diffusers.training_utils import _set_state_dict_into_text_encoder, cast_training_params, compute_snr
+from diffusers.utils import (
+ check_min_version,
+ convert_state_dict_to_diffusers,
+ convert_unet_state_dict_to_peft,
+ is_wandb_available,
+)
+from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
+from diffusers.utils.import_utils import is_torch_npu_available, is_xformers_available
+from diffusers.utils.torch_utils import is_compiled_module
+
+
+if is_wandb_available():
+ import wandb
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.31.0.dev0")
+
+logger = get_logger(__name__)
+if is_torch_npu_available():
+ torch.npu.config.allow_internal_format = False
+
+
+def save_model_card(
+ repo_id: str,
+ images: list = None,
+ base_model: str = None,
+ dataset_name: str = None,
+ train_text_encoder: bool = False,
+ repo_folder: str = None,
+ vae_path: str = None,
+):
+ img_str = ""
+ if images is not None:
+ for i, image in enumerate(images):
+ image.save(os.path.join(repo_folder, f"image_{i}.png"))
+ img_str += f"\n"
+
+ model_description = f"""
+# LoRA text2image fine-tuning - {repo_id}
+
+These are LoRA adaption weights for {base_model}. The weights were fine-tuned on the {dataset_name} dataset. You can find some example images in the following. \n
+{img_str}
+
+LoRA for the text encoder was enabled: {train_text_encoder}.
+
+Special VAE used for training: {vae_path}.
+"""
+ model_card = load_or_create_model_card(
+ repo_id_or_path=repo_id,
+ from_training=True,
+ license="creativeml-openrail-m",
+ base_model=base_model,
+ model_description=model_description,
+ inference=True,
+ )
+
+ tags = [
+ "stable-diffusion-xl",
+ "stable-diffusion-xl-diffusers",
+ "text-to-image",
+ "diffusers",
+ "diffusers-training",
+ "lora",
+ ]
+ model_card = populate_model_card(model_card, tags=tags)
+
+ model_card.save(os.path.join(repo_folder, "README.md"))
+
+
+def log_validation(
+ pipeline,
+ args,
+ accelerator,
+ epoch,
+ is_final_validation=False,
+):
+ logger.info(
+ f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
+ f" {args.validation_prompt}."
+ )
+ pipeline = pipeline.to(accelerator.device)
+ pipeline.set_progress_bar_config(disable=True)
+
+ # run inference
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
+ pipeline_args = {"prompt": args.validation_prompt}
+ if torch.backends.mps.is_available():
+ autocast_ctx = nullcontext()
+ else:
+ autocast_ctx = torch.autocast(accelerator.device.type)
+
+ with autocast_ctx:
+ images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)]
+
+ for tracker in accelerator.trackers:
+ phase_name = "test" if is_final_validation else "validation"
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images(phase_name, np_images, epoch, dataformats="NHWC")
+ if tracker.name == "wandb":
+ tracker.log(
+ {
+ phase_name: [
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images)
+ ]
+ }
+ )
+ return images
+
+
+def import_model_class_from_model_name_or_path(
+ pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
+):
+ text_encoder_config = PretrainedConfig.from_pretrained(
+ pretrained_model_name_or_path, subfolder=subfolder, revision=revision
+ )
+ model_class = text_encoder_config.architectures[0]
+
+ if model_class == "CLIPTextModel":
+ from transformers import CLIPTextModel
+
+ return CLIPTextModel
+ elif model_class == "CLIPTextModelWithProjection":
+ from transformers import CLIPTextModelWithProjection
+
+ return CLIPTextModelWithProjection
+ else:
+ raise ValueError(f"{model_class} is not supported.")
+
+
+def parse_args(input_args=None):
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--pretrained_vae_model_name_or_path",
+ type=str,
+ default=None,
+ help="Path to pretrained VAE model with better numerical stability. More details: https://github.com/huggingface/diffusers/pull/4038.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--variant",
+ type=str,
+ default=None,
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
+ )
+ parser.add_argument(
+ "--dataset_name",
+ type=str,
+ default=None,
+ help=(
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
+ " or to a folder containing files that 🤗 Datasets can understand."
+ ),
+ )
+ parser.add_argument(
+ "--dataset_config_name",
+ type=str,
+ default=None,
+ help="The config of the Dataset, leave as None if there's only one config.",
+ )
+ parser.add_argument(
+ "--train_data_dir",
+ type=str,
+ default=None,
+ help=(
+ "A folder containing the training data. Folder contents must follow the structure described in"
+ " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
+ " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
+ ),
+ )
+ parser.add_argument(
+ "--image_column", type=str, default="image", help="The column of the dataset containing an image."
+ )
+ parser.add_argument(
+ "--caption_column",
+ type=str,
+ default="text",
+ help="The column of the dataset containing a caption or a list of captions.",
+ )
+ parser.add_argument(
+ "--validation_prompt",
+ type=str,
+ default=None,
+ help="A prompt that is used during validation to verify that the model is learning.",
+ )
+ parser.add_argument(
+ "--num_validation_images",
+ type=int,
+ default=4,
+ help="Number of images that should be generated during validation with `validation_prompt`.",
+ )
+ parser.add_argument(
+ "--validation_epochs",
+ type=int,
+ default=1,
+ help=(
+ "Run fine-tuning validation every X epochs. The validation process consists of running the prompt"
+ " `args.validation_prompt` multiple times: `args.num_validation_images`."
+ ),
+ )
+ parser.add_argument(
+ "--max_train_samples",
+ type=int,
+ default=None,
+ help=(
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ ),
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="sd-model-finetuned-lora",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=1024,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--center_crop",
+ default=False,
+ action="store_true",
+ help=(
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
+ " cropped. The images will be resized to the resolution first before cropping."
+ ),
+ )
+ parser.add_argument(
+ "--random_flip",
+ action="store_true",
+ help="whether to randomly flip images horizontally",
+ )
+ parser.add_argument(
+ "--train_text_encoder",
+ action="store_true",
+ help="Whether to train the text encoder. If set, the text encoder should be float32 precision.",
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=100)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
+ " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--snr_gamma",
+ type=float,
+ default=None,
+ help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
+ "More details here: https://arxiv.org/abs/2303.09556.",
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ parser.add_argument(
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--prediction_type",
+ type=str,
+ default=None,
+ help="The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediction_type` is chosen.",
+ )
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+ parser.add_argument(
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
+ )
+ parser.add_argument(
+ "--enable_npu_flash_attention", action="store_true", help="Whether or not to use npu flash attention."
+ )
+ parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.")
+ parser.add_argument(
+ "--rank",
+ type=int,
+ default=4,
+ help=("The dimension of the LoRA update matrices."),
+ )
+ parser.add_argument(
+ "--debug_loss",
+ action="store_true",
+ help="debug loss for each image, if filenames are available in the dataset",
+ )
+
+ if input_args is not None:
+ args = parser.parse_args(input_args)
+ else:
+ args = parser.parse_args()
+
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ # Sanity checks
+ if args.dataset_name is None and args.train_data_dir is None:
+ raise ValueError("Need either a dataset name or a training folder.")
+
+ return args
+
+
+DATASET_NAME_MAPPING = {
+ "lambdalabs/naruto-blip-captions": ("image", "text"),
+}
+
+
+def tokenize_prompt(tokenizer, prompt):
+ text_inputs = tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ return text_input_ids
+
+
+# Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt
+def encode_prompt(text_encoders, tokenizers, prompt, text_input_ids_list=None):
+ prompt_embeds_list = []
+
+ for i, text_encoder in enumerate(text_encoders):
+ if tokenizers is not None:
+ tokenizer = tokenizers[i]
+ text_input_ids = tokenize_prompt(tokenizer, prompt)
+ else:
+ assert text_input_ids_list is not None
+ text_input_ids = text_input_ids_list[i]
+
+ prompt_embeds = text_encoder(
+ text_input_ids.to(text_encoder.device), output_hidden_states=True, return_dict=False
+ )
+
+ # We are only ALWAYS interested in the pooled output of the final text encoder
+ pooled_prompt_embeds = prompt_embeds[0]
+ prompt_embeds = prompt_embeds[-1][-2]
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
+ prompt_embeds_list.append(prompt_embeds)
+
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
+ pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
+ return prompt_embeds, pooled_prompt_embeds
+
+
+def main(args):
+ if args.report_to == "wandb" and args.hub_token is not None:
+ raise ValueError(
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
+ " Please use `huggingface-cli login` to authenticate with the Hub."
+ )
+
+ logging_dir = Path(args.output_dir, args.logging_dir)
+
+ if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
+ # due to pytorch#99272, MPS does not yet support bfloat16.
+ raise ValueError(
+ "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
+ )
+
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+ kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ kwargs_handlers=[kwargs],
+ )
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ datasets.utils.logging.set_verbosity_warning()
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ datasets.utils.logging.set_verbosity_error()
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
+ ).repo_id
+
+ # Load the tokenizers
+ tokenizer_one = AutoTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="tokenizer",
+ revision=args.revision,
+ use_fast=False,
+ )
+ tokenizer_two = AutoTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="tokenizer_2",
+ revision=args.revision,
+ use_fast=False,
+ )
+
+ # import correct text encoder classes
+ text_encoder_cls_one = import_model_class_from_model_name_or_path(
+ args.pretrained_model_name_or_path, args.revision
+ )
+ text_encoder_cls_two = import_model_class_from_model_name_or_path(
+ args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2"
+ )
+
+ # Load scheduler and models
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
+ text_encoder_one = text_encoder_cls_one.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
+ )
+ text_encoder_two = text_encoder_cls_two.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant
+ )
+ vae_path = (
+ args.pretrained_model_name_or_path
+ if args.pretrained_vae_model_name_or_path is None
+ else args.pretrained_vae_model_name_or_path
+ )
+ vae = AutoencoderKL.from_pretrained(
+ vae_path,
+ subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
+ revision=args.revision,
+ variant=args.variant,
+ )
+ unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
+ )
+
+ # We only train the additional adapter LoRA layers
+ vae.requires_grad_(False)
+ text_encoder_one.requires_grad_(False)
+ text_encoder_two.requires_grad_(False)
+ unet.requires_grad_(False)
+
+ # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
+ # as these weights are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ # Move unet, vae and text_encoder to device and cast to weight_dtype
+ # The VAE is in float32 to avoid NaN losses.
+ unet.to(accelerator.device, dtype=weight_dtype)
+
+ if args.pretrained_vae_model_name_or_path is None:
+ vae.to(accelerator.device, dtype=torch.float32)
+ else:
+ vae.to(accelerator.device, dtype=weight_dtype)
+ text_encoder_one.to(accelerator.device, dtype=weight_dtype)
+ text_encoder_two.to(accelerator.device, dtype=weight_dtype)
+
+ if args.enable_npu_flash_attention:
+ if is_torch_npu_available():
+ logger.info("npu flash attention enabled.")
+ unet.enable_npu_flash_attention()
+ else:
+ raise ValueError("npu flash attention requires torch_npu extensions and is supported only on npu devices.")
+
+ if args.enable_xformers_memory_efficient_attention:
+ if is_xformers_available():
+ import xformers
+
+ xformers_version = version.parse(xformers.__version__)
+ if xformers_version == version.parse("0.0.16"):
+ logger.warning(
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
+ )
+ unet.enable_xformers_memory_efficient_attention()
+ else:
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
+
+ # now we will add new LoRA weights to the attention layers
+ # Set correct lora layers
+ unet_lora_config = LoraConfig(
+ r=args.rank,
+ lora_alpha=args.rank,
+ init_lora_weights="gaussian",
+ target_modules=["to_k", "to_q", "to_v", "to_out.0"],
+ )
+
+ unet.add_adapter(unet_lora_config)
+
+ # The text encoder comes from 🤗 transformers, we will also attach adapters to it.
+ if args.train_text_encoder:
+ # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
+ text_lora_config = LoraConfig(
+ r=args.rank,
+ lora_alpha=args.rank,
+ init_lora_weights="gaussian",
+ target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
+ )
+ text_encoder_one.add_adapter(text_lora_config)
+ text_encoder_two.add_adapter(text_lora_config)
+
+ def unwrap_model(model):
+ model = accelerator.unwrap_model(model)
+ model = model._orig_mod if is_compiled_module(model) else model
+ return model
+
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
+ def save_model_hook(models, weights, output_dir):
+ if accelerator.is_main_process:
+ # there are only two options here. Either are just the unet attn processor layers
+ # or there are the unet and text encoder attn layers
+ unet_lora_layers_to_save = None
+ text_encoder_one_lora_layers_to_save = None
+ text_encoder_two_lora_layers_to_save = None
+
+ for model in models:
+ if isinstance(unwrap_model(model), type(unwrap_model(unet))):
+ unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model))
+ elif isinstance(unwrap_model(model), type(unwrap_model(text_encoder_one))):
+ text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers(
+ get_peft_model_state_dict(model)
+ )
+ elif isinstance(unwrap_model(model), type(unwrap_model(text_encoder_two))):
+ text_encoder_two_lora_layers_to_save = convert_state_dict_to_diffusers(
+ get_peft_model_state_dict(model)
+ )
+ else:
+ raise ValueError(f"unexpected save model: {model.__class__}")
+
+ # make sure to pop weight so that corresponding model is not saved again
+ if weights:
+ weights.pop()
+
+ StableDiffusionXLPipeline.save_lora_weights(
+ output_dir,
+ unet_lora_layers=unet_lora_layers_to_save,
+ text_encoder_lora_layers=text_encoder_one_lora_layers_to_save,
+ text_encoder_2_lora_layers=text_encoder_two_lora_layers_to_save,
+ )
+
+ def load_model_hook(models, input_dir):
+ unet_ = None
+ text_encoder_one_ = None
+ text_encoder_two_ = None
+
+ while len(models) > 0:
+ model = models.pop()
+
+ if isinstance(model, type(unwrap_model(unet))):
+ unet_ = model
+ elif isinstance(model, type(unwrap_model(text_encoder_one))):
+ text_encoder_one_ = model
+ elif isinstance(model, type(unwrap_model(text_encoder_two))):
+ text_encoder_two_ = model
+ else:
+ raise ValueError(f"unexpected save model: {model.__class__}")
+
+ lora_state_dict, _ = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir)
+ unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")}
+ unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
+ incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
+ if incompatible_keys is not None:
+ # check only for unexpected keys
+ unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
+ if unexpected_keys:
+ logger.warning(
+ f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
+ f" {unexpected_keys}. "
+ )
+
+ if args.train_text_encoder:
+ _set_state_dict_into_text_encoder(lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_one_)
+
+ _set_state_dict_into_text_encoder(
+ lora_state_dict, prefix="text_encoder_2.", text_encoder=text_encoder_two_
+ )
+
+ # Make sure the trainable params are in float32. This is again needed since the base models
+ # are in `weight_dtype`. More details:
+ # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
+ if args.mixed_precision == "fp16":
+ models = [unet_]
+ if args.train_text_encoder:
+ models.extend([text_encoder_one_, text_encoder_two_])
+ cast_training_params(models, dtype=torch.float32)
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
+
+ if args.gradient_checkpointing:
+ unet.enable_gradient_checkpointing()
+ if args.train_text_encoder:
+ text_encoder_one.gradient_checkpointing_enable()
+ text_encoder_two.gradient_checkpointing_enable()
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ # Make sure the trainable params are in float32.
+ if args.mixed_precision == "fp16":
+ models = [unet]
+ if args.train_text_encoder:
+ models.extend([text_encoder_one, text_encoder_two])
+ cast_training_params(models, dtype=torch.float32)
+
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
+ )
+
+ optimizer_class = bnb.optim.AdamW8bit
+ else:
+ optimizer_class = torch.optim.AdamW
+
+ # Optimizer creation
+ params_to_optimize = list(filter(lambda p: p.requires_grad, unet.parameters()))
+ if args.train_text_encoder:
+ params_to_optimize = (
+ params_to_optimize
+ + list(filter(lambda p: p.requires_grad, text_encoder_one.parameters()))
+ + list(filter(lambda p: p.requires_grad, text_encoder_two.parameters()))
+ )
+ optimizer = optimizer_class(
+ params_to_optimize,
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ # Get the datasets: you can either provide your own training and evaluation files (see below)
+ # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
+
+ # In distributed training, the load_dataset function guarantees that only one local process can concurrently
+ # download the dataset.
+ if args.dataset_name is not None:
+ # Downloading and loading a dataset from the hub.
+ dataset = load_dataset(
+ args.dataset_name, args.dataset_config_name, cache_dir=args.cache_dir, data_dir=args.train_data_dir
+ )
+ else:
+ data_files = {}
+ if args.train_data_dir is not None:
+ data_files["train"] = os.path.join(args.train_data_dir, "**")
+ dataset = load_dataset(
+ "imagefolder",
+ data_files=data_files,
+ cache_dir=args.cache_dir,
+ )
+ # See more about loading custom images at
+ # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
+
+ # Preprocessing the datasets.
+ # We need to tokenize inputs and targets.
+ column_names = dataset["train"].column_names
+
+ # 6. Get the column names for input/target.
+ dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None)
+ if args.image_column is None:
+ image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
+ else:
+ image_column = args.image_column
+ if image_column not in column_names:
+ raise ValueError(
+ f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}"
+ )
+ if args.caption_column is None:
+ caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
+ else:
+ caption_column = args.caption_column
+ if caption_column not in column_names:
+ raise ValueError(
+ f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}"
+ )
+
+ # Preprocessing the datasets.
+ # We need to tokenize input captions and transform the images.
+ def tokenize_captions(examples, is_train=True):
+ captions = []
+ for caption in examples[caption_column]:
+ if isinstance(caption, str):
+ captions.append(caption)
+ elif isinstance(caption, (list, np.ndarray)):
+ # take a random caption if there are multiple
+ captions.append(random.choice(caption) if is_train else caption[0])
+ else:
+ raise ValueError(
+ f"Caption column `{caption_column}` should contain either strings or lists of strings."
+ )
+ tokens_one = tokenize_prompt(tokenizer_one, captions)
+ tokens_two = tokenize_prompt(tokenizer_two, captions)
+ return tokens_one, tokens_two
+
+ # Preprocessing the datasets.
+ train_resize = transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR)
+ train_crop = transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution)
+ train_flip = transforms.RandomHorizontalFlip(p=1.0)
+ train_transforms = transforms.Compose(
+ [
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+ )
+
+ def preprocess_train(examples):
+ images = [image.convert("RGB") for image in examples[image_column]]
+ # image aug
+ original_sizes = []
+ all_images = []
+ crop_top_lefts = []
+ for image in images:
+ original_sizes.append((image.height, image.width))
+ image = train_resize(image)
+ if args.random_flip and random.random() < 0.5:
+ # flip
+ image = train_flip(image)
+ if args.center_crop:
+ y1 = max(0, int(round((image.height - args.resolution) / 2.0)))
+ x1 = max(0, int(round((image.width - args.resolution) / 2.0)))
+ image = train_crop(image)
+ else:
+ y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution))
+ image = crop(image, y1, x1, h, w)
+ crop_top_left = (y1, x1)
+ crop_top_lefts.append(crop_top_left)
+ image = train_transforms(image)
+ all_images.append(image)
+
+ examples["original_sizes"] = original_sizes
+ examples["crop_top_lefts"] = crop_top_lefts
+ examples["pixel_values"] = all_images
+ tokens_one, tokens_two = tokenize_captions(examples)
+ examples["input_ids_one"] = tokens_one
+ examples["input_ids_two"] = tokens_two
+ if args.debug_loss:
+ fnames = [os.path.basename(image.filename) for image in examples[image_column] if image.filename]
+ if fnames:
+ examples["filenames"] = fnames
+ return examples
+
+ with accelerator.main_process_first():
+ if args.max_train_samples is not None:
+ dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
+ # Set the training transforms
+ train_dataset = dataset["train"].with_transform(preprocess_train, output_all_columns=True)
+
+ def collate_fn(examples):
+ pixel_values = torch.stack([example["pixel_values"] for example in examples])
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
+ original_sizes = [example["original_sizes"] for example in examples]
+ crop_top_lefts = [example["crop_top_lefts"] for example in examples]
+ input_ids_one = torch.stack([example["input_ids_one"] for example in examples])
+ input_ids_two = torch.stack([example["input_ids_two"] for example in examples])
+ result = {
+ "pixel_values": pixel_values,
+ "input_ids_one": input_ids_one,
+ "input_ids_two": input_ids_two,
+ "original_sizes": original_sizes,
+ "crop_top_lefts": crop_top_lefts,
+ }
+
+ filenames = [example["filenames"] for example in examples if "filenames" in example]
+ if filenames:
+ result["filenames"] = filenames
+ return result
+
+ # DataLoaders creation:
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset,
+ shuffle=True,
+ collate_fn=collate_fn,
+ batch_size=args.train_batch_size,
+ num_workers=args.dataloader_num_workers,
+ )
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
+ num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
+ )
+
+ # Prepare everything with our `accelerator`.
+ if args.train_text_encoder:
+ unet, text_encoder_one, text_encoder_two, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ unet, text_encoder_one, text_encoder_two, optimizer, train_dataloader, lr_scheduler
+ )
+ else:
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ unet, optimizer, train_dataloader, lr_scheduler
+ )
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ accelerator.init_trackers("text2image-fine-tune", config=vars(args))
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the most recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ initial_global_step = 0
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ initial_global_step = global_step
+ first_epoch = global_step // num_update_steps_per_epoch
+
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
+
+ for epoch in range(first_epoch, args.num_train_epochs):
+ unet.train()
+ if args.train_text_encoder:
+ text_encoder_one.train()
+ text_encoder_two.train()
+ train_loss = 0.0
+ for step, batch in enumerate(train_dataloader):
+ with accelerator.accumulate(unet):
+ # Convert images to latent space
+ if args.pretrained_vae_model_name_or_path is not None:
+ pixel_values = batch["pixel_values"].to(dtype=weight_dtype)
+ else:
+ pixel_values = batch["pixel_values"]
+
+ model_input = vae.encode(pixel_values).latent_dist.sample()
+ model_input = model_input * vae.config.scaling_factor
+ if args.pretrained_vae_model_name_or_path is None:
+ model_input = model_input.to(weight_dtype)
+
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(model_input)
+ if args.noise_offset:
+ # https://www.crosslabs.org//blog/diffusion-with-offset-noise
+ noise += args.noise_offset * torch.randn(
+ (model_input.shape[0], model_input.shape[1], 1, 1), device=model_input.device
+ )
+
+ bsz = model_input.shape[0]
+ # Sample a random timestep for each image
+ timesteps = torch.randint(
+ 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device
+ )
+ timesteps = timesteps.long()
+
+ # Add noise to the model input according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)
+
+ # time ids
+ def compute_time_ids(original_size, crops_coords_top_left):
+ # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
+ target_size = (args.resolution, args.resolution)
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
+ add_time_ids = torch.tensor([add_time_ids])
+ add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype)
+ return add_time_ids
+
+ add_time_ids = torch.cat(
+ [compute_time_ids(s, c) for s, c in zip(batch["original_sizes"], batch["crop_top_lefts"])]
+ )
+
+ # Predict the noise residual
+ unet_added_conditions = {"time_ids": add_time_ids}
+ prompt_embeds, pooled_prompt_embeds = encode_prompt(
+ text_encoders=[text_encoder_one, text_encoder_two],
+ tokenizers=None,
+ prompt=None,
+ text_input_ids_list=[batch["input_ids_one"], batch["input_ids_two"]],
+ )
+ unet_added_conditions.update({"text_embeds": pooled_prompt_embeds})
+ model_pred = unet(
+ noisy_model_input,
+ timesteps,
+ prompt_embeds,
+ added_cond_kwargs=unet_added_conditions,
+ return_dict=False,
+ )[0]
+
+ # Get the target for loss depending on the prediction type
+ if args.prediction_type is not None:
+ # set prediction_type of scheduler if defined
+ noise_scheduler.register_to_config(prediction_type=args.prediction_type)
+
+ if noise_scheduler.config.prediction_type == "epsilon":
+ target = noise
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ target = noise_scheduler.get_velocity(model_input, noise, timesteps)
+ else:
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
+
+ if args.snr_gamma is None:
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
+ else:
+ # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
+ # Since we predict the noise instead of x_0, the original formulation is slightly changed.
+ # This is discussed in Section 4.2 of the same paper.
+ snr = compute_snr(noise_scheduler, timesteps)
+ mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(
+ dim=1
+ )[0]
+ if noise_scheduler.config.prediction_type == "epsilon":
+ mse_loss_weights = mse_loss_weights / snr
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ mse_loss_weights = mse_loss_weights / (snr + 1)
+
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
+ loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
+ loss = loss.mean()
+ if args.debug_loss and "filenames" in batch:
+ for fname in batch["filenames"]:
+ accelerator.log({"loss_for_" + fname: loss}, step=global_step)
+ # Gather the losses across all processes for logging (if we use distributed training).
+ avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
+ train_loss += avg_loss.item() / args.gradient_accumulation_steps
+
+ # Backpropagate
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ accelerator.clip_grad_norm_(params_to_optimize, args.max_grad_norm)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ global_step += 1
+ accelerator.log({"train_loss": train_loss}, step=global_step)
+ train_loss = 0.0
+
+ # DeepSpeed requires saving weights on every device; saving weights only on the main process would cause issues.
+ if accelerator.distributed_type == DistributedType.DEEPSPEED or accelerator.is_main_process:
+ if global_step % args.checkpointing_steps == 0:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ if accelerator.is_main_process:
+ if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
+ # create pipeline
+ pipeline = StableDiffusionXLPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ vae=vae,
+ text_encoder=unwrap_model(text_encoder_one),
+ text_encoder_2=unwrap_model(text_encoder_two),
+ unet=unwrap_model(unet),
+ revision=args.revision,
+ variant=args.variant,
+ torch_dtype=weight_dtype,
+ )
+
+ images = log_validation(pipeline, args, accelerator, epoch)
+
+ del pipeline
+ torch.cuda.empty_cache()
+
+ # Save the lora layers
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ unet = unwrap_model(unet)
+ unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet))
+
+ if args.train_text_encoder:
+ text_encoder_one = unwrap_model(text_encoder_one)
+ text_encoder_two = unwrap_model(text_encoder_two)
+
+ text_encoder_lora_layers = convert_state_dict_to_diffusers(get_peft_model_state_dict(text_encoder_one))
+ text_encoder_2_lora_layers = convert_state_dict_to_diffusers(get_peft_model_state_dict(text_encoder_two))
+ else:
+ text_encoder_lora_layers = None
+ text_encoder_2_lora_layers = None
+
+ StableDiffusionXLPipeline.save_lora_weights(
+ save_directory=args.output_dir,
+ unet_lora_layers=unet_lora_state_dict,
+ text_encoder_lora_layers=text_encoder_lora_layers,
+ text_encoder_2_lora_layers=text_encoder_2_lora_layers,
+ )
+
+ del unet
+ del text_encoder_one
+ del text_encoder_two
+ del text_encoder_lora_layers
+ del text_encoder_2_lora_layers
+ torch.cuda.empty_cache()
+
+ # Final inference
+ # Make sure vae.dtype is consistent with the unet.dtype
+ if args.mixed_precision == "fp16":
+ vae.to(weight_dtype)
+ # Load previous pipeline
+ pipeline = StableDiffusionXLPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ vae=vae,
+ revision=args.revision,
+ variant=args.variant,
+ torch_dtype=weight_dtype,
+ )
+
+ # load attention processors
+ pipeline.load_lora_weights(args.output_dir)
+
+ # run inference
+ if args.validation_prompt and args.num_validation_images > 0:
+ images = log_validation(pipeline, args, accelerator, epoch, is_final_validation=True)
+
+ if args.push_to_hub:
+ save_model_card(
+ repo_id,
+ images=images,
+ base_model=args.pretrained_model_name_or_path,
+ dataset_name=args.dataset_name,
+ train_text_encoder=args.train_text_encoder,
+ repo_folder=args.output_dir,
+ vae_path=args.pretrained_vae_model_name_or_path,
+ )
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ main(args)
diff --git a/diffusers/examples/text_to_image/train_text_to_image_sdxl.py b/diffusers/examples/text_to_image/train_text_to_image_sdxl.py
new file mode 100644
index 0000000000000000000000000000000000000000..2ca511c857ae62ae62bbec66b801ca3c0960ba08
--- /dev/null
+++ b/diffusers/examples/text_to_image/train_text_to_image_sdxl.py
@@ -0,0 +1,1345 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Fine-tuning script for Stable Diffusion XL for text2image."""
+
+import argparse
+import functools
+import gc
+import logging
+import math
+import os
+import random
+import shutil
+from contextlib import nullcontext
+from pathlib import Path
+
+import accelerate
+import datasets
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import DistributedType, ProjectConfiguration, set_seed
+from datasets import concatenate_datasets, load_dataset
+from huggingface_hub import create_repo, upload_folder
+from packaging import version
+from torchvision import transforms
+from torchvision.transforms.functional import crop
+from tqdm.auto import tqdm
+from transformers import AutoTokenizer, PretrainedConfig
+
+import diffusers
+from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionXLPipeline, UNet2DConditionModel
+from diffusers.optimization import get_scheduler
+from diffusers.training_utils import EMAModel, compute_snr
+from diffusers.utils import check_min_version, is_wandb_available
+from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
+from diffusers.utils.import_utils import is_torch_npu_available, is_xformers_available
+from diffusers.utils.torch_utils import is_compiled_module
+
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.31.0.dev0")
+
+logger = get_logger(__name__)
+if is_torch_npu_available():
+ torch.npu.config.allow_internal_format = False
+
+DATASET_NAME_MAPPING = {
+ "lambdalabs/naruto-blip-captions": ("image", "text"),
+}
+
+
+def save_model_card(
+ repo_id: str,
+ images: list = None,
+ validation_prompt: str = None,
+ base_model: str = None,
+ dataset_name: str = None,
+ repo_folder: str = None,
+ vae_path: str = None,
+):
+ img_str = ""
+ if images is not None:
+ for i, image in enumerate(images):
+ image.save(os.path.join(repo_folder, f"image_{i}.png"))
+ img_str += f"\n"
+
+ model_description = f"""
+# Text-to-image finetuning - {repo_id}
+
+This pipeline was finetuned from **{base_model}** on the **{dataset_name}** dataset. Below are some example images generated with the finetuned pipeline using the following prompt: {validation_prompt}: \n
+{img_str}
+
+Special VAE used for training: {vae_path}.
+"""
+
+ model_card = load_or_create_model_card(
+ repo_id_or_path=repo_id,
+ from_training=True,
+ license="creativeml-openrail-m",
+ base_model=base_model,
+ model_description=model_description,
+ inference=True,
+ )
+
+ tags = [
+ "stable-diffusion-xl",
+ "stable-diffusion-xl-diffusers",
+ "text-to-image",
+ "diffusers-training",
+ "diffusers",
+ ]
+ model_card = populate_model_card(model_card, tags=tags)
+
+ model_card.save(os.path.join(repo_folder, "README.md"))
+
+
+def import_model_class_from_model_name_or_path(
+ pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
+):
+ text_encoder_config = PretrainedConfig.from_pretrained(
+ pretrained_model_name_or_path, subfolder=subfolder, revision=revision
+ )
+ model_class = text_encoder_config.architectures[0]
+
+ if model_class == "CLIPTextModel":
+ from transformers import CLIPTextModel
+
+ return CLIPTextModel
+ elif model_class == "CLIPTextModelWithProjection":
+ from transformers import CLIPTextModelWithProjection
+
+ return CLIPTextModelWithProjection
+ else:
+ raise ValueError(f"{model_class} is not supported.")
+
+
+def parse_args(input_args=None):
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--pretrained_vae_model_name_or_path",
+ type=str,
+ default=None,
+ help="Path to pretrained VAE model with better numerical stability. More details: https://github.com/huggingface/diffusers/pull/4038.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--variant",
+ type=str,
+ default=None,
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
+ )
+ parser.add_argument(
+ "--dataset_name",
+ type=str,
+ default=None,
+ help=(
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
+ " or to a folder containing files that 🤗 Datasets can understand."
+ ),
+ )
+ parser.add_argument(
+ "--dataset_config_name",
+ type=str,
+ default=None,
+ help="The config of the Dataset, leave as None if there's only one config.",
+ )
+ parser.add_argument(
+ "--train_data_dir",
+ type=str,
+ default=None,
+ help=(
+ "A folder containing the training data. Folder contents must follow the structure described in"
+ " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
+ " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
+ ),
+ )
+ parser.add_argument(
+ "--image_column", type=str, default="image", help="The column of the dataset containing an image."
+ )
+ parser.add_argument(
+ "--caption_column",
+ type=str,
+ default="text",
+ help="The column of the dataset containing a caption or a list of captions.",
+ )
+ parser.add_argument(
+ "--validation_prompt",
+ type=str,
+ default=None,
+ help="A prompt that is used during validation to verify that the model is learning.",
+ )
+ parser.add_argument(
+ "--num_validation_images",
+ type=int,
+ default=4,
+ help="Number of images that should be generated during validation with `validation_prompt`.",
+ )
+ parser.add_argument(
+ "--validation_epochs",
+ type=int,
+ default=1,
+ help=(
+ "Run fine-tuning validation every X epochs. The validation process consists of running the prompt"
+ " `args.validation_prompt` multiple times: `args.num_validation_images`."
+ ),
+ )
+ parser.add_argument(
+ "--max_train_samples",
+ type=int,
+ default=None,
+ help=(
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ ),
+ )
+ parser.add_argument(
+ "--proportion_empty_prompts",
+ type=float,
+ default=0,
+ help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).",
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="sdxl-model-finetuned",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=1024,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--center_crop",
+ default=False,
+ action="store_true",
+ help=(
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
+ " cropped. The images will be resized to the resolution first before cropping."
+ ),
+ )
+ parser.add_argument(
+ "--random_flip",
+ action="store_true",
+ help="whether to randomly flip images horizontally",
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=100)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
+ " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--timestep_bias_strategy",
+ type=str,
+ default="none",
+ choices=["earlier", "later", "range", "none"],
+ help=(
+ "The timestep bias strategy, which may help direct the model toward learning low or high frequency details."
+ " Choices: ['earlier', 'later', 'range', 'none']."
+ " The default is 'none', which means no bias is applied, and training proceeds normally."
+ " The value of 'later' will increase the frequency of the model's final training timesteps."
+ ),
+ )
+ parser.add_argument(
+ "--timestep_bias_multiplier",
+ type=float,
+ default=1.0,
+ help=(
+ "The multiplier for the bias. Defaults to 1.0, which means no bias is applied."
+ " A value of 2.0 will double the weight of the bias, and a value of 0.5 will halve it."
+ ),
+ )
+ parser.add_argument(
+ "--timestep_bias_begin",
+ type=int,
+ default=0,
+ help=(
+ "When using `--timestep_bias_strategy=range`, the beginning (inclusive) timestep to bias."
+ " Defaults to zero, which equates to having no specific bias."
+ ),
+ )
+ parser.add_argument(
+ "--timestep_bias_end",
+ type=int,
+ default=1000,
+ help=(
+ "When using `--timestep_bias_strategy=range`, the final timestep (inclusive) to bias."
+ " Defaults to 1000, which is the number of timesteps that Stable Diffusion is trained on."
+ ),
+ )
+ parser.add_argument(
+ "--timestep_bias_portion",
+ type=float,
+ default=0.25,
+ help=(
+ "The portion of timesteps to bias. Defaults to 0.25, which 25% of timesteps will be biased."
+ " A value of 0.5 will bias one half of the timesteps. The value provided for `--timestep_bias_strategy` determines"
+ " whether the biased portions are in the earlier or later timesteps."
+ ),
+ )
+ parser.add_argument(
+ "--snr_gamma",
+ type=float,
+ default=None,
+ help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
+ "More details here: https://arxiv.org/abs/2303.09556.",
+ )
+ parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ parser.add_argument(
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--prediction_type",
+ type=str,
+ default=None,
+ help="The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediction_type` is chosen.",
+ )
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+ parser.add_argument(
+ "--enable_npu_flash_attention", action="store_true", help="Whether or not to use npu flash attention."
+ )
+ parser.add_argument(
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
+ )
+ parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.")
+
+ if input_args is not None:
+ args = parser.parse_args(input_args)
+ else:
+ args = parser.parse_args()
+
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ # Sanity checks
+ if args.dataset_name is None and args.train_data_dir is None:
+ raise ValueError("Need either a dataset name or a training folder.")
+
+ if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1:
+ raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].")
+
+ return args
+
+
+# Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt
+def encode_prompt(batch, text_encoders, tokenizers, proportion_empty_prompts, caption_column, is_train=True):
+ prompt_embeds_list = []
+ prompt_batch = batch[caption_column]
+
+ captions = []
+ for caption in prompt_batch:
+ if random.random() < proportion_empty_prompts:
+ captions.append("")
+ elif isinstance(caption, str):
+ captions.append(caption)
+ elif isinstance(caption, (list, np.ndarray)):
+ # take a random caption if there are multiple
+ captions.append(random.choice(caption) if is_train else caption[0])
+
+ with torch.no_grad():
+ for tokenizer, text_encoder in zip(tokenizers, text_encoders):
+ text_inputs = tokenizer(
+ captions,
+ padding="max_length",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ prompt_embeds = text_encoder(
+ text_input_ids.to(text_encoder.device),
+ output_hidden_states=True,
+ return_dict=False,
+ )
+
+ # We are only ALWAYS interested in the pooled output of the final text encoder
+ pooled_prompt_embeds = prompt_embeds[0]
+ prompt_embeds = prompt_embeds[-1][-2]
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
+ prompt_embeds_list.append(prompt_embeds)
+
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
+ pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
+ return {"prompt_embeds": prompt_embeds.cpu(), "pooled_prompt_embeds": pooled_prompt_embeds.cpu()}
+
+
+def compute_vae_encodings(batch, vae):
+ images = batch.pop("pixel_values")
+ pixel_values = torch.stack(list(images))
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
+ pixel_values = pixel_values.to(vae.device, dtype=vae.dtype)
+
+ with torch.no_grad():
+ model_input = vae.encode(pixel_values).latent_dist.sample()
+ model_input = model_input * vae.config.scaling_factor
+ return {"model_input": model_input.cpu()}
+
+
+def generate_timestep_weights(args, num_timesteps):
+ weights = torch.ones(num_timesteps)
+
+ # Determine the indices to bias
+ num_to_bias = int(args.timestep_bias_portion * num_timesteps)
+
+ if args.timestep_bias_strategy == "later":
+ bias_indices = slice(-num_to_bias, None)
+ elif args.timestep_bias_strategy == "earlier":
+ bias_indices = slice(0, num_to_bias)
+ elif args.timestep_bias_strategy == "range":
+ # Out of the possible 1000 timesteps, we might want to focus on eg. 200-500.
+ range_begin = args.timestep_bias_begin
+ range_end = args.timestep_bias_end
+ if range_begin < 0:
+ raise ValueError(
+ "When using the range strategy for timestep bias, you must provide a beginning timestep greater or equal to zero."
+ )
+ if range_end > num_timesteps:
+ raise ValueError(
+ "When using the range strategy for timestep bias, you must provide an ending timestep smaller than the number of timesteps."
+ )
+ bias_indices = slice(range_begin, range_end)
+ else: # 'none' or any other string
+ return weights
+ if args.timestep_bias_multiplier <= 0:
+ return ValueError(
+ "The parameter --timestep_bias_multiplier is not intended to be used to disable the training of specific timesteps."
+ " If it was intended to disable timestep bias, use `--timestep_bias_strategy none` instead."
+ " A timestep bias multiplier less than or equal to 0 is not allowed."
+ )
+
+ # Apply the bias
+ weights[bias_indices] *= args.timestep_bias_multiplier
+
+ # Normalize
+ weights /= weights.sum()
+
+ return weights
+
+
+def main(args):
+ if args.report_to == "wandb" and args.hub_token is not None:
+ raise ValueError(
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
+ " Please use `huggingface-cli login` to authenticate with the Hub."
+ )
+
+ logging_dir = Path(args.output_dir, args.logging_dir)
+
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+
+ if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
+ # due to pytorch#99272, MPS does not yet support bfloat16.
+ raise ValueError(
+ "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
+ )
+
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ )
+
+ # Disable AMP for MPS.
+ if torch.backends.mps.is_available():
+ accelerator.native_amp = False
+
+ if args.report_to == "wandb":
+ if not is_wandb_available():
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
+ import wandb
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ datasets.utils.logging.set_verbosity_warning()
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ datasets.utils.logging.set_verbosity_error()
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
+ ).repo_id
+
+ # Load the tokenizers
+ tokenizer_one = AutoTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="tokenizer",
+ revision=args.revision,
+ use_fast=False,
+ )
+ tokenizer_two = AutoTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="tokenizer_2",
+ revision=args.revision,
+ use_fast=False,
+ )
+
+ # import correct text encoder classes
+ text_encoder_cls_one = import_model_class_from_model_name_or_path(
+ args.pretrained_model_name_or_path, args.revision
+ )
+ text_encoder_cls_two = import_model_class_from_model_name_or_path(
+ args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2"
+ )
+
+ # Load scheduler and models
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
+ # Check for terminal SNR in combination with SNR Gamma
+ text_encoder_one = text_encoder_cls_one.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
+ )
+ text_encoder_two = text_encoder_cls_two.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant
+ )
+ vae_path = (
+ args.pretrained_model_name_or_path
+ if args.pretrained_vae_model_name_or_path is None
+ else args.pretrained_vae_model_name_or_path
+ )
+ vae = AutoencoderKL.from_pretrained(
+ vae_path,
+ subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
+ revision=args.revision,
+ variant=args.variant,
+ )
+ unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
+ )
+
+ # Freeze vae and text encoders.
+ vae.requires_grad_(False)
+ text_encoder_one.requires_grad_(False)
+ text_encoder_two.requires_grad_(False)
+ # Set unet as trainable.
+ unet.train()
+
+ # For mixed precision training we cast all non-trainable weights to half-precision
+ # as these weights are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ # Move unet, vae and text_encoder to device and cast to weight_dtype
+ # The VAE is in float32 to avoid NaN losses.
+ vae.to(accelerator.device, dtype=torch.float32)
+ text_encoder_one.to(accelerator.device, dtype=weight_dtype)
+ text_encoder_two.to(accelerator.device, dtype=weight_dtype)
+
+ # Create EMA for the unet.
+ if args.use_ema:
+ ema_unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
+ )
+ ema_unet = EMAModel(ema_unet.parameters(), model_cls=UNet2DConditionModel, model_config=ema_unet.config)
+ if args.enable_npu_flash_attention:
+ if is_torch_npu_available():
+ logger.info("npu flash attention enabled.")
+ unet.enable_npu_flash_attention()
+ else:
+ raise ValueError("npu flash attention requires torch_npu extensions and is supported only on npu devices.")
+ if args.enable_xformers_memory_efficient_attention:
+ if is_xformers_available():
+ import xformers
+
+ xformers_version = version.parse(xformers.__version__)
+ if xformers_version == version.parse("0.0.16"):
+ logger.warning(
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
+ )
+ unet.enable_xformers_memory_efficient_attention()
+ else:
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
+
+ # `accelerate` 0.16.0 will have better support for customized saving
+ if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
+ def save_model_hook(models, weights, output_dir):
+ if accelerator.is_main_process:
+ if args.use_ema:
+ ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema"))
+
+ for i, model in enumerate(models):
+ model.save_pretrained(os.path.join(output_dir, "unet"))
+
+ # make sure to pop weight so that corresponding model is not saved again
+ if weights:
+ weights.pop()
+
+ def load_model_hook(models, input_dir):
+ if args.use_ema:
+ load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), UNet2DConditionModel)
+ ema_unet.load_state_dict(load_model.state_dict())
+ ema_unet.to(accelerator.device)
+ del load_model
+
+ for _ in range(len(models)):
+ # pop models so that they are not loaded again
+ model = models.pop()
+
+ # load diffusers style into model
+ load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet")
+ model.register_to_config(**load_model.config)
+
+ model.load_state_dict(load_model.state_dict())
+ del load_model
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
+
+ if args.gradient_checkpointing:
+ unet.enable_gradient_checkpointing()
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
+ )
+
+ optimizer_class = bnb.optim.AdamW8bit
+ else:
+ optimizer_class = torch.optim.AdamW
+
+ # Optimizer creation
+ params_to_optimize = unet.parameters()
+ optimizer = optimizer_class(
+ params_to_optimize,
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ # Get the datasets: you can either provide your own training and evaluation files (see below)
+ # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
+
+ # In distributed training, the load_dataset function guarantees that only one local process can concurrently
+ # download the dataset.
+ if args.dataset_name is not None:
+ # Downloading and loading a dataset from the hub.
+ dataset = load_dataset(
+ args.dataset_name,
+ args.dataset_config_name,
+ cache_dir=args.cache_dir,
+ )
+ else:
+ data_files = {}
+ if args.train_data_dir is not None:
+ data_files["train"] = os.path.join(args.train_data_dir, "**")
+ dataset = load_dataset(
+ "imagefolder",
+ data_files=data_files,
+ cache_dir=args.cache_dir,
+ )
+ # See more about loading custom images at
+ # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
+
+ # Preprocessing the datasets.
+ # We need to tokenize inputs and targets.
+ column_names = dataset["train"].column_names
+
+ # 6. Get the column names for input/target.
+ dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None)
+ if args.image_column is None:
+ image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
+ else:
+ image_column = args.image_column
+ if image_column not in column_names:
+ raise ValueError(
+ f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}"
+ )
+ if args.caption_column is None:
+ caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
+ else:
+ caption_column = args.caption_column
+ if caption_column not in column_names:
+ raise ValueError(
+ f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}"
+ )
+
+ # Preprocessing the datasets.
+ train_resize = transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR)
+ train_crop = transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution)
+ train_flip = transforms.RandomHorizontalFlip(p=1.0)
+ train_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
+
+ def preprocess_train(examples):
+ images = [image.convert("RGB") for image in examples[image_column]]
+ # image aug
+ original_sizes = []
+ all_images = []
+ crop_top_lefts = []
+ for image in images:
+ original_sizes.append((image.height, image.width))
+ image = train_resize(image)
+ if args.random_flip and random.random() < 0.5:
+ # flip
+ image = train_flip(image)
+ if args.center_crop:
+ y1 = max(0, int(round((image.height - args.resolution) / 2.0)))
+ x1 = max(0, int(round((image.width - args.resolution) / 2.0)))
+ image = train_crop(image)
+ else:
+ y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution))
+ image = crop(image, y1, x1, h, w)
+ crop_top_left = (y1, x1)
+ crop_top_lefts.append(crop_top_left)
+ image = train_transforms(image)
+ all_images.append(image)
+
+ examples["original_sizes"] = original_sizes
+ examples["crop_top_lefts"] = crop_top_lefts
+ examples["pixel_values"] = all_images
+ return examples
+
+ with accelerator.main_process_first():
+ if args.max_train_samples is not None:
+ dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
+ # Set the training transforms
+ train_dataset = dataset["train"].with_transform(preprocess_train)
+
+ # Let's first compute all the embeddings so that we can free up the text encoders
+ # from memory. We will pre-compute the VAE encodings too.
+ text_encoders = [text_encoder_one, text_encoder_two]
+ tokenizers = [tokenizer_one, tokenizer_two]
+ compute_embeddings_fn = functools.partial(
+ encode_prompt,
+ text_encoders=text_encoders,
+ tokenizers=tokenizers,
+ proportion_empty_prompts=args.proportion_empty_prompts,
+ caption_column=args.caption_column,
+ )
+ compute_vae_encodings_fn = functools.partial(compute_vae_encodings, vae=vae)
+ with accelerator.main_process_first():
+ from datasets.fingerprint import Hasher
+
+ # fingerprint used by the cache for the other processes to load the result
+ # details: https://github.com/huggingface/diffusers/pull/4038#discussion_r1266078401
+ new_fingerprint = Hasher.hash(args)
+ new_fingerprint_for_vae = Hasher.hash(vae_path)
+ train_dataset_with_embeddings = train_dataset.map(
+ compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint
+ )
+ train_dataset_with_vae = train_dataset.map(
+ compute_vae_encodings_fn,
+ batched=True,
+ batch_size=args.train_batch_size,
+ new_fingerprint=new_fingerprint_for_vae,
+ )
+ precomputed_dataset = concatenate_datasets(
+ [train_dataset_with_embeddings, train_dataset_with_vae.remove_columns(["image", "text"])], axis=1
+ )
+ precomputed_dataset = precomputed_dataset.with_transform(preprocess_train)
+
+ del compute_vae_encodings_fn, compute_embeddings_fn, text_encoder_one, text_encoder_two
+ del text_encoders, tokenizers, vae
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ def collate_fn(examples):
+ model_input = torch.stack([torch.tensor(example["model_input"]) for example in examples])
+ original_sizes = [example["original_sizes"] for example in examples]
+ crop_top_lefts = [example["crop_top_lefts"] for example in examples]
+ prompt_embeds = torch.stack([torch.tensor(example["prompt_embeds"]) for example in examples])
+ pooled_prompt_embeds = torch.stack([torch.tensor(example["pooled_prompt_embeds"]) for example in examples])
+
+ return {
+ "model_input": model_input,
+ "prompt_embeds": prompt_embeds,
+ "pooled_prompt_embeds": pooled_prompt_embeds,
+ "original_sizes": original_sizes,
+ "crop_top_lefts": crop_top_lefts,
+ }
+
+ # DataLoaders creation:
+ train_dataloader = torch.utils.data.DataLoader(
+ precomputed_dataset,
+ shuffle=True,
+ collate_fn=collate_fn,
+ batch_size=args.train_batch_size,
+ num_workers=args.dataloader_num_workers,
+ )
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
+ num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
+ )
+
+ # Prepare everything with our `accelerator`.
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ unet, optimizer, train_dataloader, lr_scheduler
+ )
+
+ if args.use_ema:
+ ema_unet.to(accelerator.device)
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ accelerator.init_trackers("text2image-fine-tune-sdxl", config=vars(args))
+
+ # Function for unwrapping if torch.compile() was used in accelerate.
+ def unwrap_model(model):
+ model = accelerator.unwrap_model(model)
+ model = model._orig_mod if is_compiled_module(model) else model
+ return model
+
+ if torch.backends.mps.is_available() or "playground" in args.pretrained_model_name_or_path:
+ autocast_ctx = nullcontext()
+ else:
+ autocast_ctx = torch.autocast(accelerator.device.type)
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(precomputed_dataset)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the most recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ initial_global_step = 0
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ initial_global_step = global_step
+ first_epoch = global_step // num_update_steps_per_epoch
+
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
+
+ for epoch in range(first_epoch, args.num_train_epochs):
+ train_loss = 0.0
+ for step, batch in enumerate(train_dataloader):
+ with accelerator.accumulate(unet):
+ # Sample noise that we'll add to the latents
+ model_input = batch["model_input"].to(accelerator.device)
+ noise = torch.randn_like(model_input)
+ if args.noise_offset:
+ # https://www.crosslabs.org//blog/diffusion-with-offset-noise
+ noise += args.noise_offset * torch.randn(
+ (model_input.shape[0], model_input.shape[1], 1, 1), device=model_input.device
+ )
+
+ bsz = model_input.shape[0]
+ if args.timestep_bias_strategy == "none":
+ # Sample a random timestep for each image without bias.
+ timesteps = torch.randint(
+ 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device
+ )
+ else:
+ # Sample a random timestep for each image, potentially biased by the timestep weights.
+ # Biasing the timestep weights allows us to spend less time training irrelevant timesteps.
+ weights = generate_timestep_weights(args, noise_scheduler.config.num_train_timesteps).to(
+ model_input.device
+ )
+ timesteps = torch.multinomial(weights, bsz, replacement=True).long()
+
+ # Add noise to the model input according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps).to(dtype=weight_dtype)
+
+ # time ids
+ def compute_time_ids(original_size, crops_coords_top_left):
+ # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
+ target_size = (args.resolution, args.resolution)
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
+ add_time_ids = torch.tensor([add_time_ids])
+ add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype)
+ return add_time_ids
+
+ add_time_ids = torch.cat(
+ [compute_time_ids(s, c) for s, c in zip(batch["original_sizes"], batch["crop_top_lefts"])]
+ )
+
+ # Predict the noise residual
+ unet_added_conditions = {"time_ids": add_time_ids}
+ prompt_embeds = batch["prompt_embeds"].to(accelerator.device, dtype=weight_dtype)
+ pooled_prompt_embeds = batch["pooled_prompt_embeds"].to(accelerator.device)
+ unet_added_conditions.update({"text_embeds": pooled_prompt_embeds})
+ model_pred = unet(
+ noisy_model_input,
+ timesteps,
+ prompt_embeds,
+ added_cond_kwargs=unet_added_conditions,
+ return_dict=False,
+ )[0]
+
+ # Get the target for loss depending on the prediction type
+ if args.prediction_type is not None:
+ # set prediction_type of scheduler if defined
+ noise_scheduler.register_to_config(prediction_type=args.prediction_type)
+
+ if noise_scheduler.config.prediction_type == "epsilon":
+ target = noise
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ target = noise_scheduler.get_velocity(model_input, noise, timesteps)
+ elif noise_scheduler.config.prediction_type == "sample":
+ # We set the target to latents here, but the model_pred will return the noise sample prediction.
+ target = model_input
+ # We will have to subtract the noise residual from the prediction to get the target sample.
+ model_pred = model_pred - noise
+ else:
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
+
+ if args.snr_gamma is None:
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
+ else:
+ # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
+ # Since we predict the noise instead of x_0, the original formulation is slightly changed.
+ # This is discussed in Section 4.2 of the same paper.
+ snr = compute_snr(noise_scheduler, timesteps)
+ mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(
+ dim=1
+ )[0]
+ if noise_scheduler.config.prediction_type == "epsilon":
+ mse_loss_weights = mse_loss_weights / snr
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ mse_loss_weights = mse_loss_weights / (snr + 1)
+
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
+ loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
+ loss = loss.mean()
+
+ # Gather the losses across all processes for logging (if we use distributed training).
+ avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
+ train_loss += avg_loss.item() / args.gradient_accumulation_steps
+
+ # Backpropagate
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ params_to_clip = unet.parameters()
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ if args.use_ema:
+ ema_unet.step(unet.parameters())
+ progress_bar.update(1)
+ global_step += 1
+ accelerator.log({"train_loss": train_loss}, step=global_step)
+ train_loss = 0.0
+
+ # DeepSpeed requires saving weights on every device; saving weights only on the main process would cause issues.
+ if accelerator.distributed_type == DistributedType.DEEPSPEED or accelerator.is_main_process:
+ if global_step % args.checkpointing_steps == 0:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ if accelerator.is_main_process:
+ if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
+ logger.info(
+ f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
+ f" {args.validation_prompt}."
+ )
+ if args.use_ema:
+ # Store the UNet parameters temporarily and load the EMA parameters to perform inference.
+ ema_unet.store(unet.parameters())
+ ema_unet.copy_to(unet.parameters())
+
+ # create pipeline
+ vae = AutoencoderKL.from_pretrained(
+ vae_path,
+ subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
+ revision=args.revision,
+ variant=args.variant,
+ )
+ pipeline = StableDiffusionXLPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ vae=vae,
+ unet=accelerator.unwrap_model(unet),
+ revision=args.revision,
+ variant=args.variant,
+ torch_dtype=weight_dtype,
+ )
+ if args.prediction_type is not None:
+ scheduler_args = {"prediction_type": args.prediction_type}
+ pipeline.scheduler = pipeline.scheduler.from_config(pipeline.scheduler.config, **scheduler_args)
+
+ pipeline = pipeline.to(accelerator.device)
+ pipeline.set_progress_bar_config(disable=True)
+
+ # run inference
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
+ pipeline_args = {"prompt": args.validation_prompt}
+
+ with autocast_ctx:
+ images = [
+ pipeline(**pipeline_args, generator=generator, num_inference_steps=25).images[0]
+ for _ in range(args.num_validation_images)
+ ]
+
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
+ if tracker.name == "wandb":
+ tracker.log(
+ {
+ "validation": [
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
+ for i, image in enumerate(images)
+ ]
+ }
+ )
+
+ del pipeline
+ torch.cuda.empty_cache()
+
+ if args.use_ema:
+ # Switch back to the original UNet parameters.
+ ema_unet.restore(unet.parameters())
+
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ unet = unwrap_model(unet)
+ if args.use_ema:
+ ema_unet.copy_to(unet.parameters())
+
+ # Serialize pipeline.
+ vae = AutoencoderKL.from_pretrained(
+ vae_path,
+ subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
+ revision=args.revision,
+ variant=args.variant,
+ torch_dtype=weight_dtype,
+ )
+ pipeline = StableDiffusionXLPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ unet=unet,
+ vae=vae,
+ revision=args.revision,
+ variant=args.variant,
+ torch_dtype=weight_dtype,
+ )
+ if args.prediction_type is not None:
+ scheduler_args = {"prediction_type": args.prediction_type}
+ pipeline.scheduler = pipeline.scheduler.from_config(pipeline.scheduler.config, **scheduler_args)
+ pipeline.save_pretrained(args.output_dir)
+
+ # run inference
+ images = []
+ if args.validation_prompt and args.num_validation_images > 0:
+ pipeline = pipeline.to(accelerator.device)
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
+
+ with autocast_ctx:
+ images = [
+ pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]
+ for _ in range(args.num_validation_images)
+ ]
+
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC")
+ if tracker.name == "wandb":
+ tracker.log(
+ {
+ "test": [
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
+ for i, image in enumerate(images)
+ ]
+ }
+ )
+
+ if args.push_to_hub:
+ save_model_card(
+ repo_id=repo_id,
+ images=images,
+ validation_prompt=args.validation_prompt,
+ base_model=args.pretrained_model_name_or_path,
+ dataset_name=args.dataset_name,
+ repo_folder=args.output_dir,
+ vae_path=args.pretrained_vae_model_name_or_path,
+ )
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ main(args)
diff --git a/diffusers/examples/textual_inversion/README.md b/diffusers/examples/textual_inversion/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..e869bb38d252e6bd39640485c3e6ef494d7e5caa
--- /dev/null
+++ b/diffusers/examples/textual_inversion/README.md
@@ -0,0 +1,153 @@
+## Textual Inversion fine-tuning example
+
+[Textual inversion](https://arxiv.org/abs/2208.01618) is a method to personalize text2image models like stable diffusion on your own images using just 3-5 examples.
+The `textual_inversion.py` script shows how to implement the training procedure and adapt it for stable diffusion.
+
+## Running on Colab
+
+Colab for training
+[](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/sd_textual_inversion_training.ipynb)
+
+Colab for inference
+[](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/stable_conceptualizer_inference.ipynb)
+
+## Running locally with PyTorch
+### Installing the dependencies
+
+Before running the scripts, make sure to install the library's training dependencies:
+
+**Important**
+
+To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
+```bash
+git clone https://github.com/huggingface/diffusers
+cd diffusers
+pip install .
+```
+
+Then cd in the example folder and run:
+```bash
+pip install -r requirements.txt
+```
+
+And initialize an [🤗 Accelerate](https://github.com/huggingface/accelerate/) environment with:
+
+```bash
+accelerate config
+```
+
+### Cat toy example
+
+First, let's login so that we can upload the checkpoint to the Hub during training:
+
+```bash
+huggingface-cli login
+```
+
+Now let's get our dataset. For this example we will use some cat images: https://huggingface.co/datasets/diffusers/cat_toy_example .
+
+Let's first download it locally:
+
+```py
+from huggingface_hub import snapshot_download
+
+local_dir = "./cat"
+snapshot_download("diffusers/cat_toy_example", local_dir=local_dir, repo_type="dataset", ignore_patterns=".gitattributes")
+```
+
+This will be our training data.
+Now we can launch the training using:
+
+**___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.___**
+
+**___Note: Please follow the [README_sdxl.md](./README_sdxl.md) if you are using the [stable-diffusion-xl](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0).___**
+
+```bash
+export MODEL_NAME="stable-diffusion-v1-5/stable-diffusion-v1-5"
+export DATA_DIR="./cat"
+
+accelerate launch textual_inversion.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --train_data_dir=$DATA_DIR \
+ --learnable_property="object" \
+ --placeholder_token="" \
+ --initializer_token="toy" \
+ --resolution=512 \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=4 \
+ --max_train_steps=3000 \
+ --learning_rate=5.0e-04 \
+ --scale_lr \
+ --lr_scheduler="constant" \
+ --lr_warmup_steps=0 \
+ --push_to_hub \
+ --output_dir="textual_inversion_cat"
+```
+
+A full training run takes ~1 hour on one V100 GPU.
+
+**Note**: As described in [the official paper](https://arxiv.org/abs/2208.01618)
+only one embedding vector is used for the placeholder token, *e.g.* `""`.
+However, one can also add multiple embedding vectors for the placeholder token
+to increase the number of fine-tuneable parameters. This can help the model to learn
+more complex details. To use multiple embedding vectors, you should define `--num_vectors`
+to a number larger than one, *e.g.*:
+```bash
+--num_vectors 5
+```
+
+The saved textual inversion vectors will then be larger in size compared to the default case.
+
+### Inference
+
+Once you have trained a model using above command, the inference can be done simply using the `StableDiffusionPipeline`. Make sure to include the `placeholder_token` in your prompt.
+
+```python
+from diffusers import StableDiffusionPipeline
+import torch
+
+model_id = "path-to-your-trained-model"
+pipe = StableDiffusionPipeline.from_pretrained(model_id,torch_dtype=torch.float16).to("cuda")
+
+repo_id_embeds = "path-to-your-learned-embeds"
+pipe.load_textual_inversion(repo_id_embeds)
+
+prompt = "A backpack"
+
+image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5).images[0]
+
+image.save("cat-backpack.png")
+```
+
+
+## Training with Flax/JAX
+
+For faster training on TPUs and GPUs you can leverage the flax training example. Follow the instructions above to get the model and dataset before running the script.
+
+Before running the scripts, make sure to install the library's training dependencies:
+
+```bash
+pip install -U -r requirements_flax.txt
+```
+
+```bash
+export MODEL_NAME="duongna/stable-diffusion-v1-4-flax"
+export DATA_DIR="path-to-dir-containing-images"
+
+python textual_inversion_flax.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --train_data_dir=$DATA_DIR \
+ --learnable_property="object" \
+ --placeholder_token="" \
+ --initializer_token="toy" \
+ --resolution=512 \
+ --train_batch_size=1 \
+ --max_train_steps=3000 \
+ --learning_rate=5.0e-04 \
+ --scale_lr \
+ --output_dir="textual_inversion_cat"
+```
+It should be at least 70% faster than the PyTorch script with the same configuration.
+
+### Training with xformers:
+You can enable memory efficient attention by [installing xFormers](https://github.com/facebookresearch/xformers#installing-xformers) and padding the `--enable_xformers_memory_efficient_attention` argument to the script. This is not available with the Flax/JAX implementation.
diff --git a/diffusers/examples/textual_inversion/README_sdxl.md b/diffusers/examples/textual_inversion/README_sdxl.md
new file mode 100644
index 0000000000000000000000000000000000000000..c971ee2a0f39d3fe4baef1b3a834b65dc569d2fb
--- /dev/null
+++ b/diffusers/examples/textual_inversion/README_sdxl.md
@@ -0,0 +1,47 @@
+## Textual Inversion fine-tuning example for SDXL
+
+```sh
+export MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0"
+export DATA_DIR="./cat"
+
+accelerate launch textual_inversion_sdxl.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --train_data_dir=$DATA_DIR \
+ --learnable_property="object" \
+ --placeholder_token="" \
+ --initializer_token="toy" \
+ --mixed_precision="bf16" \
+ --resolution=768 \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=4 \
+ --max_train_steps=500 \
+ --learning_rate=5.0e-04 \
+ --scale_lr \
+ --lr_scheduler="constant" \
+ --lr_warmup_steps=0 \
+ --save_as_full_pipeline \
+ --output_dir="./textual_inversion_cat_sdxl"
+```
+
+Training of both text encoders is supported.
+
+### Inference Example
+
+Once you have trained a model using above command, the inference can be done simply using the `StableDiffusionXLPipeline`.
+Make sure to include the `placeholder_token` in your prompt.
+
+```python
+from diffusers import StableDiffusionXLPipeline
+import torch
+
+model_id = "./textual_inversion_cat_sdxl"
+pipe = StableDiffusionXLPipeline.from_pretrained(model_id,torch_dtype=torch.float16).to("cuda")
+
+prompt = "A backpack"
+
+image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5).images[0]
+image.save("cat-backpack.png")
+
+image = pipe(prompt="", prompt_2=prompt, num_inference_steps=50, guidance_scale=7.5).images[0]
+image.save("cat-backpack-prompt_2.png")
+```
diff --git a/diffusers/examples/textual_inversion/requirements.txt b/diffusers/examples/textual_inversion/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..7a612982f4abbaa64f83db52e411a1235a372259
--- /dev/null
+++ b/diffusers/examples/textual_inversion/requirements.txt
@@ -0,0 +1,6 @@
+accelerate>=0.16.0
+torchvision
+transformers>=4.25.1
+ftfy
+tensorboard
+Jinja2
diff --git a/diffusers/examples/textual_inversion/requirements_flax.txt b/diffusers/examples/textual_inversion/requirements_flax.txt
new file mode 100644
index 0000000000000000000000000000000000000000..8f85ad523a3b46b65abf0138c05ecdd656e6845c
--- /dev/null
+++ b/diffusers/examples/textual_inversion/requirements_flax.txt
@@ -0,0 +1,8 @@
+transformers>=4.25.1
+flax
+optax
+torch
+torchvision
+ftfy
+tensorboard
+Jinja2
diff --git a/diffusers/examples/textual_inversion/test_textual_inversion.py b/diffusers/examples/textual_inversion/test_textual_inversion.py
new file mode 100644
index 0000000000000000000000000000000000000000..fa0b2c2bcf09998c7001a1e72b77600e2f946326
--- /dev/null
+++ b/diffusers/examples/textual_inversion/test_textual_inversion.py
@@ -0,0 +1,152 @@
+# coding=utf-8
+# Copyright 2024 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import os
+import sys
+import tempfile
+
+
+sys.path.append("..")
+from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
+
+
+logging.basicConfig(level=logging.DEBUG)
+
+logger = logging.getLogger()
+stream_handler = logging.StreamHandler(sys.stdout)
+logger.addHandler(stream_handler)
+
+
+class TextualInversion(ExamplesTestsAccelerate):
+ def test_textual_inversion(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/textual_inversion/textual_inversion.py
+ --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe
+ --train_data_dir docs/source/en/imgs
+ --learnable_property object
+ --placeholder_token
+ --initializer_token a
+ --save_steps 1
+ --num_vectors 2
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ """.split()
+
+ run_command(self._launch_args + test_args)
+ # save_pretrained smoke test
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "learned_embeds.safetensors")))
+
+ def test_textual_inversion_checkpointing(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/textual_inversion/textual_inversion.py
+ --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe
+ --train_data_dir docs/source/en/imgs
+ --learnable_property object
+ --placeholder_token
+ --initializer_token a
+ --save_steps 1
+ --num_vectors 2
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 3
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --checkpointing_steps=1
+ --checkpoints_total_limit=2
+ """.split()
+
+ run_command(self._launch_args + test_args)
+
+ # check checkpoint directories exist
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-2", "checkpoint-3"},
+ )
+
+ def test_textual_inversion_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/textual_inversion/textual_inversion.py
+ --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe
+ --train_data_dir docs/source/en/imgs
+ --learnable_property object
+ --placeholder_token
+ --initializer_token a
+ --save_steps 1
+ --num_vectors 2
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --checkpointing_steps=1
+ """.split()
+
+ run_command(self._launch_args + test_args)
+
+ # check checkpoint directories exist
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-1", "checkpoint-2"},
+ )
+
+ resume_run_args = f"""
+ examples/textual_inversion/textual_inversion.py
+ --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe
+ --train_data_dir docs/source/en/imgs
+ --learnable_property object
+ --placeholder_token
+ --initializer_token a
+ --save_steps 1
+ --num_vectors 2
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --checkpointing_steps=1
+ --resume_from_checkpoint=checkpoint-2
+ --checkpoints_total_limit=2
+ """.split()
+
+ run_command(self._launch_args + resume_run_args)
+
+ # check checkpoint directories exist
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-2", "checkpoint-3"},
+ )
diff --git a/diffusers/examples/textual_inversion/test_textual_inversion_sdxl.py b/diffusers/examples/textual_inversion/test_textual_inversion_sdxl.py
new file mode 100644
index 0000000000000000000000000000000000000000..a861708a70f9c3d35dce5cbd4d2e8a206d9d11a5
--- /dev/null
+++ b/diffusers/examples/textual_inversion/test_textual_inversion_sdxl.py
@@ -0,0 +1,152 @@
+# coding=utf-8
+# Copyright 2024 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import os
+import sys
+import tempfile
+
+
+sys.path.append("..")
+from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
+
+
+logging.basicConfig(level=logging.DEBUG)
+
+logger = logging.getLogger()
+stream_handler = logging.StreamHandler(sys.stdout)
+logger.addHandler(stream_handler)
+
+
+class TextualInversionSdxl(ExamplesTestsAccelerate):
+ def test_textual_inversion_sdxl(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/textual_inversion/textual_inversion_sdxl.py
+ --pretrained_model_name_or_path hf-internal-testing/tiny-sdxl-pipe
+ --train_data_dir docs/source/en/imgs
+ --learnable_property object
+ --placeholder_token
+ --initializer_token a
+ --save_steps 1
+ --num_vectors 2
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ """.split()
+
+ run_command(self._launch_args + test_args)
+ # save_pretrained smoke test
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "learned_embeds.safetensors")))
+
+ def test_textual_inversion_sdxl_checkpointing(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/textual_inversion/textual_inversion_sdxl.py
+ --pretrained_model_name_or_path hf-internal-testing/tiny-sdxl-pipe
+ --train_data_dir docs/source/en/imgs
+ --learnable_property object
+ --placeholder_token
+ --initializer_token a
+ --save_steps 1
+ --num_vectors 2
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 3
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --checkpointing_steps=1
+ --checkpoints_total_limit=2
+ """.split()
+
+ run_command(self._launch_args + test_args)
+
+ # check checkpoint directories exist
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-2", "checkpoint-3"},
+ )
+
+ def test_textual_inversion_sdxl_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/textual_inversion/textual_inversion_sdxl.py
+ --pretrained_model_name_or_path hf-internal-testing/tiny-sdxl-pipe
+ --train_data_dir docs/source/en/imgs
+ --learnable_property object
+ --placeholder_token
+ --initializer_token a
+ --save_steps 1
+ --num_vectors 2
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --checkpointing_steps=1
+ """.split()
+
+ run_command(self._launch_args + test_args)
+
+ # check checkpoint directories exist
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-1", "checkpoint-2"},
+ )
+
+ resume_run_args = f"""
+ examples/textual_inversion/textual_inversion_sdxl.py
+ --pretrained_model_name_or_path hf-internal-testing/tiny-sdxl-pipe
+ --train_data_dir docs/source/en/imgs
+ --learnable_property object
+ --placeholder_token
+ --initializer_token a
+ --save_steps 1
+ --num_vectors 2
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --checkpointing_steps=1
+ --resume_from_checkpoint=checkpoint-2
+ --checkpoints_total_limit=2
+ """.split()
+
+ run_command(self._launch_args + resume_run_args)
+
+ # check checkpoint directories exist
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-2", "checkpoint-3"},
+ )
diff --git a/diffusers/examples/textual_inversion/textual_inversion.py b/diffusers/examples/textual_inversion/textual_inversion.py
new file mode 100644
index 0000000000000000000000000000000000000000..6b710531836b1b4bc1ea5a2c49b13b797fc19d85
--- /dev/null
+++ b/diffusers/examples/textual_inversion/textual_inversion.py
@@ -0,0 +1,1022 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+import argparse
+import logging
+import math
+import os
+import random
+import shutil
+import warnings
+from contextlib import nullcontext
+from pathlib import Path
+
+import numpy as np
+import PIL
+import safetensors
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import ProjectConfiguration, set_seed
+from huggingface_hub import create_repo, upload_folder
+
+# TODO: remove and import from diffusers.utils when the new version of diffusers is released
+from packaging import version
+from PIL import Image
+from torch.utils.data import Dataset
+from torchvision import transforms
+from tqdm.auto import tqdm
+from transformers import CLIPTextModel, CLIPTokenizer
+
+import diffusers
+from diffusers import (
+ AutoencoderKL,
+ DDPMScheduler,
+ DiffusionPipeline,
+ DPMSolverMultistepScheduler,
+ StableDiffusionPipeline,
+ UNet2DConditionModel,
+)
+from diffusers.optimization import get_scheduler
+from diffusers.utils import check_min_version, is_wandb_available
+from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
+from diffusers.utils.import_utils import is_xformers_available
+
+
+if is_wandb_available():
+ import wandb
+
+if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
+ PIL_INTERPOLATION = {
+ "linear": PIL.Image.Resampling.BILINEAR,
+ "bilinear": PIL.Image.Resampling.BILINEAR,
+ "bicubic": PIL.Image.Resampling.BICUBIC,
+ "lanczos": PIL.Image.Resampling.LANCZOS,
+ "nearest": PIL.Image.Resampling.NEAREST,
+ }
+else:
+ PIL_INTERPOLATION = {
+ "linear": PIL.Image.LINEAR,
+ "bilinear": PIL.Image.BILINEAR,
+ "bicubic": PIL.Image.BICUBIC,
+ "lanczos": PIL.Image.LANCZOS,
+ "nearest": PIL.Image.NEAREST,
+ }
+# ------------------------------------------------------------------------------
+
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.31.0.dev0")
+
+logger = get_logger(__name__)
+
+
+def save_model_card(repo_id: str, images: list = None, base_model: str = None, repo_folder: str = None):
+ img_str = ""
+ if images is not None:
+ for i, image in enumerate(images):
+ image.save(os.path.join(repo_folder, f"image_{i}.png"))
+ img_str += f"\n"
+ model_description = f"""
+# Textual inversion text2image fine-tuning - {repo_id}
+These are textual inversion adaption weights for {base_model}. You can find some example images in the following. \n
+{img_str}
+"""
+ model_card = load_or_create_model_card(
+ repo_id_or_path=repo_id,
+ from_training=True,
+ license="creativeml-openrail-m",
+ base_model=base_model,
+ model_description=model_description,
+ inference=True,
+ )
+
+ tags = [
+ "stable-diffusion",
+ "stable-diffusion-diffusers",
+ "text-to-image",
+ "diffusers",
+ "textual_inversion",
+ "diffusers-training",
+ ]
+ model_card = populate_model_card(model_card, tags=tags)
+
+ model_card.save(os.path.join(repo_folder, "README.md"))
+
+
+def log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch):
+ logger.info(
+ f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
+ f" {args.validation_prompt}."
+ )
+ # create pipeline (note: unet and vae are loaded again in float32)
+ pipeline = DiffusionPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ text_encoder=accelerator.unwrap_model(text_encoder),
+ tokenizer=tokenizer,
+ unet=unet,
+ vae=vae,
+ safety_checker=None,
+ revision=args.revision,
+ variant=args.variant,
+ torch_dtype=weight_dtype,
+ )
+ pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
+ pipeline = pipeline.to(accelerator.device)
+ pipeline.set_progress_bar_config(disable=True)
+
+ # run inference
+ generator = None if args.seed is None else torch.Generator(device=accelerator.device).manual_seed(args.seed)
+ images = []
+ for _ in range(args.num_validation_images):
+ if torch.backends.mps.is_available():
+ autocast_ctx = nullcontext()
+ else:
+ autocast_ctx = torch.autocast(accelerator.device.type)
+
+ with autocast_ctx:
+ image = pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]
+ images.append(image)
+
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
+ if tracker.name == "wandb":
+ tracker.log(
+ {
+ "validation": [
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images)
+ ]
+ }
+ )
+
+ del pipeline
+ torch.cuda.empty_cache()
+ return images
+
+
+def save_progress(text_encoder, placeholder_token_ids, accelerator, args, save_path, safe_serialization=True):
+ logger.info("Saving embeddings")
+ learned_embeds = (
+ accelerator.unwrap_model(text_encoder)
+ .get_input_embeddings()
+ .weight[min(placeholder_token_ids) : max(placeholder_token_ids) + 1]
+ )
+ learned_embeds_dict = {args.placeholder_token: learned_embeds.detach().cpu()}
+
+ if safe_serialization:
+ safetensors.torch.save_file(learned_embeds_dict, save_path, metadata={"format": "pt"})
+ else:
+ torch.save(learned_embeds_dict, save_path)
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ parser.add_argument(
+ "--save_steps",
+ type=int,
+ default=500,
+ help="Save learned_embeds.bin every X updates steps.",
+ )
+ parser.add_argument(
+ "--save_as_full_pipeline",
+ action="store_true",
+ help="Save the complete stable diffusion pipeline.",
+ )
+ parser.add_argument(
+ "--num_vectors",
+ type=int,
+ default=1,
+ help="How many textual inversion vectors shall be used to learn the concept.",
+ )
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--variant",
+ type=str,
+ default=None,
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
+ )
+ parser.add_argument(
+ "--tokenizer_name",
+ type=str,
+ default=None,
+ help="Pretrained tokenizer name or path if not the same as model_name",
+ )
+ parser.add_argument(
+ "--train_data_dir", type=str, default=None, required=True, help="A folder containing the training data."
+ )
+ parser.add_argument(
+ "--placeholder_token",
+ type=str,
+ default=None,
+ required=True,
+ help="A token to use as a placeholder for the concept.",
+ )
+ parser.add_argument(
+ "--initializer_token", type=str, default=None, required=True, help="A token to use as initializer word."
+ )
+ parser.add_argument("--learnable_property", type=str, default="object", help="Choose between 'object' and 'style'")
+ parser.add_argument("--repeats", type=int, default=100, help="How many times to repeat the training data.")
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="text-inversion-model",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution."
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=100)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=5000,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--lr_num_cycles",
+ type=int,
+ default=1,
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
+ )
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default="no",
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose"
+ "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
+ "and Nvidia Ampere GPU or Intel Gen 4 Xeon (and later) ."
+ ),
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument(
+ "--validation_prompt",
+ type=str,
+ default=None,
+ help="A prompt that is used during validation to verify that the model is learning.",
+ )
+ parser.add_argument(
+ "--num_validation_images",
+ type=int,
+ default=4,
+ help="Number of images that should be generated during validation with `validation_prompt`.",
+ )
+ parser.add_argument(
+ "--validation_steps",
+ type=int,
+ default=100,
+ help=(
+ "Run validation every X steps. Validation consists of running the prompt"
+ " `args.validation_prompt` multiple times: `args.num_validation_images`"
+ " and logging the images."
+ ),
+ )
+ parser.add_argument(
+ "--validation_epochs",
+ type=int,
+ default=None,
+ help=(
+ "Deprecated in favor of validation_steps. Run validation every X epochs. Validation consists of running the prompt"
+ " `args.validation_prompt` multiple times: `args.num_validation_images`"
+ " and logging the images."
+ ),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
+ )
+ parser.add_argument(
+ "--no_safe_serialization",
+ action="store_true",
+ help="If specified save the checkpoint not in `safetensors` format, but in original PyTorch format instead.",
+ )
+
+ args = parser.parse_args()
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ if args.train_data_dir is None:
+ raise ValueError("You must specify a train data directory.")
+
+ return args
+
+
+imagenet_templates_small = [
+ "a photo of a {}",
+ "a rendering of a {}",
+ "a cropped photo of the {}",
+ "the photo of a {}",
+ "a photo of a clean {}",
+ "a photo of a dirty {}",
+ "a dark photo of the {}",
+ "a photo of my {}",
+ "a photo of the cool {}",
+ "a close-up photo of a {}",
+ "a bright photo of the {}",
+ "a cropped photo of a {}",
+ "a photo of the {}",
+ "a good photo of the {}",
+ "a photo of one {}",
+ "a close-up photo of the {}",
+ "a rendition of the {}",
+ "a photo of the clean {}",
+ "a rendition of a {}",
+ "a photo of a nice {}",
+ "a good photo of a {}",
+ "a photo of the nice {}",
+ "a photo of the small {}",
+ "a photo of the weird {}",
+ "a photo of the large {}",
+ "a photo of a cool {}",
+ "a photo of a small {}",
+]
+
+imagenet_style_templates_small = [
+ "a painting in the style of {}",
+ "a rendering in the style of {}",
+ "a cropped painting in the style of {}",
+ "the painting in the style of {}",
+ "a clean painting in the style of {}",
+ "a dirty painting in the style of {}",
+ "a dark painting in the style of {}",
+ "a picture in the style of {}",
+ "a cool painting in the style of {}",
+ "a close-up painting in the style of {}",
+ "a bright painting in the style of {}",
+ "a cropped painting in the style of {}",
+ "a good painting in the style of {}",
+ "a close-up painting in the style of {}",
+ "a rendition in the style of {}",
+ "a nice painting in the style of {}",
+ "a small painting in the style of {}",
+ "a weird painting in the style of {}",
+ "a large painting in the style of {}",
+]
+
+
+class TextualInversionDataset(Dataset):
+ def __init__(
+ self,
+ data_root,
+ tokenizer,
+ learnable_property="object", # [object, style]
+ size=512,
+ repeats=100,
+ interpolation="bicubic",
+ flip_p=0.5,
+ set="train",
+ placeholder_token="*",
+ center_crop=False,
+ ):
+ self.data_root = data_root
+ self.tokenizer = tokenizer
+ self.learnable_property = learnable_property
+ self.size = size
+ self.placeholder_token = placeholder_token
+ self.center_crop = center_crop
+ self.flip_p = flip_p
+
+ self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)]
+
+ self.num_images = len(self.image_paths)
+ self._length = self.num_images
+
+ if set == "train":
+ self._length = self.num_images * repeats
+
+ self.interpolation = {
+ "linear": PIL_INTERPOLATION["linear"],
+ "bilinear": PIL_INTERPOLATION["bilinear"],
+ "bicubic": PIL_INTERPOLATION["bicubic"],
+ "lanczos": PIL_INTERPOLATION["lanczos"],
+ }[interpolation]
+
+ self.templates = imagenet_style_templates_small if learnable_property == "style" else imagenet_templates_small
+ self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p)
+
+ def __len__(self):
+ return self._length
+
+ def __getitem__(self, i):
+ example = {}
+ image = Image.open(self.image_paths[i % self.num_images])
+
+ if not image.mode == "RGB":
+ image = image.convert("RGB")
+
+ placeholder_string = self.placeholder_token
+ text = random.choice(self.templates).format(placeholder_string)
+
+ example["input_ids"] = self.tokenizer(
+ text,
+ padding="max_length",
+ truncation=True,
+ max_length=self.tokenizer.model_max_length,
+ return_tensors="pt",
+ ).input_ids[0]
+
+ # default to score-sde preprocessing
+ img = np.array(image).astype(np.uint8)
+
+ if self.center_crop:
+ crop = min(img.shape[0], img.shape[1])
+ (
+ h,
+ w,
+ ) = (
+ img.shape[0],
+ img.shape[1],
+ )
+ img = img[(h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2]
+
+ image = Image.fromarray(img)
+ image = image.resize((self.size, self.size), resample=self.interpolation)
+
+ image = self.flip_transform(image)
+ image = np.array(image).astype(np.uint8)
+ image = (image / 127.5 - 1.0).astype(np.float32)
+
+ example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1)
+ return example
+
+
+def main():
+ args = parse_args()
+ if args.report_to == "wandb" and args.hub_token is not None:
+ raise ValueError(
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
+ " Please use `huggingface-cli login` to authenticate with the Hub."
+ )
+
+ logging_dir = os.path.join(args.output_dir, args.logging_dir)
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ )
+
+ # Disable AMP for MPS.
+ if torch.backends.mps.is_available():
+ accelerator.native_amp = False
+
+ if args.report_to == "wandb":
+ if not is_wandb_available():
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
+ ).repo_id
+
+ # Load tokenizer
+ if args.tokenizer_name:
+ tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
+ elif args.pretrained_model_name_or_path:
+ tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
+
+ # Load scheduler and models
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
+ text_encoder = CLIPTextModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
+ )
+ vae = AutoencoderKL.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant
+ )
+ unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
+ )
+
+ # Add the placeholder token in tokenizer
+ placeholder_tokens = [args.placeholder_token]
+
+ if args.num_vectors < 1:
+ raise ValueError(f"--num_vectors has to be larger or equal to 1, but is {args.num_vectors}")
+
+ # add dummy tokens for multi-vector
+ additional_tokens = []
+ for i in range(1, args.num_vectors):
+ additional_tokens.append(f"{args.placeholder_token}_{i}")
+ placeholder_tokens += additional_tokens
+
+ num_added_tokens = tokenizer.add_tokens(placeholder_tokens)
+ if num_added_tokens != args.num_vectors:
+ raise ValueError(
+ f"The tokenizer already contains the token {args.placeholder_token}. Please pass a different"
+ " `placeholder_token` that is not already in the tokenizer."
+ )
+
+ # Convert the initializer_token, placeholder_token to ids
+ token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False)
+ # Check if initializer_token is a single token or a sequence of tokens
+ if len(token_ids) > 1:
+ raise ValueError("The initializer token must be a single token.")
+
+ initializer_token_id = token_ids[0]
+ placeholder_token_ids = tokenizer.convert_tokens_to_ids(placeholder_tokens)
+
+ # Resize the token embeddings as we are adding new special tokens to the tokenizer
+ text_encoder.resize_token_embeddings(len(tokenizer))
+
+ # Initialise the newly added placeholder token with the embeddings of the initializer token
+ token_embeds = text_encoder.get_input_embeddings().weight.data
+ with torch.no_grad():
+ for token_id in placeholder_token_ids:
+ token_embeds[token_id] = token_embeds[initializer_token_id].clone()
+
+ # Freeze vae and unet
+ vae.requires_grad_(False)
+ unet.requires_grad_(False)
+ # Freeze all parameters except for the token embeddings in text encoder
+ text_encoder.text_model.encoder.requires_grad_(False)
+ text_encoder.text_model.final_layer_norm.requires_grad_(False)
+ text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)
+
+ if args.gradient_checkpointing:
+ # Keep unet in train mode if we are using gradient checkpointing to save memory.
+ # The dropout cannot be != 0 so it doesn't matter if we are in eval or train mode.
+ unet.train()
+ text_encoder.gradient_checkpointing_enable()
+ unet.enable_gradient_checkpointing()
+
+ if args.enable_xformers_memory_efficient_attention:
+ if is_xformers_available():
+ import xformers
+
+ xformers_version = version.parse(xformers.__version__)
+ if xformers_version == version.parse("0.0.16"):
+ logger.warning(
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
+ )
+ unet.enable_xformers_memory_efficient_attention()
+ else:
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ # Initialize the optimizer
+ optimizer = torch.optim.AdamW(
+ text_encoder.get_input_embeddings().parameters(), # only optimize the embeddings
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ # Dataset and DataLoaders creation:
+ train_dataset = TextualInversionDataset(
+ data_root=args.train_data_dir,
+ tokenizer=tokenizer,
+ size=args.resolution,
+ placeholder_token=(" ".join(tokenizer.convert_ids_to_tokens(placeholder_token_ids))),
+ repeats=args.repeats,
+ learnable_property=args.learnable_property,
+ center_crop=args.center_crop,
+ set="train",
+ )
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers
+ )
+ if args.validation_epochs is not None:
+ warnings.warn(
+ f"FutureWarning: You are doing logging with validation_epochs={args.validation_epochs}."
+ " Deprecated validation_epochs in favor of `validation_steps`"
+ f"Setting `args.validation_steps` to {args.validation_epochs * len(train_dataset)}",
+ FutureWarning,
+ stacklevel=2,
+ )
+ args.validation_steps = args.validation_epochs * len(train_dataset)
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
+ num_cycles=args.lr_num_cycles,
+ )
+
+ text_encoder.train()
+ # Prepare everything with our `accelerator`.
+ text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ text_encoder, optimizer, train_dataloader, lr_scheduler
+ )
+
+ # For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision
+ # as these weights are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ # Move vae and unet to device and cast to weight_dtype
+ unet.to(accelerator.device, dtype=weight_dtype)
+ vae.to(accelerator.device, dtype=weight_dtype)
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ accelerator.init_trackers("textual_inversion", config=vars(args))
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the most recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ initial_global_step = 0
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ initial_global_step = global_step
+ first_epoch = global_step // num_update_steps_per_epoch
+
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
+
+ # keep original embeddings as reference
+ orig_embeds_params = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight.data.clone()
+
+ for epoch in range(first_epoch, args.num_train_epochs):
+ text_encoder.train()
+ for step, batch in enumerate(train_dataloader):
+ with accelerator.accumulate(text_encoder):
+ # Convert images to latent space
+ latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample().detach()
+ latents = latents * vae.config.scaling_factor
+
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(latents)
+ bsz = latents.shape[0]
+ # Sample a random timestep for each image
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
+ timesteps = timesteps.long()
+
+ # Add noise to the latents according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
+
+ # Get the text embedding for conditioning
+ encoder_hidden_states = text_encoder(batch["input_ids"])[0].to(dtype=weight_dtype)
+
+ # Predict the noise residual
+ model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
+
+ # Get the target for loss depending on the prediction type
+ if noise_scheduler.config.prediction_type == "epsilon":
+ target = noise
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
+ else:
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
+
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
+
+ accelerator.backward(loss)
+
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ # Let's make sure we don't update any embedding weights besides the newly added token
+ index_no_updates = torch.ones((len(tokenizer),), dtype=torch.bool)
+ index_no_updates[min(placeholder_token_ids) : max(placeholder_token_ids) + 1] = False
+
+ with torch.no_grad():
+ accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[
+ index_no_updates
+ ] = orig_embeds_params[index_no_updates]
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ images = []
+ progress_bar.update(1)
+ global_step += 1
+ if global_step % args.save_steps == 0:
+ weight_name = (
+ f"learned_embeds-steps-{global_step}.bin"
+ if args.no_safe_serialization
+ else f"learned_embeds-steps-{global_step}.safetensors"
+ )
+ save_path = os.path.join(args.output_dir, weight_name)
+ save_progress(
+ text_encoder,
+ placeholder_token_ids,
+ accelerator,
+ args,
+ save_path,
+ safe_serialization=not args.no_safe_serialization,
+ )
+
+ if accelerator.is_main_process:
+ if global_step % args.checkpointing_steps == 0:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ if args.validation_prompt is not None and global_step % args.validation_steps == 0:
+ images = log_validation(
+ text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch
+ )
+
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+ accelerator.log(logs, step=global_step)
+
+ if global_step >= args.max_train_steps:
+ break
+ # Create the pipeline using the trained modules and save it.
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ if args.push_to_hub and not args.save_as_full_pipeline:
+ logger.warning("Enabling full model saving because --push_to_hub=True was specified.")
+ save_full_model = True
+ else:
+ save_full_model = args.save_as_full_pipeline
+ if save_full_model:
+ pipeline = StableDiffusionPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ text_encoder=accelerator.unwrap_model(text_encoder),
+ vae=vae,
+ unet=unet,
+ tokenizer=tokenizer,
+ )
+ pipeline.save_pretrained(args.output_dir)
+ # Save the newly trained embeddings
+ weight_name = "learned_embeds.bin" if args.no_safe_serialization else "learned_embeds.safetensors"
+ save_path = os.path.join(args.output_dir, weight_name)
+ save_progress(
+ text_encoder,
+ placeholder_token_ids,
+ accelerator,
+ args,
+ save_path,
+ safe_serialization=not args.no_safe_serialization,
+ )
+
+ if args.push_to_hub:
+ save_model_card(
+ repo_id,
+ images=images,
+ base_model=args.pretrained_model_name_or_path,
+ repo_folder=args.output_dir,
+ )
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/diffusers/examples/textual_inversion/textual_inversion_flax.py b/diffusers/examples/textual_inversion/textual_inversion_flax.py
new file mode 100644
index 0000000000000000000000000000000000000000..ee7b1580d14520f9a8e44c81844c803d6932a414
--- /dev/null
+++ b/diffusers/examples/textual_inversion/textual_inversion_flax.py
@@ -0,0 +1,681 @@
+import argparse
+import logging
+import math
+import os
+import random
+from pathlib import Path
+
+import jax
+import jax.numpy as jnp
+import numpy as np
+import optax
+import PIL
+import torch
+import torch.utils.checkpoint
+import transformers
+from flax import jax_utils
+from flax.training import train_state
+from flax.training.common_utils import shard
+from huggingface_hub import create_repo, upload_folder
+
+# TODO: remove and import from diffusers.utils when the new version of diffusers is released
+from packaging import version
+from PIL import Image
+from torch.utils.data import Dataset
+from torchvision import transforms
+from tqdm.auto import tqdm
+from transformers import CLIPImageProcessor, CLIPTokenizer, FlaxCLIPTextModel, set_seed
+
+from diffusers import (
+ FlaxAutoencoderKL,
+ FlaxDDPMScheduler,
+ FlaxPNDMScheduler,
+ FlaxStableDiffusionPipeline,
+ FlaxUNet2DConditionModel,
+)
+from diffusers.pipelines.stable_diffusion import FlaxStableDiffusionSafetyChecker
+from diffusers.utils import check_min_version
+
+
+if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
+ PIL_INTERPOLATION = {
+ "linear": PIL.Image.Resampling.BILINEAR,
+ "bilinear": PIL.Image.Resampling.BILINEAR,
+ "bicubic": PIL.Image.Resampling.BICUBIC,
+ "lanczos": PIL.Image.Resampling.LANCZOS,
+ "nearest": PIL.Image.Resampling.NEAREST,
+ }
+else:
+ PIL_INTERPOLATION = {
+ "linear": PIL.Image.LINEAR,
+ "bilinear": PIL.Image.BILINEAR,
+ "bicubic": PIL.Image.BICUBIC,
+ "lanczos": PIL.Image.LANCZOS,
+ "nearest": PIL.Image.NEAREST,
+ }
+# ------------------------------------------------------------------------------
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.31.0.dev0")
+
+logger = logging.getLogger(__name__)
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--tokenizer_name",
+ type=str,
+ default=None,
+ help="Pretrained tokenizer name or path if not the same as model_name",
+ )
+ parser.add_argument(
+ "--train_data_dir", type=str, default=None, required=True, help="A folder containing the training data."
+ )
+ parser.add_argument(
+ "--placeholder_token",
+ type=str,
+ default=None,
+ required=True,
+ help="A token to use as a placeholder for the concept.",
+ )
+ parser.add_argument(
+ "--initializer_token", type=str, default=None, required=True, help="A token to use as initializer word."
+ )
+ parser.add_argument("--learnable_property", type=str, default="object", help="Choose between 'object' and 'style'")
+ parser.add_argument("--repeats", type=int, default=100, help="How many times to repeat the training data.")
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="text-inversion-model",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution."
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=100)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=5000,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--save_steps",
+ type=int,
+ default=500,
+ help="Save learned_embeds.bin every X updates steps.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=True,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument(
+ "--use_auth_token",
+ action="store_true",
+ help=(
+ "Will use the token generated when running `huggingface-cli login` (necessary to use this script with"
+ " private models)."
+ ),
+ )
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+
+ args = parser.parse_args()
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ if args.train_data_dir is None:
+ raise ValueError("You must specify a train data directory.")
+
+ return args
+
+
+imagenet_templates_small = [
+ "a photo of a {}",
+ "a rendering of a {}",
+ "a cropped photo of the {}",
+ "the photo of a {}",
+ "a photo of a clean {}",
+ "a photo of a dirty {}",
+ "a dark photo of the {}",
+ "a photo of my {}",
+ "a photo of the cool {}",
+ "a close-up photo of a {}",
+ "a bright photo of the {}",
+ "a cropped photo of a {}",
+ "a photo of the {}",
+ "a good photo of the {}",
+ "a photo of one {}",
+ "a close-up photo of the {}",
+ "a rendition of the {}",
+ "a photo of the clean {}",
+ "a rendition of a {}",
+ "a photo of a nice {}",
+ "a good photo of a {}",
+ "a photo of the nice {}",
+ "a photo of the small {}",
+ "a photo of the weird {}",
+ "a photo of the large {}",
+ "a photo of a cool {}",
+ "a photo of a small {}",
+]
+
+imagenet_style_templates_small = [
+ "a painting in the style of {}",
+ "a rendering in the style of {}",
+ "a cropped painting in the style of {}",
+ "the painting in the style of {}",
+ "a clean painting in the style of {}",
+ "a dirty painting in the style of {}",
+ "a dark painting in the style of {}",
+ "a picture in the style of {}",
+ "a cool painting in the style of {}",
+ "a close-up painting in the style of {}",
+ "a bright painting in the style of {}",
+ "a cropped painting in the style of {}",
+ "a good painting in the style of {}",
+ "a close-up painting in the style of {}",
+ "a rendition in the style of {}",
+ "a nice painting in the style of {}",
+ "a small painting in the style of {}",
+ "a weird painting in the style of {}",
+ "a large painting in the style of {}",
+]
+
+
+class TextualInversionDataset(Dataset):
+ def __init__(
+ self,
+ data_root,
+ tokenizer,
+ learnable_property="object", # [object, style]
+ size=512,
+ repeats=100,
+ interpolation="bicubic",
+ flip_p=0.5,
+ set="train",
+ placeholder_token="*",
+ center_crop=False,
+ ):
+ self.data_root = data_root
+ self.tokenizer = tokenizer
+ self.learnable_property = learnable_property
+ self.size = size
+ self.placeholder_token = placeholder_token
+ self.center_crop = center_crop
+ self.flip_p = flip_p
+
+ self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)]
+
+ self.num_images = len(self.image_paths)
+ self._length = self.num_images
+
+ if set == "train":
+ self._length = self.num_images * repeats
+
+ self.interpolation = {
+ "linear": PIL_INTERPOLATION["linear"],
+ "bilinear": PIL_INTERPOLATION["bilinear"],
+ "bicubic": PIL_INTERPOLATION["bicubic"],
+ "lanczos": PIL_INTERPOLATION["lanczos"],
+ }[interpolation]
+
+ self.templates = imagenet_style_templates_small if learnable_property == "style" else imagenet_templates_small
+ self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p)
+
+ def __len__(self):
+ return self._length
+
+ def __getitem__(self, i):
+ example = {}
+ image = Image.open(self.image_paths[i % self.num_images])
+
+ if not image.mode == "RGB":
+ image = image.convert("RGB")
+
+ placeholder_string = self.placeholder_token
+ text = random.choice(self.templates).format(placeholder_string)
+
+ example["input_ids"] = self.tokenizer(
+ text,
+ padding="max_length",
+ truncation=True,
+ max_length=self.tokenizer.model_max_length,
+ return_tensors="pt",
+ ).input_ids[0]
+
+ # default to score-sde preprocessing
+ img = np.array(image).astype(np.uint8)
+
+ if self.center_crop:
+ crop = min(img.shape[0], img.shape[1])
+ (
+ h,
+ w,
+ ) = (
+ img.shape[0],
+ img.shape[1],
+ )
+ img = img[(h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2]
+
+ image = Image.fromarray(img)
+ image = image.resize((self.size, self.size), resample=self.interpolation)
+
+ image = self.flip_transform(image)
+ image = np.array(image).astype(np.uint8)
+ image = (image / 127.5 - 1.0).astype(np.float32)
+
+ example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1)
+ return example
+
+
+def resize_token_embeddings(model, new_num_tokens, initializer_token_id, placeholder_token_id, rng):
+ if model.config.vocab_size == new_num_tokens or new_num_tokens is None:
+ return
+ model.config.vocab_size = new_num_tokens
+
+ params = model.params
+ old_embeddings = params["text_model"]["embeddings"]["token_embedding"]["embedding"]
+ old_num_tokens, emb_dim = old_embeddings.shape
+
+ initializer = jax.nn.initializers.normal()
+
+ new_embeddings = initializer(rng, (new_num_tokens, emb_dim))
+ new_embeddings = new_embeddings.at[:old_num_tokens].set(old_embeddings)
+ new_embeddings = new_embeddings.at[placeholder_token_id].set(new_embeddings[initializer_token_id])
+ params["text_model"]["embeddings"]["token_embedding"]["embedding"] = new_embeddings
+
+ model.params = params
+ return model
+
+
+def get_params_to_save(params):
+ return jax.device_get(jax.tree_util.tree_map(lambda x: x[0], params))
+
+
+def main():
+ args = parse_args()
+
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ if jax.process_index() == 0:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
+ ).repo_id
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ # Setup logging, we only want one process per machine to log things on the screen.
+ logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
+ if jax.process_index() == 0:
+ transformers.utils.logging.set_verbosity_info()
+ else:
+ transformers.utils.logging.set_verbosity_error()
+
+ # Load the tokenizer and add the placeholder token as a additional special token
+ if args.tokenizer_name:
+ tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
+ elif args.pretrained_model_name_or_path:
+ tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
+
+ # Add the placeholder token in tokenizer
+ num_added_tokens = tokenizer.add_tokens(args.placeholder_token)
+ if num_added_tokens == 0:
+ raise ValueError(
+ f"The tokenizer already contains the token {args.placeholder_token}. Please pass a different"
+ " `placeholder_token` that is not already in the tokenizer."
+ )
+
+ # Convert the initializer_token, placeholder_token to ids
+ token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False)
+ # Check if initializer_token is a single token or a sequence of tokens
+ if len(token_ids) > 1:
+ raise ValueError("The initializer token must be a single token.")
+
+ initializer_token_id = token_ids[0]
+ placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token)
+
+ # Load models and create wrapper for stable diffusion
+ text_encoder = FlaxCLIPTextModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
+ )
+ vae, vae_params = FlaxAutoencoderKL.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision
+ )
+ unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
+ )
+
+ # Create sampling rng
+ rng = jax.random.PRNGKey(args.seed)
+ rng, _ = jax.random.split(rng)
+ # Resize the token embeddings as we are adding new special tokens to the tokenizer
+ text_encoder = resize_token_embeddings(
+ text_encoder, len(tokenizer), initializer_token_id, placeholder_token_id, rng
+ )
+ original_token_embeds = text_encoder.params["text_model"]["embeddings"]["token_embedding"]["embedding"]
+
+ train_dataset = TextualInversionDataset(
+ data_root=args.train_data_dir,
+ tokenizer=tokenizer,
+ size=args.resolution,
+ placeholder_token=args.placeholder_token,
+ repeats=args.repeats,
+ learnable_property=args.learnable_property,
+ center_crop=args.center_crop,
+ set="train",
+ )
+
+ def collate_fn(examples):
+ pixel_values = torch.stack([example["pixel_values"] for example in examples])
+ input_ids = torch.stack([example["input_ids"] for example in examples])
+
+ batch = {"pixel_values": pixel_values, "input_ids": input_ids}
+ batch = {k: v.numpy() for k, v in batch.items()}
+
+ return batch
+
+ total_train_batch_size = args.train_batch_size * jax.local_device_count()
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset, batch_size=total_train_batch_size, shuffle=True, drop_last=True, collate_fn=collate_fn
+ )
+
+ # Optimization
+ if args.scale_lr:
+ args.learning_rate = args.learning_rate * total_train_batch_size
+
+ constant_scheduler = optax.constant_schedule(args.learning_rate)
+
+ optimizer = optax.adamw(
+ learning_rate=constant_scheduler,
+ b1=args.adam_beta1,
+ b2=args.adam_beta2,
+ eps=args.adam_epsilon,
+ weight_decay=args.adam_weight_decay,
+ )
+
+ def create_mask(params, label_fn):
+ def _map(params, mask, label_fn):
+ for k in params:
+ if label_fn(k):
+ mask[k] = "token_embedding"
+ else:
+ if isinstance(params[k], dict):
+ mask[k] = {}
+ _map(params[k], mask[k], label_fn)
+ else:
+ mask[k] = "zero"
+
+ mask = {}
+ _map(params, mask, label_fn)
+ return mask
+
+ def zero_grads():
+ # from https://github.com/deepmind/optax/issues/159#issuecomment-896459491
+ def init_fn(_):
+ return ()
+
+ def update_fn(updates, state, params=None):
+ return jax.tree_util.tree_map(jnp.zeros_like, updates), ()
+
+ return optax.GradientTransformation(init_fn, update_fn)
+
+ # Zero out gradients of layers other than the token embedding layer
+ tx = optax.multi_transform(
+ {"token_embedding": optimizer, "zero": zero_grads()},
+ create_mask(text_encoder.params, lambda s: s == "token_embedding"),
+ )
+
+ state = train_state.TrainState.create(apply_fn=text_encoder.__call__, params=text_encoder.params, tx=tx)
+
+ noise_scheduler = FlaxDDPMScheduler(
+ beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000
+ )
+ noise_scheduler_state = noise_scheduler.create_state()
+
+ # Initialize our training
+ train_rngs = jax.random.split(rng, jax.local_device_count())
+
+ # Define gradient train step fn
+ def train_step(state, vae_params, unet_params, batch, train_rng):
+ dropout_rng, sample_rng, new_train_rng = jax.random.split(train_rng, 3)
+
+ def compute_loss(params):
+ vae_outputs = vae.apply(
+ {"params": vae_params}, batch["pixel_values"], deterministic=True, method=vae.encode
+ )
+ latents = vae_outputs.latent_dist.sample(sample_rng)
+ # (NHWC) -> (NCHW)
+ latents = jnp.transpose(latents, (0, 3, 1, 2))
+ latents = latents * vae.config.scaling_factor
+
+ noise_rng, timestep_rng = jax.random.split(sample_rng)
+ noise = jax.random.normal(noise_rng, latents.shape)
+ bsz = latents.shape[0]
+ timesteps = jax.random.randint(
+ timestep_rng,
+ (bsz,),
+ 0,
+ noise_scheduler.config.num_train_timesteps,
+ )
+ noisy_latents = noise_scheduler.add_noise(noise_scheduler_state, latents, noise, timesteps)
+ encoder_hidden_states = state.apply_fn(
+ batch["input_ids"], params=params, dropout_rng=dropout_rng, train=True
+ )[0]
+ # Predict the noise residual and compute loss
+ model_pred = unet.apply(
+ {"params": unet_params}, noisy_latents, timesteps, encoder_hidden_states, train=False
+ ).sample
+
+ # Get the target for loss depending on the prediction type
+ if noise_scheduler.config.prediction_type == "epsilon":
+ target = noise
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ target = noise_scheduler.get_velocity(noise_scheduler_state, latents, noise, timesteps)
+ else:
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
+
+ loss = (target - model_pred) ** 2
+ loss = loss.mean()
+
+ return loss
+
+ grad_fn = jax.value_and_grad(compute_loss)
+ loss, grad = grad_fn(state.params)
+ grad = jax.lax.pmean(grad, "batch")
+ new_state = state.apply_gradients(grads=grad)
+
+ # Keep the token embeddings fixed except the newly added embeddings for the concept,
+ # as we only want to optimize the concept embeddings
+ token_embeds = original_token_embeds.at[placeholder_token_id].set(
+ new_state.params["text_model"]["embeddings"]["token_embedding"]["embedding"][placeholder_token_id]
+ )
+ new_state.params["text_model"]["embeddings"]["token_embedding"]["embedding"] = token_embeds
+
+ metrics = {"loss": loss}
+ metrics = jax.lax.pmean(metrics, axis_name="batch")
+ return new_state, metrics, new_train_rng
+
+ # Create parallel version of the train and eval step
+ p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
+
+ # Replicate the train state on each device
+ state = jax_utils.replicate(state)
+ vae_params = jax_utils.replicate(vae_params)
+ unet_params = jax_utils.replicate(unet_params)
+
+ # Train!
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader))
+
+ # Scheduler and math around the number of training steps.
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel & distributed) = {total_train_batch_size}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+
+ global_step = 0
+
+ epochs = tqdm(range(args.num_train_epochs), desc=f"Epoch ... (1/{args.num_train_epochs})", position=0)
+ for epoch in epochs:
+ # ======================== Training ================================
+
+ train_metrics = []
+
+ steps_per_epoch = len(train_dataset) // total_train_batch_size
+ train_step_progress_bar = tqdm(total=steps_per_epoch, desc="Training...", position=1, leave=False)
+ # train
+ for batch in train_dataloader:
+ batch = shard(batch)
+ state, train_metric, train_rngs = p_train_step(state, vae_params, unet_params, batch, train_rngs)
+ train_metrics.append(train_metric)
+
+ train_step_progress_bar.update(1)
+ global_step += 1
+
+ if global_step >= args.max_train_steps:
+ break
+ if global_step % args.save_steps == 0:
+ learned_embeds = get_params_to_save(state.params)["text_model"]["embeddings"]["token_embedding"][
+ "embedding"
+ ][placeholder_token_id]
+ learned_embeds_dict = {args.placeholder_token: learned_embeds}
+ jnp.save(
+ os.path.join(args.output_dir, "learned_embeds-" + str(global_step) + ".npy"), learned_embeds_dict
+ )
+
+ train_metric = jax_utils.unreplicate(train_metric)
+
+ train_step_progress_bar.close()
+ epochs.write(f"Epoch... ({epoch + 1}/{args.num_train_epochs} | Loss: {train_metric['loss']})")
+
+ # Create the pipeline using using the trained modules and save it.
+ if jax.process_index() == 0:
+ scheduler = FlaxPNDMScheduler(
+ beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True
+ )
+ safety_checker = FlaxStableDiffusionSafetyChecker.from_pretrained(
+ "CompVis/stable-diffusion-safety-checker", from_pt=True
+ )
+ pipeline = FlaxStableDiffusionPipeline(
+ text_encoder=text_encoder,
+ vae=vae,
+ unet=unet,
+ tokenizer=tokenizer,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32"),
+ )
+
+ pipeline.save_pretrained(
+ args.output_dir,
+ params={
+ "text_encoder": get_params_to_save(state.params),
+ "vae": get_params_to_save(vae_params),
+ "unet": get_params_to_save(unet_params),
+ "safety_checker": safety_checker.params,
+ },
+ )
+
+ # Also save the newly trained embeddings
+ learned_embeds = get_params_to_save(state.params)["text_model"]["embeddings"]["token_embedding"]["embedding"][
+ placeholder_token_id
+ ]
+ learned_embeds_dict = {args.placeholder_token: learned_embeds}
+ jnp.save(os.path.join(args.output_dir, "learned_embeds.npy"), learned_embeds_dict)
+
+ if args.push_to_hub:
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+
+if __name__ == "__main__":
+ main()
diff --git a/diffusers/examples/textual_inversion/textual_inversion_sdxl.py b/diffusers/examples/textual_inversion/textual_inversion_sdxl.py
new file mode 100644
index 0000000000000000000000000000000000000000..a4629f0f43d613efef87f9f19d2b997bc2de5c74
--- /dev/null
+++ b/diffusers/examples/textual_inversion/textual_inversion_sdxl.py
@@ -0,0 +1,1122 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+import argparse
+import logging
+import math
+import os
+import random
+import shutil
+from pathlib import Path
+
+import numpy as np
+import PIL
+import safetensors
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import ProjectConfiguration, set_seed
+from huggingface_hub import create_repo, upload_folder
+from packaging import version
+from PIL import Image
+from torch.utils.data import Dataset
+from torchvision import transforms
+from tqdm.auto import tqdm
+from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
+
+import diffusers
+from diffusers import (
+ AutoencoderKL,
+ DDPMScheduler,
+ DiffusionPipeline,
+ DPMSolverMultistepScheduler,
+ UNet2DConditionModel,
+)
+from diffusers.optimization import get_scheduler
+from diffusers.utils import check_min_version, is_wandb_available
+from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
+from diffusers.utils.import_utils import is_xformers_available
+
+
+if is_wandb_available():
+ import wandb
+
+if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
+ PIL_INTERPOLATION = {
+ "linear": PIL.Image.Resampling.BILINEAR,
+ "bilinear": PIL.Image.Resampling.BILINEAR,
+ "bicubic": PIL.Image.Resampling.BICUBIC,
+ "lanczos": PIL.Image.Resampling.LANCZOS,
+ "nearest": PIL.Image.Resampling.NEAREST,
+ }
+else:
+ PIL_INTERPOLATION = {
+ "linear": PIL.Image.LINEAR,
+ "bilinear": PIL.Image.BILINEAR,
+ "bicubic": PIL.Image.BICUBIC,
+ "lanczos": PIL.Image.LANCZOS,
+ "nearest": PIL.Image.NEAREST,
+ }
+# ------------------------------------------------------------------------------
+
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.31.0.dev0")
+
+logger = get_logger(__name__)
+
+
+def save_model_card(repo_id: str, images=None, base_model=str, repo_folder=None):
+ img_str = ""
+ for i, image in enumerate(images):
+ image.save(os.path.join(repo_folder, f"image_{i}.png"))
+ img_str += f"\n"
+
+ model_description = f"""
+# Textual inversion text2image fine-tuning - {repo_id}
+These are textual inversion adaption weights for {base_model}. You can find some example images in the following. \n
+{img_str}
+"""
+ model_card = load_or_create_model_card(
+ repo_id_or_path=repo_id,
+ from_training=True,
+ license="creativeml-openrail-m",
+ base_model=base_model,
+ model_description=model_description,
+ inference=True,
+ )
+
+ tags = [
+ "stable-diffusion-xl",
+ "stable-diffusion-xl-diffusers",
+ "text-to-image",
+ "diffusers",
+ "diffusers-training",
+ "textual_inversion",
+ ]
+
+ model_card = populate_model_card(model_card, tags=tags)
+
+ model_card.save(os.path.join(repo_folder, "README.md"))
+
+
+def log_validation(
+ text_encoder_1,
+ text_encoder_2,
+ tokenizer_1,
+ tokenizer_2,
+ unet,
+ vae,
+ args,
+ accelerator,
+ weight_dtype,
+ epoch,
+ is_final_validation=False,
+):
+ logger.info(
+ f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
+ f" {args.validation_prompt}."
+ )
+ pipeline = DiffusionPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ text_encoder=accelerator.unwrap_model(text_encoder_1),
+ text_encoder_2=accelerator.unwrap_model(text_encoder_2),
+ tokenizer=tokenizer_1,
+ tokenizer_2=tokenizer_2,
+ unet=unet,
+ vae=vae,
+ safety_checker=None,
+ revision=args.revision,
+ variant=args.variant,
+ torch_dtype=weight_dtype,
+ )
+ pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
+ pipeline = pipeline.to(accelerator.device)
+ pipeline.set_progress_bar_config(disable=True)
+
+ # run inference
+ generator = None if args.seed is None else torch.Generator(device=accelerator.device).manual_seed(args.seed)
+ images = []
+ for _ in range(args.num_validation_images):
+ image = pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]
+ images.append(image)
+
+ tracker_key = "test" if is_final_validation else "validation"
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images(tracker_key, np_images, epoch, dataformats="NHWC")
+ if tracker.name == "wandb":
+ tracker.log(
+ {
+ tracker_key: [
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images)
+ ]
+ }
+ )
+
+ del pipeline
+ torch.cuda.empty_cache()
+ return images
+
+
+def save_progress(text_encoder, placeholder_token_ids, accelerator, args, save_path, safe_serialization=True):
+ logger.info("Saving embeddings")
+ learned_embeds = (
+ accelerator.unwrap_model(text_encoder)
+ .get_input_embeddings()
+ .weight[min(placeholder_token_ids) : max(placeholder_token_ids) + 1]
+ )
+ learned_embeds_dict = {args.placeholder_token: learned_embeds.detach().cpu()}
+
+ if safe_serialization:
+ safetensors.torch.save_file(learned_embeds_dict, save_path, metadata={"format": "pt"})
+ else:
+ torch.save(learned_embeds_dict, save_path)
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ parser.add_argument(
+ "--save_steps",
+ type=int,
+ default=500,
+ help="Save learned_embeds.bin every X updates steps.",
+ )
+ parser.add_argument(
+ "--save_as_full_pipeline",
+ action="store_true",
+ help="Save the complete stable diffusion pipeline.",
+ )
+ parser.add_argument(
+ "--num_vectors",
+ type=int,
+ default=1,
+ help="How many textual inversion vectors shall be used to learn the concept.",
+ )
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--variant",
+ type=str,
+ default=None,
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
+ )
+ parser.add_argument(
+ "--train_data_dir", type=str, default=None, required=True, help="A folder containing the training data."
+ )
+ parser.add_argument(
+ "--placeholder_token",
+ type=str,
+ default=None,
+ required=True,
+ help="A token to use as a placeholder for the concept.",
+ )
+ parser.add_argument(
+ "--initializer_token", type=str, default=None, required=True, help="A token to use as initializer word."
+ )
+ parser.add_argument("--learnable_property", type=str, default="object", help="Choose between 'object' and 'style'")
+ parser.add_argument("--repeats", type=int, default=100, help="How many times to repeat the training data.")
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="text-inversion-model",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution."
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=100)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=5000,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--lr_num_cycles",
+ type=int,
+ default=1,
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
+ )
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ parser.add_argument(
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default="no",
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose"
+ "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
+ "and an Nvidia Ampere GPU."
+ ),
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument(
+ "--validation_prompt",
+ type=str,
+ default=None,
+ help="A prompt that is used during validation to verify that the model is learning.",
+ )
+ parser.add_argument(
+ "--num_validation_images",
+ type=int,
+ default=4,
+ help="Number of images that should be generated during validation with `validation_prompt`.",
+ )
+ parser.add_argument(
+ "--validation_steps",
+ type=int,
+ default=100,
+ help=(
+ "Run validation every X steps. Validation consists of running the prompt"
+ " `args.validation_prompt` multiple times: `args.num_validation_images`"
+ " and logging the images."
+ ),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
+ )
+
+ args = parser.parse_args()
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ if args.train_data_dir is None:
+ raise ValueError("You must specify a train data directory.")
+
+ return args
+
+
+imagenet_templates_small = [
+ "a photo of a {}",
+ "a rendering of a {}",
+ "a cropped photo of the {}",
+ "the photo of a {}",
+ "a photo of a clean {}",
+ "a photo of a dirty {}",
+ "a dark photo of the {}",
+ "a photo of my {}",
+ "a photo of the cool {}",
+ "a close-up photo of a {}",
+ "a bright photo of the {}",
+ "a cropped photo of a {}",
+ "a photo of the {}",
+ "a good photo of the {}",
+ "a photo of one {}",
+ "a close-up photo of the {}",
+ "a rendition of the {}",
+ "a photo of the clean {}",
+ "a rendition of a {}",
+ "a photo of a nice {}",
+ "a good photo of a {}",
+ "a photo of the nice {}",
+ "a photo of the small {}",
+ "a photo of the weird {}",
+ "a photo of the large {}",
+ "a photo of a cool {}",
+ "a photo of a small {}",
+]
+
+imagenet_style_templates_small = [
+ "a painting in the style of {}",
+ "a rendering in the style of {}",
+ "a cropped painting in the style of {}",
+ "the painting in the style of {}",
+ "a clean painting in the style of {}",
+ "a dirty painting in the style of {}",
+ "a dark painting in the style of {}",
+ "a picture in the style of {}",
+ "a cool painting in the style of {}",
+ "a close-up painting in the style of {}",
+ "a bright painting in the style of {}",
+ "a cropped painting in the style of {}",
+ "a good painting in the style of {}",
+ "a close-up painting in the style of {}",
+ "a rendition in the style of {}",
+ "a nice painting in the style of {}",
+ "a small painting in the style of {}",
+ "a weird painting in the style of {}",
+ "a large painting in the style of {}",
+]
+
+
+class TextualInversionDataset(Dataset):
+ def __init__(
+ self,
+ data_root,
+ tokenizer_1,
+ tokenizer_2,
+ learnable_property="object", # [object, style]
+ size=512,
+ repeats=100,
+ interpolation="bicubic",
+ flip_p=0.5,
+ set="train",
+ placeholder_token="*",
+ center_crop=False,
+ ):
+ self.data_root = data_root
+ self.tokenizer_1 = tokenizer_1
+ self.tokenizer_2 = tokenizer_2
+ self.learnable_property = learnable_property
+ self.size = size
+ self.placeholder_token = placeholder_token
+ self.center_crop = center_crop
+ self.flip_p = flip_p
+
+ self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)]
+
+ self.num_images = len(self.image_paths)
+ self._length = self.num_images
+
+ if set == "train":
+ self._length = self.num_images * repeats
+
+ self.interpolation = {
+ "linear": PIL_INTERPOLATION["linear"],
+ "bilinear": PIL_INTERPOLATION["bilinear"],
+ "bicubic": PIL_INTERPOLATION["bicubic"],
+ "lanczos": PIL_INTERPOLATION["lanczos"],
+ }[interpolation]
+
+ self.templates = imagenet_style_templates_small if learnable_property == "style" else imagenet_templates_small
+ self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p)
+ self.crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size)
+
+ def __len__(self):
+ return self._length
+
+ def __getitem__(self, i):
+ example = {}
+ image = Image.open(self.image_paths[i % self.num_images])
+
+ if not image.mode == "RGB":
+ image = image.convert("RGB")
+
+ placeholder_string = self.placeholder_token
+ text = random.choice(self.templates).format(placeholder_string)
+
+ example["original_size"] = (image.height, image.width)
+
+ image = image.resize((self.size, self.size), resample=self.interpolation)
+
+ if self.center_crop:
+ y1 = max(0, int(round((image.height - self.size) / 2.0)))
+ x1 = max(0, int(round((image.width - self.size) / 2.0)))
+ image = self.crop(image)
+ else:
+ y1, x1, h, w = self.crop.get_params(image, (self.size, self.size))
+ image = transforms.functional.crop(image, y1, x1, h, w)
+
+ example["crop_top_left"] = (y1, x1)
+
+ example["input_ids_1"] = self.tokenizer_1(
+ text,
+ padding="max_length",
+ truncation=True,
+ max_length=self.tokenizer_1.model_max_length,
+ return_tensors="pt",
+ ).input_ids[0]
+
+ example["input_ids_2"] = self.tokenizer_2(
+ text,
+ padding="max_length",
+ truncation=True,
+ max_length=self.tokenizer_2.model_max_length,
+ return_tensors="pt",
+ ).input_ids[0]
+
+ # default to score-sde preprocessing
+ img = np.array(image).astype(np.uint8)
+
+ image = Image.fromarray(img)
+
+ image = self.flip_transform(image)
+ image = np.array(image).astype(np.uint8)
+ image = (image / 127.5 - 1.0).astype(np.float32)
+
+ example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1)
+ return example
+
+
+def main():
+ args = parse_args()
+ if args.report_to == "wandb" and args.hub_token is not None:
+ raise ValueError(
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
+ " Please use `huggingface-cli login` to authenticate with the Hub."
+ )
+
+ logging_dir = os.path.join(args.output_dir, args.logging_dir)
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ )
+
+ # Disable AMP for MPS.
+ if torch.backends.mps.is_available():
+ accelerator.native_amp = False
+
+ if args.report_to == "wandb":
+ if not is_wandb_available():
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
+ ).repo_id
+
+ # Load tokenizer
+ tokenizer_1 = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
+ tokenizer_2 = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer_2")
+
+ # Load scheduler and models
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
+ text_encoder_1 = CLIPTextModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
+ )
+ text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision
+ )
+ vae = AutoencoderKL.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant
+ )
+ unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
+ )
+
+ # Add the placeholder token in tokenizer_1
+ placeholder_tokens = [args.placeholder_token]
+
+ if args.num_vectors < 1:
+ raise ValueError(f"--num_vectors has to be larger or equal to 1, but is {args.num_vectors}")
+
+ # add dummy tokens for multi-vector
+ additional_tokens = []
+ for i in range(1, args.num_vectors):
+ additional_tokens.append(f"{args.placeholder_token}_{i}")
+ placeholder_tokens += additional_tokens
+
+ num_added_tokens = tokenizer_1.add_tokens(placeholder_tokens)
+ if num_added_tokens != args.num_vectors:
+ raise ValueError(
+ f"The tokenizer already contains the token {args.placeholder_token}. Please pass a different"
+ " `placeholder_token` that is not already in the tokenizer."
+ )
+ num_added_tokens = tokenizer_2.add_tokens(placeholder_tokens)
+ if num_added_tokens != args.num_vectors:
+ raise ValueError(
+ f"The 2nd tokenizer already contains the token {args.placeholder_token}. Please pass a different"
+ " `placeholder_token` that is not already in the tokenizer."
+ )
+
+ # Convert the initializer_token, placeholder_token to ids
+ token_ids = tokenizer_1.encode(args.initializer_token, add_special_tokens=False)
+ token_ids_2 = tokenizer_2.encode(args.initializer_token, add_special_tokens=False)
+
+ # Check if initializer_token is a single token or a sequence of tokens
+ if len(token_ids) > 1 or len(token_ids_2) > 1:
+ raise ValueError("The initializer token must be a single token.")
+
+ initializer_token_id = token_ids[0]
+ placeholder_token_ids = tokenizer_1.convert_tokens_to_ids(placeholder_tokens)
+ initializer_token_id_2 = token_ids_2[0]
+ placeholder_token_ids_2 = tokenizer_2.convert_tokens_to_ids(placeholder_tokens)
+
+ # Resize the token embeddings as we are adding new special tokens to the tokenizer
+ text_encoder_1.resize_token_embeddings(len(tokenizer_1))
+ text_encoder_2.resize_token_embeddings(len(tokenizer_2))
+
+ # Initialise the newly added placeholder token with the embeddings of the initializer token
+ token_embeds = text_encoder_1.get_input_embeddings().weight.data
+ token_embeds_2 = text_encoder_2.get_input_embeddings().weight.data
+ with torch.no_grad():
+ for token_id in placeholder_token_ids:
+ token_embeds[token_id] = token_embeds[initializer_token_id].clone()
+ for token_id in placeholder_token_ids_2:
+ token_embeds_2[token_id] = token_embeds_2[initializer_token_id_2].clone()
+
+ # Freeze vae and unet
+ vae.requires_grad_(False)
+ unet.requires_grad_(False)
+
+ # Freeze all parameters except for the token embeddings in text encoder
+ text_encoder_1.text_model.encoder.requires_grad_(False)
+ text_encoder_1.text_model.final_layer_norm.requires_grad_(False)
+ text_encoder_1.text_model.embeddings.position_embedding.requires_grad_(False)
+ text_encoder_2.text_model.encoder.requires_grad_(False)
+ text_encoder_2.text_model.final_layer_norm.requires_grad_(False)
+ text_encoder_2.text_model.embeddings.position_embedding.requires_grad_(False)
+
+ if args.gradient_checkpointing:
+ text_encoder_1.gradient_checkpointing_enable()
+ text_encoder_2.gradient_checkpointing_enable()
+
+ if args.enable_xformers_memory_efficient_attention:
+ if is_xformers_available():
+ import xformers
+
+ xformers_version = version.parse(xformers.__version__)
+ if xformers_version == version.parse("0.0.16"):
+ logger.warning(
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
+ )
+ unet.enable_xformers_memory_efficient_attention()
+ else:
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ # Initialize the optimizer
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
+ )
+
+ optimizer_class = bnb.optim.AdamW8bit
+ else:
+ optimizer_class = torch.optim.AdamW
+
+ optimizer = optimizer_class(
+ # only optimize the embeddings
+ [
+ text_encoder_1.text_model.embeddings.token_embedding.weight,
+ text_encoder_2.text_model.embeddings.token_embedding.weight,
+ ],
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ placeholder_token = " ".join(tokenizer_1.convert_ids_to_tokens(placeholder_token_ids))
+ # Dataset and DataLoaders creation:
+ train_dataset = TextualInversionDataset(
+ data_root=args.train_data_dir,
+ tokenizer_1=tokenizer_1,
+ tokenizer_2=tokenizer_2,
+ size=args.resolution,
+ placeholder_token=placeholder_token,
+ repeats=args.repeats,
+ learnable_property=args.learnable_property,
+ center_crop=args.center_crop,
+ set="train",
+ )
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers
+ )
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
+ num_cycles=args.lr_num_cycles,
+ )
+
+ text_encoder_1.train()
+ text_encoder_2.train()
+ # Prepare everything with our `accelerator`.
+ text_encoder_1, text_encoder_2, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ text_encoder_1, text_encoder_2, optimizer, train_dataloader, lr_scheduler
+ )
+
+ # For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision
+ # as these weights are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ # Move vae and unet and text_encoder_2 to device and cast to weight_dtype
+ unet.to(accelerator.device, dtype=weight_dtype)
+ vae.to(accelerator.device, dtype=weight_dtype)
+ text_encoder_2.to(accelerator.device, dtype=weight_dtype)
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ accelerator.init_trackers("textual_inversion", config=vars(args))
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the most recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ initial_global_step = 0
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ initial_global_step = global_step
+ first_epoch = global_step // num_update_steps_per_epoch
+
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
+
+ # keep original embeddings as reference
+ orig_embeds_params = accelerator.unwrap_model(text_encoder_1).get_input_embeddings().weight.data.clone()
+ orig_embeds_params_2 = accelerator.unwrap_model(text_encoder_2).get_input_embeddings().weight.data.clone()
+
+ for epoch in range(first_epoch, args.num_train_epochs):
+ text_encoder_1.train()
+ text_encoder_2.train()
+ for step, batch in enumerate(train_dataloader):
+ with accelerator.accumulate([text_encoder_1, text_encoder_2]):
+ # Convert images to latent space
+ latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample().detach()
+ latents = latents * vae.config.scaling_factor
+
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(latents)
+ bsz = latents.shape[0]
+ # Sample a random timestep for each image
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
+ timesteps = timesteps.long()
+
+ # Add noise to the latents according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
+
+ # Get the text embedding for conditioning
+ encoder_hidden_states_1 = (
+ text_encoder_1(batch["input_ids_1"], output_hidden_states=True)
+ .hidden_states[-2]
+ .to(dtype=weight_dtype)
+ )
+ encoder_output_2 = text_encoder_2(batch["input_ids_2"], output_hidden_states=True)
+ encoder_hidden_states_2 = encoder_output_2.hidden_states[-2].to(dtype=weight_dtype)
+ original_size = [
+ (batch["original_size"][0][i].item(), batch["original_size"][1][i].item())
+ for i in range(args.train_batch_size)
+ ]
+ crop_top_left = [
+ (batch["crop_top_left"][0][i].item(), batch["crop_top_left"][1][i].item())
+ for i in range(args.train_batch_size)
+ ]
+ target_size = (args.resolution, args.resolution)
+ add_time_ids = torch.cat(
+ [
+ torch.tensor(original_size[i] + crop_top_left[i] + target_size)
+ for i in range(args.train_batch_size)
+ ]
+ ).to(accelerator.device, dtype=weight_dtype)
+ added_cond_kwargs = {"text_embeds": encoder_output_2[0], "time_ids": add_time_ids}
+ encoder_hidden_states = torch.cat([encoder_hidden_states_1, encoder_hidden_states_2], dim=-1)
+
+ # Predict the noise residual
+ model_pred = unet(
+ noisy_latents, timesteps, encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
+ ).sample
+
+ # Get the target for loss depending on the prediction type
+ if noise_scheduler.config.prediction_type == "epsilon":
+ target = noise
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
+ else:
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
+
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
+
+ accelerator.backward(loss)
+
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ # Let's make sure we don't update any embedding weights besides the newly added token
+ index_no_updates = torch.ones((len(tokenizer_1),), dtype=torch.bool)
+ index_no_updates[min(placeholder_token_ids) : max(placeholder_token_ids) + 1] = False
+ index_no_updates_2 = torch.ones((len(tokenizer_2),), dtype=torch.bool)
+ index_no_updates_2[min(placeholder_token_ids_2) : max(placeholder_token_ids_2) + 1] = False
+
+ with torch.no_grad():
+ accelerator.unwrap_model(text_encoder_1).get_input_embeddings().weight[
+ index_no_updates
+ ] = orig_embeds_params[index_no_updates]
+ accelerator.unwrap_model(text_encoder_2).get_input_embeddings().weight[
+ index_no_updates_2
+ ] = orig_embeds_params_2[index_no_updates_2]
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ images = []
+ progress_bar.update(1)
+ global_step += 1
+ if global_step % args.save_steps == 0:
+ weight_name = f"learned_embeds-steps-{global_step}.safetensors"
+ save_path = os.path.join(args.output_dir, weight_name)
+ save_progress(
+ text_encoder_1,
+ placeholder_token_ids,
+ accelerator,
+ args,
+ save_path,
+ safe_serialization=True,
+ )
+ weight_name = f"learned_embeds_2-steps-{global_step}.safetensors"
+ save_path = os.path.join(args.output_dir, weight_name)
+ save_progress(
+ text_encoder_2,
+ placeholder_token_ids_2,
+ accelerator,
+ args,
+ save_path,
+ safe_serialization=True,
+ )
+
+ if accelerator.is_main_process:
+ if global_step % args.checkpointing_steps == 0:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ if args.validation_prompt is not None and global_step % args.validation_steps == 0:
+ images = log_validation(
+ text_encoder_1,
+ text_encoder_2,
+ tokenizer_1,
+ tokenizer_2,
+ unet,
+ vae,
+ args,
+ accelerator,
+ weight_dtype,
+ epoch,
+ )
+
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+ accelerator.log(logs, step=global_step)
+
+ if global_step >= args.max_train_steps:
+ break
+ # Create the pipeline using the trained modules and save it.
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ if args.validation_prompt:
+ images = log_validation(
+ text_encoder_1,
+ text_encoder_2,
+ tokenizer_1,
+ tokenizer_2,
+ unet,
+ vae,
+ args,
+ accelerator,
+ weight_dtype,
+ epoch,
+ is_final_validation=True,
+ )
+
+ if args.push_to_hub and not args.save_as_full_pipeline:
+ logger.warning("Enabling full model saving because --push_to_hub=True was specified.")
+ save_full_model = True
+ else:
+ save_full_model = args.save_as_full_pipeline
+ if save_full_model:
+ pipeline = DiffusionPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ text_encoder=accelerator.unwrap_model(text_encoder_1),
+ text_encoder_2=accelerator.unwrap_model(text_encoder_2),
+ vae=vae,
+ unet=unet,
+ tokenizer=tokenizer_1,
+ tokenizer_2=tokenizer_2,
+ )
+ pipeline.save_pretrained(args.output_dir)
+ # Save the newly trained embeddings
+ weight_name = "learned_embeds.safetensors"
+ save_path = os.path.join(args.output_dir, weight_name)
+ save_progress(
+ text_encoder_1,
+ placeholder_token_ids,
+ accelerator,
+ args,
+ save_path,
+ safe_serialization=True,
+ )
+ weight_name = "learned_embeds_2.safetensors"
+ save_path = os.path.join(args.output_dir, weight_name)
+ save_progress(
+ text_encoder_2,
+ placeholder_token_ids_2,
+ accelerator,
+ args,
+ save_path,
+ safe_serialization=True,
+ )
+
+ if args.push_to_hub:
+ save_model_card(
+ repo_id,
+ images=images,
+ base_model=args.pretrained_model_name_or_path,
+ repo_folder=args.output_dir,
+ )
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/diffusers/examples/unconditional_image_generation/README.md b/diffusers/examples/unconditional_image_generation/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..2990b3abf3f5b09af67ed2d0da96ce3bc297e1f6
--- /dev/null
+++ b/diffusers/examples/unconditional_image_generation/README.md
@@ -0,0 +1,163 @@
+## Training an unconditional diffusion model
+
+Creating a training image set is [described in a different document](https://huggingface.co/docs/datasets/image_process#image-datasets).
+
+### Installing the dependencies
+
+Before running the scripts, make sure to install the library's training dependencies:
+
+**Important**
+
+To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
+```bash
+git clone https://github.com/huggingface/diffusers
+cd diffusers
+pip install .
+```
+
+Then cd in the example folder and run
+```bash
+pip install -r requirements.txt
+```
+
+
+And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
+
+```bash
+accelerate config
+```
+
+### Unconditional Flowers
+
+The command to train a DDPM UNet model on the Oxford Flowers dataset:
+
+```bash
+accelerate launch train_unconditional.py \
+ --dataset_name="huggan/flowers-102-categories" \
+ --resolution=64 --center_crop --random_flip \
+ --output_dir="ddpm-ema-flowers-64" \
+ --train_batch_size=16 \
+ --num_epochs=100 \
+ --gradient_accumulation_steps=1 \
+ --use_ema \
+ --learning_rate=1e-4 \
+ --lr_warmup_steps=500 \
+ --mixed_precision=no \
+ --push_to_hub
+```
+An example trained model: https://huggingface.co/anton-l/ddpm-ema-flowers-64
+
+A full training run takes 2 hours on 4xV100 GPUs.
+
+
+
+
+### Unconditional Pokemon
+
+The command to train a DDPM UNet model on the Pokemon dataset:
+
+```bash
+accelerate launch train_unconditional.py \
+ --dataset_name="huggan/pokemon" \
+ --resolution=64 --center_crop --random_flip \
+ --output_dir="ddpm-ema-pokemon-64" \
+ --train_batch_size=16 \
+ --num_epochs=100 \
+ --gradient_accumulation_steps=1 \
+ --use_ema \
+ --learning_rate=1e-4 \
+ --lr_warmup_steps=500 \
+ --mixed_precision=no \
+ --push_to_hub
+```
+An example trained model: https://huggingface.co/anton-l/ddpm-ema-pokemon-64
+
+A full training run takes 2 hours on 4xV100 GPUs.
+
+
+
+### Training with multiple GPUs
+
+`accelerate` allows for seamless multi-GPU training. Follow the instructions [here](https://huggingface.co/docs/accelerate/basic_tutorials/launch)
+for running distributed training with `accelerate`. Here is an example command:
+
+```bash
+accelerate launch --mixed_precision="fp16" --multi_gpu train_unconditional.py \
+ --dataset_name="huggan/pokemon" \
+ --resolution=64 --center_crop --random_flip \
+ --output_dir="ddpm-ema-pokemon-64" \
+ --train_batch_size=16 \
+ --num_epochs=100 \
+ --gradient_accumulation_steps=1 \
+ --use_ema \
+ --learning_rate=1e-4 \
+ --lr_warmup_steps=500 \
+ --mixed_precision="fp16" \
+ --logger="wandb"
+```
+
+To be able to use Weights and Biases (`wandb`) as a logger you need to install the library: `pip install wandb`.
+
+### Using your own data
+
+To use your own dataset, there are 2 ways:
+- you can either provide your own folder as `--train_data_dir`
+- or you can upload your dataset to the hub (possibly as a private repo, if you prefer so), and simply pass the `--dataset_name` argument.
+
+Below, we explain both in more detail.
+
+#### Provide the dataset as a folder
+
+If you provide your own folders with images, the script expects the following directory structure:
+
+```bash
+data_dir/xxx.png
+data_dir/xxy.png
+data_dir/[...]/xxz.png
+```
+
+In other words, the script will take care of gathering all images inside the folder. You can then run the script like this:
+
+```bash
+accelerate launch train_unconditional.py \
+ --train_data_dir \
+
+```
+
+Internally, the script will use the [`ImageFolder`](https://huggingface.co/docs/datasets/v2.0.0/en/image_process#imagefolder) feature which will automatically turn the folders into 🤗 Dataset objects.
+
+#### Upload your data to the hub, as a (possibly private) repo
+
+It's very easy (and convenient) to upload your image dataset to the hub using the [`ImageFolder`](https://huggingface.co/docs/datasets/v2.0.0/en/image_process#imagefolder) feature available in 🤗 Datasets. Simply do the following:
+
+```python
+from datasets import load_dataset
+
+# example 1: local folder
+dataset = load_dataset("imagefolder", data_dir="path_to_your_folder")
+
+# example 2: local files (supported formats are tar, gzip, zip, xz, rar, zstd)
+dataset = load_dataset("imagefolder", data_files="path_to_zip_file")
+
+# example 3: remote files (supported formats are tar, gzip, zip, xz, rar, zstd)
+dataset = load_dataset("imagefolder", data_files="https://download.microsoft.com/download/3/E/1/3E1C3F21-ECDB-4869-8368-6DEBA77B919F/kagglecatsanddogs_3367a.zip")
+
+# example 4: providing several splits
+dataset = load_dataset("imagefolder", data_files={"train": ["path/to/file1", "path/to/file2"], "test": ["path/to/file3", "path/to/file4"]})
+```
+
+`ImageFolder` will create an `image` column containing the PIL-encoded images.
+
+Next, push it to the hub!
+
+```python
+# assuming you have ran the huggingface-cli login command in a terminal
+dataset.push_to_hub("name_of_your_dataset")
+
+# if you want to push to a private repo, simply pass private=True:
+dataset.push_to_hub("name_of_your_dataset", private=True)
+```
+
+and that's it! You can now train your model by simply setting the `--dataset_name` argument to the name of your dataset on the hub.
+
+More on this can also be found in [this blog post](https://huggingface.co/blog/image-search-datasets).
diff --git a/diffusers/examples/unconditional_image_generation/requirements.txt b/diffusers/examples/unconditional_image_generation/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..f366720afd11e41945a3f29472be2048dbf98404
--- /dev/null
+++ b/diffusers/examples/unconditional_image_generation/requirements.txt
@@ -0,0 +1,3 @@
+accelerate>=0.16.0
+torchvision
+datasets
diff --git a/diffusers/examples/unconditional_image_generation/test_unconditional.py b/diffusers/examples/unconditional_image_generation/test_unconditional.py
new file mode 100644
index 0000000000000000000000000000000000000000..ef71da0a114c5b1cc36ebb1bfb3eacdeaa2598f0
--- /dev/null
+++ b/diffusers/examples/unconditional_image_generation/test_unconditional.py
@@ -0,0 +1,130 @@
+# coding=utf-8
+# Copyright 2024 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import os
+import sys
+import tempfile
+
+
+sys.path.append("..")
+from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
+
+
+logging.basicConfig(level=logging.DEBUG)
+
+logger = logging.getLogger()
+stream_handler = logging.StreamHandler(sys.stdout)
+logger.addHandler(stream_handler)
+
+
+class Unconditional(ExamplesTestsAccelerate):
+ def test_train_unconditional(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/unconditional_image_generation/train_unconditional.py
+ --dataset_name hf-internal-testing/dummy_image_class_data
+ --model_config_name_or_path diffusers/ddpm_dummy
+ --resolution 64
+ --output_dir {tmpdir}
+ --train_batch_size 2
+ --num_epochs 1
+ --gradient_accumulation_steps 1
+ --ddpm_num_inference_steps 2
+ --learning_rate 1e-3
+ --lr_warmup_steps 5
+ """.split()
+
+ run_command(self._launch_args + test_args, return_stdout=True)
+ # save_pretrained smoke test
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.safetensors")))
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json")))
+
+ def test_unconditional_checkpointing_checkpoints_total_limit(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ initial_run_args = f"""
+ examples/unconditional_image_generation/train_unconditional.py
+ --dataset_name hf-internal-testing/dummy_image_class_data
+ --model_config_name_or_path diffusers/ddpm_dummy
+ --resolution 64
+ --output_dir {tmpdir}
+ --train_batch_size 1
+ --num_epochs 1
+ --gradient_accumulation_steps 1
+ --ddpm_num_inference_steps 2
+ --learning_rate 1e-3
+ --lr_warmup_steps 5
+ --checkpointing_steps=2
+ --checkpoints_total_limit=2
+ """.split()
+
+ run_command(self._launch_args + initial_run_args)
+
+ # check checkpoint directories exist
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ # checkpoint-2 should have been deleted
+ {"checkpoint-4", "checkpoint-6"},
+ )
+
+ def test_unconditional_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ initial_run_args = f"""
+ examples/unconditional_image_generation/train_unconditional.py
+ --dataset_name hf-internal-testing/dummy_image_class_data
+ --model_config_name_or_path diffusers/ddpm_dummy
+ --resolution 64
+ --output_dir {tmpdir}
+ --train_batch_size 1
+ --num_epochs 1
+ --gradient_accumulation_steps 1
+ --ddpm_num_inference_steps 1
+ --learning_rate 1e-3
+ --lr_warmup_steps 5
+ --checkpointing_steps=2
+ """.split()
+
+ run_command(self._launch_args + initial_run_args)
+
+ # check checkpoint directories exist
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-2", "checkpoint-4", "checkpoint-6"},
+ )
+
+ resume_run_args = f"""
+ examples/unconditional_image_generation/train_unconditional.py
+ --dataset_name hf-internal-testing/dummy_image_class_data
+ --model_config_name_or_path diffusers/ddpm_dummy
+ --resolution 64
+ --output_dir {tmpdir}
+ --train_batch_size 1
+ --num_epochs 2
+ --gradient_accumulation_steps 1
+ --ddpm_num_inference_steps 1
+ --learning_rate 1e-3
+ --lr_warmup_steps 5
+ --resume_from_checkpoint=checkpoint-6
+ --checkpointing_steps=2
+ --checkpoints_total_limit=2
+ """.split()
+
+ run_command(self._launch_args + resume_run_args)
+
+ # check checkpoint directories exist
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-10", "checkpoint-12"},
+ )
diff --git a/diffusers/examples/unconditional_image_generation/train_unconditional.py b/diffusers/examples/unconditional_image_generation/train_unconditional.py
new file mode 100644
index 0000000000000000000000000000000000000000..1f5e1de240cb5e218d85828d73dd92d88d39d778
--- /dev/null
+++ b/diffusers/examples/unconditional_image_generation/train_unconditional.py
@@ -0,0 +1,704 @@
+import argparse
+import inspect
+import logging
+import math
+import os
+import shutil
+from datetime import timedelta
+from pathlib import Path
+
+import accelerate
+import datasets
+import torch
+import torch.nn.functional as F
+from accelerate import Accelerator, InitProcessGroupKwargs
+from accelerate.logging import get_logger
+from accelerate.utils import ProjectConfiguration
+from datasets import load_dataset
+from huggingface_hub import create_repo, upload_folder
+from packaging import version
+from torchvision import transforms
+from tqdm.auto import tqdm
+
+import diffusers
+from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
+from diffusers.optimization import get_scheduler
+from diffusers.training_utils import EMAModel
+from diffusers.utils import check_min_version, is_accelerate_version, is_tensorboard_available, is_wandb_available
+from diffusers.utils.import_utils import is_xformers_available
+
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.31.0.dev0")
+
+logger = get_logger(__name__, log_level="INFO")
+
+
+def _extract_into_tensor(arr, timesteps, broadcast_shape):
+ """
+ Extract values from a 1-D numpy array for a batch of indices.
+
+ :param arr: the 1-D numpy array.
+ :param timesteps: a tensor of indices into the array to extract.
+ :param broadcast_shape: a larger shape of K dimensions with the batch
+ dimension equal to the length of timesteps.
+ :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
+ """
+ if not isinstance(arr, torch.Tensor):
+ arr = torch.from_numpy(arr)
+ res = arr[timesteps].float().to(timesteps.device)
+ while len(res.shape) < len(broadcast_shape):
+ res = res[..., None]
+ return res.expand(broadcast_shape)
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ parser.add_argument(
+ "--dataset_name",
+ type=str,
+ default=None,
+ help=(
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
+ " or to a folder containing files that HF Datasets can understand."
+ ),
+ )
+ parser.add_argument(
+ "--dataset_config_name",
+ type=str,
+ default=None,
+ help="The config of the Dataset, leave as None if there's only one config.",
+ )
+ parser.add_argument(
+ "--model_config_name_or_path",
+ type=str,
+ default=None,
+ help="The config of the UNet model to train, leave as None to use standard DDPM configuration.",
+ )
+ parser.add_argument(
+ "--train_data_dir",
+ type=str,
+ default=None,
+ help=(
+ "A folder containing the training data. Folder contents must follow the structure described in"
+ " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
+ " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
+ ),
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="ddpm-model-64",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument("--overwrite_output_dir", action="store_true")
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=64,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--center_crop",
+ default=False,
+ action="store_true",
+ help=(
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
+ " cropped. The images will be resized to the resolution first before cropping."
+ ),
+ )
+ parser.add_argument(
+ "--random_flip",
+ default=False,
+ action="store_true",
+ help="whether to randomly flip images horizontally",
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument(
+ "--eval_batch_size", type=int, default=16, help="The number of images to generate for evaluation."
+ )
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "The number of subprocesses to use for data loading. 0 means that the data will be loaded in the main"
+ " process."
+ ),
+ )
+ parser.add_argument("--num_epochs", type=int, default=100)
+ parser.add_argument("--save_images_epochs", type=int, default=10, help="How often to save images during training.")
+ parser.add_argument(
+ "--save_model_epochs", type=int, default=10, help="How often to save the model during training."
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="cosine",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.95, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument(
+ "--adam_weight_decay", type=float, default=1e-6, help="Weight decay magnitude for the Adam optimizer."
+ )
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer.")
+ parser.add_argument(
+ "--use_ema",
+ action="store_true",
+ help="Whether to use Exponential Moving Average for the final model weights.",
+ )
+ parser.add_argument("--ema_inv_gamma", type=float, default=1.0, help="The inverse gamma value for the EMA decay.")
+ parser.add_argument("--ema_power", type=float, default=3 / 4, help="The power value for the EMA decay.")
+ parser.add_argument("--ema_max_decay", type=float, default=0.9999, help="The maximum decay magnitude for EMA.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--hub_private_repo", action="store_true", help="Whether or not to create a private repository."
+ )
+ parser.add_argument(
+ "--logger",
+ type=str,
+ default="tensorboard",
+ choices=["tensorboard", "wandb"],
+ help=(
+ "Whether to use [tensorboard](https://www.tensorflow.org/tensorboard) or [wandb](https://www.wandb.ai)"
+ " for experiment tracking and logging of model metrics and model checkpoints"
+ ),
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default="no",
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose"
+ "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
+ "and an Nvidia Ampere GPU."
+ ),
+ )
+ parser.add_argument(
+ "--prediction_type",
+ type=str,
+ default="epsilon",
+ choices=["epsilon", "sample"],
+ help="Whether the model should predict the 'epsilon'/noise error or directly the reconstructed image 'x0'.",
+ )
+ parser.add_argument("--ddpm_num_steps", type=int, default=1000)
+ parser.add_argument("--ddpm_num_inference_steps", type=int, default=1000)
+ parser.add_argument("--ddpm_beta_schedule", type=str, default="linear")
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
+ )
+
+ args = parser.parse_args()
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ if args.dataset_name is None and args.train_data_dir is None:
+ raise ValueError("You must specify either a dataset name from the hub or a train data directory.")
+
+ return args
+
+
+def main(args):
+ logging_dir = os.path.join(args.output_dir, args.logging_dir)
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+
+ kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=7200)) # a big number for high resolution or big dataset
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.logger,
+ project_config=accelerator_project_config,
+ kwargs_handlers=[kwargs],
+ )
+
+ if args.logger == "tensorboard":
+ if not is_tensorboard_available():
+ raise ImportError("Make sure to install tensorboard if you want to use it for logging during training.")
+
+ elif args.logger == "wandb":
+ if not is_wandb_available():
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
+ import wandb
+
+ # `accelerate` 0.16.0 will have better support for customized saving
+ if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
+ def save_model_hook(models, weights, output_dir):
+ if accelerator.is_main_process:
+ if args.use_ema:
+ ema_model.save_pretrained(os.path.join(output_dir, "unet_ema"))
+
+ for i, model in enumerate(models):
+ model.save_pretrained(os.path.join(output_dir, "unet"))
+
+ # make sure to pop weight so that corresponding model is not saved again
+ weights.pop()
+
+ def load_model_hook(models, input_dir):
+ if args.use_ema:
+ load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), UNet2DModel)
+ ema_model.load_state_dict(load_model.state_dict())
+ ema_model.to(accelerator.device)
+ del load_model
+
+ for i in range(len(models)):
+ # pop models so that they are not loaded again
+ model = models.pop()
+
+ # load diffusers style into model
+ load_model = UNet2DModel.from_pretrained(input_dir, subfolder="unet")
+ model.register_to_config(**load_model.config)
+
+ model.load_state_dict(load_model.state_dict())
+ del load_model
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ datasets.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ datasets.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
+ ).repo_id
+
+ # Initialize the model
+ if args.model_config_name_or_path is None:
+ model = UNet2DModel(
+ sample_size=args.resolution,
+ in_channels=3,
+ out_channels=3,
+ layers_per_block=2,
+ block_out_channels=(128, 128, 256, 256, 512, 512),
+ down_block_types=(
+ "DownBlock2D",
+ "DownBlock2D",
+ "DownBlock2D",
+ "DownBlock2D",
+ "AttnDownBlock2D",
+ "DownBlock2D",
+ ),
+ up_block_types=(
+ "UpBlock2D",
+ "AttnUpBlock2D",
+ "UpBlock2D",
+ "UpBlock2D",
+ "UpBlock2D",
+ "UpBlock2D",
+ ),
+ )
+ else:
+ config = UNet2DModel.load_config(args.model_config_name_or_path)
+ model = UNet2DModel.from_config(config)
+
+ # Create EMA for the model.
+ if args.use_ema:
+ ema_model = EMAModel(
+ model.parameters(),
+ decay=args.ema_max_decay,
+ use_ema_warmup=True,
+ inv_gamma=args.ema_inv_gamma,
+ power=args.ema_power,
+ model_cls=UNet2DModel,
+ model_config=model.config,
+ )
+
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ args.mixed_precision = accelerator.mixed_precision
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+ args.mixed_precision = accelerator.mixed_precision
+
+ if args.enable_xformers_memory_efficient_attention:
+ if is_xformers_available():
+ import xformers
+
+ xformers_version = version.parse(xformers.__version__)
+ if xformers_version == version.parse("0.0.16"):
+ logger.warning(
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
+ )
+ model.enable_xformers_memory_efficient_attention()
+ else:
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
+
+ # Initialize the scheduler
+ accepts_prediction_type = "prediction_type" in set(inspect.signature(DDPMScheduler.__init__).parameters.keys())
+ if accepts_prediction_type:
+ noise_scheduler = DDPMScheduler(
+ num_train_timesteps=args.ddpm_num_steps,
+ beta_schedule=args.ddpm_beta_schedule,
+ prediction_type=args.prediction_type,
+ )
+ else:
+ noise_scheduler = DDPMScheduler(num_train_timesteps=args.ddpm_num_steps, beta_schedule=args.ddpm_beta_schedule)
+
+ # Initialize the optimizer
+ optimizer = torch.optim.AdamW(
+ model.parameters(),
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ # Get the datasets: you can either provide your own training and evaluation files (see below)
+ # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
+
+ # In distributed training, the load_dataset function guarantees that only one local process can concurrently
+ # download the dataset.
+ if args.dataset_name is not None:
+ dataset = load_dataset(
+ args.dataset_name,
+ args.dataset_config_name,
+ cache_dir=args.cache_dir,
+ split="train",
+ )
+ else:
+ dataset = load_dataset("imagefolder", data_dir=args.train_data_dir, cache_dir=args.cache_dir, split="train")
+ # See more about loading custom images at
+ # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
+
+ # Preprocessing the datasets and DataLoaders creation.
+ augmentations = transforms.Compose(
+ [
+ transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
+ transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
+ transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+ )
+
+ def transform_images(examples):
+ images = [augmentations(image.convert("RGB")) for image in examples["image"]]
+ return {"input": images}
+
+ logger.info(f"Dataset size: {len(dataset)}")
+
+ dataset.set_transform(transform_images)
+ train_dataloader = torch.utils.data.DataLoader(
+ dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers
+ )
+
+ # Initialize the learning rate scheduler
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
+ num_training_steps=(len(train_dataloader) * args.num_epochs),
+ )
+
+ # Prepare everything with our `accelerator`.
+ model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ model, optimizer, train_dataloader, lr_scheduler
+ )
+
+ if args.use_ema:
+ ema_model.to(accelerator.device)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ run = os.path.split(__file__)[-1].split(".")[0]
+ accelerator.init_trackers(run)
+
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ max_train_steps = args.num_epochs * num_update_steps_per_epoch
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(dataset)}")
+ logger.info(f" Num Epochs = {args.num_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {max_train_steps}")
+
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the most recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ resume_global_step = global_step * args.gradient_accumulation_steps
+ first_epoch = global_step // num_update_steps_per_epoch
+ resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
+
+ # Train!
+ for epoch in range(first_epoch, args.num_epochs):
+ model.train()
+ progress_bar = tqdm(total=num_update_steps_per_epoch, disable=not accelerator.is_local_main_process)
+ progress_bar.set_description(f"Epoch {epoch}")
+ for step, batch in enumerate(train_dataloader):
+ # Skip steps until we reach the resumed step
+ if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
+ if step % args.gradient_accumulation_steps == 0:
+ progress_bar.update(1)
+ continue
+
+ clean_images = batch["input"].to(weight_dtype)
+ # Sample noise that we'll add to the images
+ noise = torch.randn(clean_images.shape, dtype=weight_dtype, device=clean_images.device)
+ bsz = clean_images.shape[0]
+ # Sample a random timestep for each image
+ timesteps = torch.randint(
+ 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=clean_images.device
+ ).long()
+
+ # Add noise to the clean images according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps)
+
+ with accelerator.accumulate(model):
+ # Predict the noise residual
+ model_output = model(noisy_images, timesteps).sample
+
+ if args.prediction_type == "epsilon":
+ loss = F.mse_loss(model_output.float(), noise.float()) # this could have different weights!
+ elif args.prediction_type == "sample":
+ alpha_t = _extract_into_tensor(
+ noise_scheduler.alphas_cumprod, timesteps, (clean_images.shape[0], 1, 1, 1)
+ )
+ snr_weights = alpha_t / (1 - alpha_t)
+ # use SNR weighting from distillation paper
+ loss = snr_weights * F.mse_loss(model_output.float(), clean_images.float(), reduction="none")
+ loss = loss.mean()
+ else:
+ raise ValueError(f"Unsupported prediction type: {args.prediction_type}")
+
+ accelerator.backward(loss)
+
+ if accelerator.sync_gradients:
+ accelerator.clip_grad_norm_(model.parameters(), 1.0)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ if args.use_ema:
+ ema_model.step(model.parameters())
+ progress_bar.update(1)
+ global_step += 1
+
+ if accelerator.is_main_process:
+ if global_step % args.checkpointing_steps == 0:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step}
+ if args.use_ema:
+ logs["ema_decay"] = ema_model.cur_decay_value
+ progress_bar.set_postfix(**logs)
+ accelerator.log(logs, step=global_step)
+ progress_bar.close()
+
+ accelerator.wait_for_everyone()
+
+ # Generate sample images for visual inspection
+ if accelerator.is_main_process:
+ if epoch % args.save_images_epochs == 0 or epoch == args.num_epochs - 1:
+ unet = accelerator.unwrap_model(model)
+
+ if args.use_ema:
+ ema_model.store(unet.parameters())
+ ema_model.copy_to(unet.parameters())
+
+ pipeline = DDPMPipeline(
+ unet=unet,
+ scheduler=noise_scheduler,
+ )
+
+ generator = torch.Generator(device=pipeline.device).manual_seed(0)
+ # run pipeline in inference (sample random noise and denoise)
+ images = pipeline(
+ generator=generator,
+ batch_size=args.eval_batch_size,
+ num_inference_steps=args.ddpm_num_inference_steps,
+ output_type="np",
+ ).images
+
+ if args.use_ema:
+ ema_model.restore(unet.parameters())
+
+ # denormalize the images and save to tensorboard
+ images_processed = (images * 255).round().astype("uint8")
+
+ if args.logger == "tensorboard":
+ if is_accelerate_version(">=", "0.17.0.dev0"):
+ tracker = accelerator.get_tracker("tensorboard", unwrap=True)
+ else:
+ tracker = accelerator.get_tracker("tensorboard")
+ tracker.add_images("test_samples", images_processed.transpose(0, 3, 1, 2), epoch)
+ elif args.logger == "wandb":
+ # Upcoming `log_images` helper coming in https://github.com/huggingface/accelerate/pull/962/files
+ accelerator.get_tracker("wandb").log(
+ {"test_samples": [wandb.Image(img) for img in images_processed], "epoch": epoch},
+ step=global_step,
+ )
+
+ if epoch % args.save_model_epochs == 0 or epoch == args.num_epochs - 1:
+ # save the model
+ unet = accelerator.unwrap_model(model)
+
+ if args.use_ema:
+ ema_model.store(unet.parameters())
+ ema_model.copy_to(unet.parameters())
+
+ pipeline = DDPMPipeline(
+ unet=unet,
+ scheduler=noise_scheduler,
+ )
+
+ pipeline.save_pretrained(args.output_dir)
+
+ if args.use_ema:
+ ema_model.restore(unet.parameters())
+
+ if args.push_to_hub:
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message=f"Epoch {epoch}",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ main(args)
diff --git a/diffusers/examples/vqgan/README.md b/diffusers/examples/vqgan/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..056bbcaf67487dd538f7dff8b075b2be4a5f814a
--- /dev/null
+++ b/diffusers/examples/vqgan/README.md
@@ -0,0 +1,127 @@
+## Training an VQGAN VAE
+VQVAEs were first introduced in [Neural Discrete Representation Learning](https://arxiv.org/abs/1711.00937) and was combined with a GAN in the paper [Taming Transformers for High-Resolution Image Synthesis](https://arxiv.org/abs/2012.09841). The basic idea of a VQVAE is it's a type of a variational auto encoder with tokens as the latent space similar to tokens for LLMs. This script was adapted from a [pr to huggingface's open-muse project](https://github.com/huggingface/open-muse/pull/52) with general code following [lucidrian's implementation of the vqgan training script](https://github.com/lucidrains/muse-maskgit-pytorch/blob/main/muse_maskgit_pytorch/trainers.py) but both of these implementation follow from the [taming transformer repo](https://github.com/CompVis/taming-transformers?tab=readme-ov-file).
+
+
+Creating a training image set is [described in a different document](https://huggingface.co/docs/datasets/image_process#image-datasets).
+
+### Installing the dependencies
+
+Before running the scripts, make sure to install the library's training dependencies:
+
+**Important**
+
+To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
+```bash
+git clone https://github.com/huggingface/diffusers
+cd diffusers
+pip install .
+```
+
+Then cd in the example folder and run
+```bash
+pip install -r requirements.txt
+```
+
+
+And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
+
+```bash
+accelerate config
+```
+
+### Training on CIFAR10
+
+The command to train a VQGAN model on cifar10 dataset:
+
+```bash
+accelerate launch train_vqgan.py \
+ --dataset_name=cifar10 \
+ --image_column=img \
+ --validation_images images/bird.jpg images/car.jpg images/dog.jpg images/frog.jpg \
+ --resolution=128 \
+ --train_batch_size=2 \
+ --gradient_accumulation_steps=8 \
+ --report_to=wandb
+```
+
+An example training run is [here](https://wandb.ai/sayakpaul/vqgan-training/runs/0m5kzdfp) by @sayakpaul and a lower scale one [here](https://wandb.ai/dsbuddy27/vqgan-training/runs/eqd6xi4n?nw=nwuserisamu). The validation images can be obtained from [here](https://huggingface.co/datasets/diffusers/docs-images/tree/main/vqgan_validation_images).
+The simplest way to improve the quality of a VQGAN model is to maximize the amount of information present in the bottleneck. The easiest way to do this is increasing the image resolution. However, other ways include, but not limited to, lowering compression by downsampling fewer times or increasing the vocabulary size which at most can be around 16384. How to do this is shown below.
+
+# Modifying the architecture
+
+To modify the architecture of the vqgan model you can save the config taken from [here](https://huggingface.co/kandinsky-community/kandinsky-2-2-decoder/blob/main/movq/config.json) and then provide that to the script with the option --model_config_name_or_path. This config is below
+```
+{
+ "_class_name": "VQModel",
+ "_diffusers_version": "0.17.0.dev0",
+ "act_fn": "silu",
+ "block_out_channels": [
+ 128,
+ 256,
+ 256,
+ 512
+ ],
+ "down_block_types": [
+ "DownEncoderBlock2D",
+ "DownEncoderBlock2D",
+ "DownEncoderBlock2D",
+ "AttnDownEncoderBlock2D"
+ ],
+ "in_channels": 3,
+ "latent_channels": 4,
+ "layers_per_block": 2,
+ "norm_num_groups": 32,
+ "norm_type": "spatial",
+ "num_vq_embeddings": 16384,
+ "out_channels": 3,
+ "sample_size": 32,
+ "scaling_factor": 0.18215,
+ "up_block_types": [
+ "AttnUpDecoderBlock2D",
+ "UpDecoderBlock2D",
+ "UpDecoderBlock2D",
+ "UpDecoderBlock2D"
+ ],
+ "vq_embed_dim": 4
+}
+```
+To lower the amount of layers in a VQGan, you can remove layers by modifying the block_out_channels, down_block_types, and up_block_types like below
+```
+{
+ "_class_name": "VQModel",
+ "_diffusers_version": "0.17.0.dev0",
+ "act_fn": "silu",
+ "block_out_channels": [
+ 128,
+ 256,
+ 256,
+ ],
+ "down_block_types": [
+ "DownEncoderBlock2D",
+ "DownEncoderBlock2D",
+ "DownEncoderBlock2D",
+ ],
+ "in_channels": 3,
+ "latent_channels": 4,
+ "layers_per_block": 2,
+ "norm_num_groups": 32,
+ "norm_type": "spatial",
+ "num_vq_embeddings": 16384,
+ "out_channels": 3,
+ "sample_size": 32,
+ "scaling_factor": 0.18215,
+ "up_block_types": [
+ "UpDecoderBlock2D",
+ "UpDecoderBlock2D",
+ "UpDecoderBlock2D"
+ ],
+ "vq_embed_dim": 4
+}
+```
+For increasing the size of the vocabularies you can increase num_vq_embeddings. However, [some research](https://magvit.cs.cmu.edu/v2/) shows that the representation of VQGANs start degrading after 2^14~16384 vq embeddings so it's not recommended to go past that.
+
+## Extra training tips/ideas
+During logging take care to make sure data_time is low. data_time is the amount spent loading the data and where the GPU is not active. So essentially, it's the time wasted. The easiest way to lower data time is to increase the --dataloader_num_workers to a higher number like 4. Due to a bug in Pytorch, this only works on linux based systems. For more details check [here](https://github.com/huggingface/diffusers/issues/7646)
+Secondly, training should seem to be done when both the discriminator and the generator loss converges.
+Thirdly, another low hanging fruit is just using ema using the --use_ema parameter. This tends to make the output images smoother. This has a con where you have to lower your batch size by 1 but it may be worth it.
+Another more experimental low hanging fruit is changing from the vgg19 to different models for the lpips loss using the --timm_model_backend. If you do this, I recommend also changing the timm_model_layers parameter to the layer in your model which you think is best for representation. However, be careful with the feature map norms since this can easily overdominate the loss.
\ No newline at end of file
diff --git a/diffusers/examples/vqgan/discriminator.py b/diffusers/examples/vqgan/discriminator.py
new file mode 100644
index 0000000000000000000000000000000000000000..eb31c3cb0165533f5587f6c40a046ae7c6555882
--- /dev/null
+++ b/diffusers/examples/vqgan/discriminator.py
@@ -0,0 +1,48 @@
+"""
+Ported from Paella
+"""
+
+import torch
+from torch import nn
+
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.models.modeling_utils import ModelMixin
+
+
+# Discriminator model ported from Paella https://github.com/dome272/Paella/blob/main/src_distributed/vqgan.py
+class Discriminator(ModelMixin, ConfigMixin):
+ @register_to_config
+ def __init__(self, in_channels=3, cond_channels=0, hidden_channels=512, depth=6):
+ super().__init__()
+ d = max(depth - 3, 3)
+ layers = [
+ nn.utils.spectral_norm(
+ nn.Conv2d(in_channels, hidden_channels // (2**d), kernel_size=3, stride=2, padding=1)
+ ),
+ nn.LeakyReLU(0.2),
+ ]
+ for i in range(depth - 1):
+ c_in = hidden_channels // (2 ** max((d - i), 0))
+ c_out = hidden_channels // (2 ** max((d - 1 - i), 0))
+ layers.append(nn.utils.spectral_norm(nn.Conv2d(c_in, c_out, kernel_size=3, stride=2, padding=1)))
+ layers.append(nn.InstanceNorm2d(c_out))
+ layers.append(nn.LeakyReLU(0.2))
+ self.encoder = nn.Sequential(*layers)
+ self.shuffle = nn.Conv2d(
+ (hidden_channels + cond_channels) if cond_channels > 0 else hidden_channels, 1, kernel_size=1
+ )
+ self.logits = nn.Sigmoid()
+
+ def forward(self, x, cond=None):
+ x = self.encoder(x)
+ if cond is not None:
+ cond = cond.view(
+ cond.size(0),
+ cond.size(1),
+ 1,
+ 1,
+ ).expand(-1, -1, x.size(-2), x.size(-1))
+ x = torch.cat([x, cond], dim=1)
+ x = self.shuffle(x)
+ x = self.logits(x)
+ return x
diff --git a/diffusers/examples/vqgan/requirements.txt b/diffusers/examples/vqgan/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..f204a70f1e0e7fc38f5f262a697d9781beb300e6
--- /dev/null
+++ b/diffusers/examples/vqgan/requirements.txt
@@ -0,0 +1,8 @@
+accelerate>=0.16.0
+torchvision
+transformers>=4.25.1
+datasets
+timm
+numpy
+tqdm
+tensorboard
\ No newline at end of file
diff --git a/diffusers/examples/vqgan/test_vqgan.py b/diffusers/examples/vqgan/test_vqgan.py
new file mode 100644
index 0000000000000000000000000000000000000000..664a7f7365b05cff8edaa8d20f9fbb01e3c56407
--- /dev/null
+++ b/diffusers/examples/vqgan/test_vqgan.py
@@ -0,0 +1,395 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import logging
+import os
+import shutil
+import sys
+import tempfile
+
+import torch
+
+from diffusers import VQModel
+from diffusers.utils.testing_utils import require_timm
+
+
+sys.path.append("..")
+from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
+
+
+logging.basicConfig(level=logging.DEBUG)
+
+logger = logging.getLogger()
+stream_handler = logging.StreamHandler(sys.stdout)
+logger.addHandler(stream_handler)
+
+
+@require_timm
+class TextToImage(ExamplesTestsAccelerate):
+ @property
+ def test_vqmodel_config(self):
+ return {
+ "_class_name": "VQModel",
+ "_diffusers_version": "0.17.0.dev0",
+ "act_fn": "silu",
+ "block_out_channels": [
+ 32,
+ ],
+ "down_block_types": [
+ "DownEncoderBlock2D",
+ ],
+ "in_channels": 3,
+ "latent_channels": 4,
+ "layers_per_block": 2,
+ "norm_num_groups": 32,
+ "norm_type": "spatial",
+ "num_vq_embeddings": 32,
+ "out_channels": 3,
+ "sample_size": 32,
+ "scaling_factor": 0.18215,
+ "up_block_types": [
+ "UpDecoderBlock2D",
+ ],
+ "vq_embed_dim": 4,
+ }
+
+ @property
+ def test_discriminator_config(self):
+ return {
+ "_class_name": "Discriminator",
+ "_diffusers_version": "0.27.0.dev0",
+ "in_channels": 3,
+ "cond_channels": 0,
+ "hidden_channels": 8,
+ "depth": 4,
+ }
+
+ def get_vq_and_discriminator_configs(self, tmpdir):
+ vqmodel_config_path = os.path.join(tmpdir, "vqmodel.json")
+ discriminator_config_path = os.path.join(tmpdir, "discriminator.json")
+ with open(vqmodel_config_path, "w") as fp:
+ json.dump(self.test_vqmodel_config, fp)
+ with open(discriminator_config_path, "w") as fp:
+ json.dump(self.test_discriminator_config, fp)
+ return vqmodel_config_path, discriminator_config_path
+
+ def test_vqmodel(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ vqmodel_config_path, discriminator_config_path = self.get_vq_and_discriminator_configs(tmpdir)
+ test_args = f"""
+ examples/vqgan/train_vqgan.py
+ --dataset_name hf-internal-testing/dummy_image_text_data
+ --resolution 32
+ --image_column image
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --model_config_name_or_path {vqmodel_config_path}
+ --discriminator_config_name_or_path {discriminator_config_path}
+ --output_dir {tmpdir}
+ """.split()
+
+ run_command(self._launch_args + test_args)
+ # save_pretrained smoke test
+ self.assertTrue(
+ os.path.isfile(os.path.join(tmpdir, "discriminator", "diffusion_pytorch_model.safetensors"))
+ )
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "vqmodel", "diffusion_pytorch_model.safetensors")))
+
+ def test_vqmodel_checkpointing(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ vqmodel_config_path, discriminator_config_path = self.get_vq_and_discriminator_configs(tmpdir)
+ # Run training script with checkpointing
+ # max_train_steps == 4, checkpointing_steps == 2
+ # Should create checkpoints at steps 2, 4
+
+ initial_run_args = f"""
+ examples/vqgan/train_vqgan.py
+ --dataset_name hf-internal-testing/dummy_image_text_data
+ --resolution 32
+ --image_column image
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 4
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --model_config_name_or_path {vqmodel_config_path}
+ --discriminator_config_name_or_path {discriminator_config_path}
+ --checkpointing_steps=2
+ --output_dir {tmpdir}
+ --seed=0
+ """.split()
+
+ run_command(self._launch_args + initial_run_args)
+
+ # check checkpoint directories exist
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-2", "checkpoint-4"},
+ )
+
+ # check can run an intermediate checkpoint
+ model = VQModel.from_pretrained(tmpdir, subfolder="checkpoint-2/vqmodel")
+ image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
+ _ = model(image)
+
+ # Remove checkpoint 2 so that we can check only later checkpoints exist after resuming
+ shutil.rmtree(os.path.join(tmpdir, "checkpoint-2"))
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-4"},
+ )
+
+ # Run training script for 2 total steps resuming from checkpoint 4
+
+ resume_run_args = f"""
+ examples/vqgan/train_vqgan.py
+ --dataset_name hf-internal-testing/dummy_image_text_data
+ --resolution 32
+ --image_column image
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 6
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --model_config_name_or_path {vqmodel_config_path}
+ --discriminator_config_name_or_path {discriminator_config_path}
+ --checkpointing_steps=1
+ --resume_from_checkpoint={os.path.join(tmpdir, 'checkpoint-4')}
+ --output_dir {tmpdir}
+ --seed=0
+ """.split()
+
+ run_command(self._launch_args + resume_run_args)
+
+ # check can run new fully trained pipeline
+ model = VQModel.from_pretrained(tmpdir, subfolder="vqmodel")
+ image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
+ _ = model(image)
+
+ # no checkpoint-2 -> check old checkpoints do not exist
+ # check new checkpoints exist
+ # In the current script, checkpointing_steps 1 is equivalent to checkpointing_steps 2 as after the generator gets trained for one step,
+ # the discriminator gets trained and loss and saving happens after that. Thus we do not expect to get a checkpoint-5
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-4", "checkpoint-6"},
+ )
+
+ def test_vqmodel_checkpointing_use_ema(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ vqmodel_config_path, discriminator_config_path = self.get_vq_and_discriminator_configs(tmpdir)
+ # Run training script with checkpointing
+ # max_train_steps == 4, checkpointing_steps == 2
+ # Should create checkpoints at steps 2, 4
+
+ initial_run_args = f"""
+ examples/vqgan/train_vqgan.py
+ --dataset_name hf-internal-testing/dummy_image_text_data
+ --resolution 32
+ --image_column image
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 4
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --model_config_name_or_path {vqmodel_config_path}
+ --discriminator_config_name_or_path {discriminator_config_path}
+ --checkpointing_steps=2
+ --output_dir {tmpdir}
+ --use_ema
+ --seed=0
+ """.split()
+
+ run_command(self._launch_args + initial_run_args)
+
+ model = VQModel.from_pretrained(tmpdir, subfolder="vqmodel")
+ image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
+ _ = model(image)
+
+ # check checkpoint directories exist
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-2", "checkpoint-4"},
+ )
+
+ # check can run an intermediate checkpoint
+ model = VQModel.from_pretrained(tmpdir, subfolder="checkpoint-2/vqmodel")
+ image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
+ _ = model(image)
+
+ # Remove checkpoint 2 so that we can check only later checkpoints exist after resuming
+ shutil.rmtree(os.path.join(tmpdir, "checkpoint-2"))
+
+ # Run training script for 2 total steps resuming from checkpoint 4
+
+ resume_run_args = f"""
+ examples/vqgan/train_vqgan.py
+ --dataset_name hf-internal-testing/dummy_image_text_data
+ --resolution 32
+ --image_column image
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 6
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --model_config_name_or_path {vqmodel_config_path}
+ --discriminator_config_name_or_path {discriminator_config_path}
+ --checkpointing_steps=1
+ --resume_from_checkpoint={os.path.join(tmpdir, 'checkpoint-4')}
+ --output_dir {tmpdir}
+ --use_ema
+ --seed=0
+ """.split()
+
+ run_command(self._launch_args + resume_run_args)
+
+ # check can run new fully trained pipeline
+ model = VQModel.from_pretrained(tmpdir, subfolder="vqmodel")
+ image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
+ _ = model(image)
+
+ # no checkpoint-2 -> check old checkpoints do not exist
+ # check new checkpoints exist
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-4", "checkpoint-6"},
+ )
+
+ def test_vqmodel_checkpointing_checkpoints_total_limit(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ vqmodel_config_path, discriminator_config_path = self.get_vq_and_discriminator_configs(tmpdir)
+ # Run training script with checkpointing
+ # max_train_steps == 6, checkpointing_steps == 2, checkpoints_total_limit == 2
+ # Should create checkpoints at steps 2, 4, 6
+ # with checkpoint at step 2 deleted
+
+ initial_run_args = f"""
+ examples/vqgan/train_vqgan.py
+ --dataset_name hf-internal-testing/dummy_image_text_data
+ --resolution 32
+ --image_column image
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 6
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --model_config_name_or_path {vqmodel_config_path}
+ --discriminator_config_name_or_path {discriminator_config_path}
+ --output_dir {tmpdir}
+ --checkpointing_steps=2
+ --checkpoints_total_limit=2
+ --seed=0
+ """.split()
+
+ run_command(self._launch_args + initial_run_args)
+
+ model = VQModel.from_pretrained(tmpdir, subfolder="vqmodel")
+ image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
+ _ = model(image)
+
+ # check checkpoint directories exist
+ # checkpoint-2 should have been deleted
+ self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-4", "checkpoint-6"})
+
+ def test_vqmodel_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ vqmodel_config_path, discriminator_config_path = self.get_vq_and_discriminator_configs(tmpdir)
+ # Run training script with checkpointing
+ # max_train_steps == 4, checkpointing_steps == 2
+ # Should create checkpoints at steps 2, 4
+
+ initial_run_args = f"""
+ examples/vqgan/train_vqgan.py
+ --dataset_name hf-internal-testing/dummy_image_text_data
+ --resolution 32
+ --image_column image
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 4
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --model_config_name_or_path {vqmodel_config_path}
+ --discriminator_config_name_or_path {discriminator_config_path}
+ --checkpointing_steps=2
+ --output_dir {tmpdir}
+ --seed=0
+ """.split()
+
+ run_command(self._launch_args + initial_run_args)
+
+ model = VQModel.from_pretrained(tmpdir, subfolder="vqmodel")
+ image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
+ _ = model(image)
+
+ # check checkpoint directories exist
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-2", "checkpoint-4"},
+ )
+
+ # resume and we should try to checkpoint at 6, where we'll have to remove
+ # checkpoint-2 and checkpoint-4 instead of just a single previous checkpoint
+
+ resume_run_args = f"""
+ examples/vqgan/train_vqgan.py
+ --dataset_name hf-internal-testing/dummy_image_text_data
+ --resolution 32
+ --image_column image
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 8
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --model_config_name_or_path {vqmodel_config_path}
+ --discriminator_config_name_or_path {discriminator_config_path}
+ --output_dir {tmpdir}
+ --checkpointing_steps=2
+ --resume_from_checkpoint={os.path.join(tmpdir, 'checkpoint-4')}
+ --checkpoints_total_limit=2
+ --seed=0
+ """.split()
+
+ run_command(self._launch_args + resume_run_args)
+
+ model = VQModel.from_pretrained(tmpdir, subfolder="vqmodel")
+ image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
+ _ = model(image)
+
+ # check checkpoint directories exist
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-6", "checkpoint-8"},
+ )
diff --git a/diffusers/examples/vqgan/train_vqgan.py b/diffusers/examples/vqgan/train_vqgan.py
new file mode 100644
index 0000000000000000000000000000000000000000..d16dce921896aa9cced7944f29bbfa2adf99e507
--- /dev/null
+++ b/diffusers/examples/vqgan/train_vqgan.py
@@ -0,0 +1,1067 @@
+# coding=utf-8
+# Copyright 2023 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import math
+import os
+import shutil
+import time
+from pathlib import Path
+
+import accelerate
+import numpy as np
+import PIL
+import PIL.Image
+import timm
+import torch
+import torch.nn.functional as F
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import DistributedType, ProjectConfiguration, set_seed
+from datasets import load_dataset
+from discriminator import Discriminator
+from huggingface_hub import create_repo
+from packaging import version
+from PIL import Image
+from timm.data import resolve_data_config
+from timm.data.transforms_factory import create_transform
+from torchvision import transforms
+from tqdm import tqdm
+
+from diffusers import VQModel
+from diffusers.optimization import get_scheduler
+from diffusers.training_utils import EMAModel
+from diffusers.utils import check_min_version, is_wandb_available
+
+
+if is_wandb_available():
+ import wandb
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.31.0.dev0")
+
+logger = get_logger(__name__, log_level="INFO")
+
+
+class AverageMeter(object):
+ """Computes and stores the average and current value"""
+
+ def __init__(self):
+ self.reset()
+
+ def reset(self):
+ self.val = 0
+ self.avg = 0
+ self.sum = 0
+ self.count = 0
+
+ def update(self, val, n=1):
+ self.val = val
+ self.sum += val * n
+ self.count += n
+ self.avg = self.sum / self.count
+
+
+def _map_layer_to_idx(backbone, layers, offset=0):
+ """Maps set of layer names to indices of model. Ported from anomalib
+
+ Returns:
+ Feature map extracted from the CNN
+ """
+ idx = []
+ features = timm.create_model(
+ backbone,
+ pretrained=False,
+ features_only=False,
+ exportable=True,
+ )
+ for i in layers:
+ try:
+ idx.append(list(dict(features.named_children()).keys()).index(i) - offset)
+ except ValueError:
+ raise ValueError(
+ f"Layer {i} not found in model {backbone}. Select layer from {list(dict(features.named_children()).keys())}. The network architecture is {features}"
+ )
+ return idx
+
+
+def get_perceptual_loss(pixel_values, fmap, timm_model, timm_model_resolution, timm_model_normalization):
+ img_timm_model_input = timm_model_normalization(F.interpolate(pixel_values, timm_model_resolution))
+ fmap_timm_model_input = timm_model_normalization(F.interpolate(fmap, timm_model_resolution))
+
+ if pixel_values.shape[1] == 1:
+ # handle grayscale for timm_model
+ img_timm_model_input, fmap_timm_model_input = (
+ t.repeat(1, 3, 1, 1) for t in (img_timm_model_input, fmap_timm_model_input)
+ )
+
+ img_timm_model_feats = timm_model(img_timm_model_input)
+ recon_timm_model_feats = timm_model(fmap_timm_model_input)
+ perceptual_loss = F.mse_loss(img_timm_model_feats[0], recon_timm_model_feats[0])
+ for i in range(1, len(img_timm_model_feats)):
+ perceptual_loss += F.mse_loss(img_timm_model_feats[i], recon_timm_model_feats[i])
+ perceptual_loss /= len(img_timm_model_feats)
+ return perceptual_loss
+
+
+def grad_layer_wrt_loss(loss, layer):
+ return torch.autograd.grad(
+ outputs=loss,
+ inputs=layer,
+ grad_outputs=torch.ones_like(loss),
+ retain_graph=True,
+ )[0].detach()
+
+
+def gradient_penalty(images, output, weight=10):
+ gradients = torch.autograd.grad(
+ outputs=output,
+ inputs=images,
+ grad_outputs=torch.ones(output.size(), device=images.device),
+ create_graph=True,
+ retain_graph=True,
+ only_inputs=True,
+ )[0]
+ bsz = gradients.shape[0]
+ gradients = torch.reshape(gradients, (bsz, -1))
+ return weight * ((gradients.norm(2, dim=1) - 1) ** 2).mean()
+
+
+@torch.no_grad()
+def log_validation(model, args, validation_transform, accelerator, global_step):
+ logger.info("Generating images...")
+ dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ dtype = torch.bfloat16
+ original_images = []
+ for image_path in args.validation_images:
+ image = PIL.Image.open(image_path)
+ if not image.mode == "RGB":
+ image = image.convert("RGB")
+ image = validation_transform(image).to(accelerator.device, dtype=dtype)
+ original_images.append(image[None])
+ # Generate images
+ model.eval()
+ images = []
+ for original_image in original_images:
+ image = accelerator.unwrap_model(model)(original_image).sample
+ images.append(image)
+ model.train()
+ original_images = torch.cat(original_images, dim=0)
+ images = torch.cat(images, dim=0)
+
+ # Convert to PIL images
+ images = torch.clamp(images, 0.0, 1.0)
+ original_images = torch.clamp(original_images, 0.0, 1.0)
+ images *= 255.0
+ original_images *= 255.0
+ images = images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8)
+ original_images = original_images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8)
+ images = np.concatenate([original_images, images], axis=2)
+ images = [Image.fromarray(image) for image in images]
+
+ # Log images
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images("validation", np_images, global_step, dataformats="NHWC")
+ if tracker.name == "wandb":
+ tracker.log(
+ {
+ "validation": [
+ wandb.Image(image, caption=f"{i}: Original, Generated") for i, image in enumerate(images)
+ ]
+ },
+ step=global_step,
+ )
+ torch.cuda.empty_cache()
+ return images
+
+
+def log_grad_norm(model, accelerator, global_step):
+ for name, param in model.named_parameters():
+ if param.grad is not None:
+ grads = param.grad.detach().data
+ grad_norm = (grads.norm(p=2) / grads.numel()).item()
+ accelerator.log({"grad_norm/" + name: grad_norm}, step=global_step)
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ parser.add_argument(
+ "--log_grad_norm_steps",
+ type=int,
+ default=500,
+ help=("Print logs of gradient norms every X steps."),
+ )
+ parser.add_argument(
+ "--log_steps",
+ type=int,
+ default=50,
+ help=("Print logs every X steps."),
+ )
+ parser.add_argument(
+ "--validation_steps",
+ type=int,
+ default=100,
+ help=(
+ "Run validation every X steps. Validation consists of running reconstruction on images in"
+ " `args.validation_images` and logging the reconstructed images."
+ ),
+ )
+ parser.add_argument(
+ "--vae_loss",
+ type=str,
+ default="l2",
+ help="The loss function for vae reconstruction loss.",
+ )
+ parser.add_argument(
+ "--timm_model_offset",
+ type=int,
+ default=0,
+ help="Offset of timm layers to indices.",
+ )
+ parser.add_argument(
+ "--timm_model_layers",
+ type=str,
+ default="head",
+ help="The layers to get output from in the timm model.",
+ )
+ parser.add_argument(
+ "--timm_model_backend",
+ type=str,
+ default="vgg19",
+ help="Timm model used to get the lpips loss",
+ )
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--model_config_name_or_path",
+ type=str,
+ default=None,
+ help="The config of the Vq model to train, leave as None to use standard Vq model configuration.",
+ )
+ parser.add_argument(
+ "--discriminator_config_name_or_path",
+ type=str,
+ default=None,
+ help="The config of the discriminator model to train, leave as None to use standard Vq model configuration.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--dataset_name",
+ type=str,
+ default=None,
+ help=(
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
+ " or to a folder containing files that 🤗 Datasets can understand."
+ ),
+ )
+ parser.add_argument(
+ "--dataset_config_name",
+ type=str,
+ default=None,
+ help="The config of the Dataset, leave as None if there's only one config.",
+ )
+ parser.add_argument(
+ "--train_data_dir",
+ type=str,
+ default=None,
+ help=(
+ "A folder containing the training data. Folder contents must follow the structure described in"
+ " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
+ " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
+ ),
+ )
+ parser.add_argument(
+ "--image_column", type=str, default="image", help="The column of the dataset containing an image."
+ )
+ parser.add_argument(
+ "--max_train_samples",
+ type=int,
+ default=None,
+ help=(
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ ),
+ )
+ parser.add_argument(
+ "--validation_images",
+ type=str,
+ default=None,
+ nargs="+",
+ help=("A set of validation images evaluated every `--validation_steps` and logged to `--report_to`."),
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="vqgan-output",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--center_crop",
+ default=False,
+ action="store_true",
+ help=(
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
+ " cropped. The images will be resized to the resolution first before cropping."
+ ),
+ )
+ parser.add_argument(
+ "--random_flip",
+ action="store_true",
+ help="whether to randomly flip images horizontally",
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=100)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--discr_learning_rate",
+ type=float,
+ default=1e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--discr_lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")
+ parser.add_argument(
+ "--non_ema_revision",
+ type=str,
+ default=None,
+ required=False,
+ help=(
+ "Revision of pretrained non-ema model identifier. Must be a branch, tag or git identifier of the local or"
+ " remote repository specified with --pretrained_model_name_or_path."
+ ),
+ )
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--prediction_type",
+ type=str,
+ default=None,
+ help="The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediciton_type` is chosen.",
+ )
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
+ )
+ parser.add_argument(
+ "--tracker_project_name",
+ type=str,
+ default="vqgan-training",
+ help=(
+ "The `project_name` argument passed to Accelerator.init_trackers for"
+ " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
+ ),
+ )
+
+ args = parser.parse_args()
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ # Sanity checks
+ if args.dataset_name is None and args.train_data_dir is None:
+ raise ValueError("Need either a dataset name or a training folder.")
+
+ # default to using the same revision for the non-ema model if not specified
+ if args.non_ema_revision is None:
+ args.non_ema_revision = args.revision
+
+ return args
+
+
+def main():
+ #########################
+ # SETUP Accelerator #
+ #########################
+ args = parse_args()
+
+ # Enable TF32 on Ampere GPUs
+ if args.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+ torch.backends.cudnn.benchmark = True
+ torch.backends.cudnn.deterministic = False
+
+ logging_dir = os.path.join(args.output_dir, args.logging_dir)
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ )
+
+ if accelerator.distributed_type == DistributedType.DEEPSPEED:
+ accelerator.state.deepspeed_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] = args.train_batch_size
+
+ #####################################
+ # SETUP LOGGING, SEED and CONFIG #
+ #####################################
+
+ if accelerator.is_main_process:
+ tracker_config = dict(vars(args))
+ tracker_config.pop("validation_images")
+ accelerator.init_trackers(args.tracker_project_name, tracker_config)
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
+ ).repo_id
+
+ #########################
+ # MODELS and OPTIMIZER #
+ #########################
+ logger.info("Loading models and optimizer")
+
+ if args.model_config_name_or_path is None and args.pretrained_model_name_or_path is None:
+ # Taken from config of movq at kandinsky-community/kandinsky-2-2-decoder but without the attention layers
+ model = VQModel(
+ act_fn="silu",
+ block_out_channels=[
+ 128,
+ 256,
+ 512,
+ ],
+ down_block_types=[
+ "DownEncoderBlock2D",
+ "DownEncoderBlock2D",
+ "DownEncoderBlock2D",
+ ],
+ in_channels=3,
+ latent_channels=4,
+ layers_per_block=2,
+ norm_num_groups=32,
+ norm_type="spatial",
+ num_vq_embeddings=16384,
+ out_channels=3,
+ sample_size=32,
+ scaling_factor=0.18215,
+ up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D"],
+ vq_embed_dim=4,
+ )
+ elif args.pretrained_model_name_or_path is not None:
+ model = VQModel.from_pretrained(args.pretrained_model_name_or_path)
+ else:
+ config = VQModel.load_config(args.model_config_name_or_path)
+ model = VQModel.from_config(config)
+ if args.use_ema:
+ ema_model = EMAModel(model.parameters(), model_cls=VQModel, model_config=model.config)
+ if args.discriminator_config_name_or_path is None:
+ discriminator = Discriminator()
+ else:
+ config = Discriminator.load_config(args.discriminator_config_name_or_path)
+ discriminator = Discriminator.from_config(config)
+
+ idx = _map_layer_to_idx(args.timm_model_backend, args.timm_model_layers.split("|"), args.timm_model_offset)
+
+ timm_model = timm.create_model(
+ args.timm_model_backend,
+ pretrained=True,
+ features_only=True,
+ exportable=True,
+ out_indices=idx,
+ )
+ timm_model = timm_model.to(accelerator.device)
+ timm_model.requires_grad = False
+ timm_model.eval()
+ timm_transform = create_transform(**resolve_data_config(timm_model.pretrained_cfg, model=timm_model))
+ try:
+ # Gets the resolution of the timm transformation after centercrop
+ timm_centercrop_transform = timm_transform.transforms[1]
+ assert isinstance(
+ timm_centercrop_transform, transforms.CenterCrop
+ ), f"Timm model {timm_model} is currently incompatible with this script. Try vgg19."
+ timm_model_resolution = timm_centercrop_transform.size[0]
+ # Gets final normalization
+ timm_model_normalization = timm_transform.transforms[-1]
+ assert isinstance(
+ timm_model_normalization, transforms.Normalize
+ ), f"Timm model {timm_model} is currently incompatible with this script. Try vgg19."
+ except AssertionError as e:
+ raise NotImplementedError(e)
+ # Enable flash attention if asked
+ if args.enable_xformers_memory_efficient_attention:
+ model.enable_xformers_memory_efficient_attention()
+
+ # `accelerate` 0.16.0 will have better support for customized saving
+ if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
+ def save_model_hook(models, weights, output_dir):
+ if accelerator.is_main_process:
+ if args.use_ema:
+ ema_model.save_pretrained(os.path.join(output_dir, "vqmodel_ema"))
+ vqmodel = models[0]
+ discriminator = models[1]
+ vqmodel.save_pretrained(os.path.join(output_dir, "vqmodel"))
+ discriminator.save_pretrained(os.path.join(output_dir, "discriminator"))
+ weights.pop()
+ weights.pop()
+
+ def load_model_hook(models, input_dir):
+ if args.use_ema:
+ load_model = EMAModel.from_pretrained(os.path.join(input_dir, "vqmodel_ema"), VQModel)
+ ema_model.load_state_dict(load_model.state_dict())
+ ema_model.to(accelerator.device)
+ del load_model
+ discriminator = models.pop()
+ load_model = Discriminator.from_pretrained(input_dir, subfolder="discriminator")
+ discriminator.register_to_config(**load_model.config)
+ discriminator.load_state_dict(load_model.state_dict())
+ del load_model
+ vqmodel = models.pop()
+ load_model = VQModel.from_pretrained(input_dir, subfolder="vqmodel")
+ vqmodel.register_to_config(**load_model.config)
+ vqmodel.load_state_dict(load_model.state_dict())
+ del load_model
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
+
+ learning_rate = args.learning_rate
+ if args.scale_lr:
+ learning_rate = (
+ learning_rate * args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+ )
+
+ # Initialize the optimizer
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
+ )
+
+ optimizer_cls = bnb.optim.AdamW8bit
+ else:
+ optimizer_cls = torch.optim.AdamW
+
+ optimizer = optimizer_cls(
+ list(model.parameters()),
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+ discr_optimizer = optimizer_cls(
+ list(discriminator.parameters()),
+ lr=args.discr_learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ ##################################
+ # DATLOADER and LR-SCHEDULER #
+ #################################
+ logger.info("Creating dataloaders and lr_scheduler")
+
+ args.train_batch_size * accelerator.num_processes
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ # DataLoaders creation:
+ if args.dataset_name is not None:
+ # Downloading and loading a dataset from the hub.
+ dataset = load_dataset(
+ args.dataset_name,
+ args.dataset_config_name,
+ cache_dir=args.cache_dir,
+ data_dir=args.train_data_dir,
+ )
+ else:
+ data_files = {}
+ if args.train_data_dir is not None:
+ data_files["train"] = os.path.join(args.train_data_dir, "**")
+ dataset = load_dataset(
+ "imagefolder",
+ data_files=data_files,
+ cache_dir=args.cache_dir,
+ )
+ # See more about loading custom images at
+ # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
+
+ # Preprocessing the datasets.
+ # We need to tokenize inputs and targets.
+ column_names = dataset["train"].column_names
+
+ # 6. Get the column names for input/target.
+ assert args.image_column is not None
+ image_column = args.image_column
+ if image_column not in column_names:
+ raise ValueError(f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}")
+ # Preprocessing the datasets.
+ train_transforms = transforms.Compose(
+ [
+ transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
+ transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
+ transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
+ transforms.ToTensor(),
+ ]
+ )
+ validation_transform = transforms.Compose(
+ [
+ transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
+ transforms.ToTensor(),
+ ]
+ )
+
+ def preprocess_train(examples):
+ images = [image.convert("RGB") for image in examples[image_column]]
+ examples["pixel_values"] = [train_transforms(image) for image in images]
+ return examples
+
+ with accelerator.main_process_first():
+ if args.max_train_samples is not None:
+ dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
+ train_dataset = dataset["train"].with_transform(preprocess_train)
+
+ def collate_fn(examples):
+ pixel_values = torch.stack([example["pixel_values"] for example in examples])
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
+ return {"pixel_values": pixel_values}
+
+ # DataLoaders creation:
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset,
+ shuffle=True,
+ collate_fn=collate_fn,
+ batch_size=args.train_batch_size,
+ num_workers=args.dataloader_num_workers,
+ )
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_training_steps=args.max_train_steps,
+ num_warmup_steps=args.lr_warmup_steps,
+ )
+ discr_lr_scheduler = get_scheduler(
+ args.discr_lr_scheduler,
+ optimizer=discr_optimizer,
+ num_training_steps=args.max_train_steps,
+ num_warmup_steps=args.lr_warmup_steps,
+ )
+
+ # Prepare everything with accelerator
+ logger.info("Preparing model, optimizer and dataloaders")
+ # The dataloader are already aware of distributed training, so we don't need to prepare them.
+ model, discriminator, optimizer, discr_optimizer, lr_scheduler, discr_lr_scheduler = accelerator.prepare(
+ model, discriminator, optimizer, discr_optimizer, lr_scheduler, discr_lr_scheduler
+ )
+ if args.use_ema:
+ ema_model.to(accelerator.device)
+ # Train!
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # Potentially load in the weights and states from a previous save
+ resume_from_checkpoint = args.resume_from_checkpoint
+ if resume_from_checkpoint:
+ if resume_from_checkpoint != "latest":
+ path = resume_from_checkpoint
+ else:
+ # Get the most recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+ path = os.path.join(args.output_dir, path)
+
+ if path is None:
+ accelerator.print(f"Checkpoint '{resume_from_checkpoint}' does not exist. Starting a new training run.")
+ resume_from_checkpoint = None
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(path)
+ accelerator.wait_for_everyone()
+ global_step = int(os.path.basename(path).split("-")[1])
+ first_epoch = global_step // num_update_steps_per_epoch
+
+ batch_time_m = AverageMeter()
+ data_time_m = AverageMeter()
+ end = time.time()
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
+ # As stated above, we are not doing epoch based training here, but just using this for book keeping and being able to
+ # reuse the same training loop with other datasets/loaders.
+ avg_gen_loss, avg_discr_loss = None, None
+ for epoch in range(first_epoch, args.num_train_epochs):
+ model.train()
+ discriminator.train()
+ for i, batch in enumerate(train_dataloader):
+ pixel_values = batch["pixel_values"]
+ pixel_values = pixel_values.to(accelerator.device, non_blocking=True)
+ data_time_m.update(time.time() - end)
+ generator_step = ((i // args.gradient_accumulation_steps) % 2) == 0
+ # Train Step
+ # The behavior of accelerator.accumulate is to
+ # 1. Check if gradients are synced(reached gradient-accumulation_steps)
+ # 2. If so sync gradients by stopping the not syncing process
+ if generator_step:
+ optimizer.zero_grad(set_to_none=True)
+ else:
+ discr_optimizer.zero_grad(set_to_none=True)
+ # encode images to the latent space and get the commit loss from vq tokenization
+ # Return commit loss
+ fmap, commit_loss = model(pixel_values, return_dict=False)
+
+ if generator_step:
+ with accelerator.accumulate(model):
+ # reconstruction loss. Pixel level differences between input vs output
+ if args.vae_loss == "l2":
+ loss = F.mse_loss(pixel_values, fmap)
+ else:
+ loss = F.l1_loss(pixel_values, fmap)
+ # perceptual loss. The high level feature mean squared error loss
+ perceptual_loss = get_perceptual_loss(
+ pixel_values,
+ fmap,
+ timm_model,
+ timm_model_resolution=timm_model_resolution,
+ timm_model_normalization=timm_model_normalization,
+ )
+ # generator loss
+ gen_loss = -discriminator(fmap).mean()
+ last_dec_layer = accelerator.unwrap_model(model).decoder.conv_out.weight
+ norm_grad_wrt_perceptual_loss = grad_layer_wrt_loss(perceptual_loss, last_dec_layer).norm(p=2)
+ norm_grad_wrt_gen_loss = grad_layer_wrt_loss(gen_loss, last_dec_layer).norm(p=2)
+
+ adaptive_weight = norm_grad_wrt_perceptual_loss / norm_grad_wrt_gen_loss.clamp(min=1e-8)
+ adaptive_weight = adaptive_weight.clamp(max=1e4)
+ loss += commit_loss
+ loss += perceptual_loss
+ loss += adaptive_weight * gen_loss
+ # Gather the losses across all processes for logging (if we use distributed training).
+ avg_gen_loss = accelerator.gather(loss.repeat(args.train_batch_size)).float().mean()
+ accelerator.backward(loss)
+
+ if args.max_grad_norm is not None and accelerator.sync_gradients:
+ accelerator.clip_grad_norm_(model.parameters(), args.max_grad_norm)
+
+ optimizer.step()
+ lr_scheduler.step()
+ # log gradient norm before zeroing it
+ if (
+ accelerator.sync_gradients
+ and global_step % args.log_grad_norm_steps == 0
+ and accelerator.is_main_process
+ ):
+ log_grad_norm(model, accelerator, global_step)
+ else:
+ # Return discriminator loss
+ with accelerator.accumulate(discriminator):
+ fmap.detach_()
+ pixel_values.requires_grad_()
+ real = discriminator(pixel_values)
+ fake = discriminator(fmap)
+ loss = (F.relu(1 + fake) + F.relu(1 - real)).mean()
+ gp = gradient_penalty(pixel_values, real)
+ loss += gp
+ avg_discr_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
+ accelerator.backward(loss)
+
+ if args.max_grad_norm is not None and accelerator.sync_gradients:
+ accelerator.clip_grad_norm_(discriminator.parameters(), args.max_grad_norm)
+
+ discr_optimizer.step()
+ discr_lr_scheduler.step()
+ if (
+ accelerator.sync_gradients
+ and global_step % args.log_grad_norm_steps == 0
+ and accelerator.is_main_process
+ ):
+ log_grad_norm(discriminator, accelerator, global_step)
+ batch_time_m.update(time.time() - end)
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ global_step += 1
+ progress_bar.update(1)
+ if args.use_ema:
+ ema_model.step(model.parameters())
+ if accelerator.sync_gradients and not generator_step and accelerator.is_main_process:
+ # wait for both generator and discriminator to settle
+ # Log metrics
+ if global_step % args.log_steps == 0:
+ samples_per_second_per_gpu = (
+ args.gradient_accumulation_steps * args.train_batch_size / batch_time_m.val
+ )
+ logs = {
+ "step_discr_loss": avg_discr_loss.item(),
+ "lr": lr_scheduler.get_last_lr()[0],
+ "samples/sec/gpu": samples_per_second_per_gpu,
+ "data_time": data_time_m.val,
+ "batch_time": batch_time_m.val,
+ }
+ if avg_gen_loss is not None:
+ logs["step_gen_loss"] = avg_gen_loss.item()
+ accelerator.log(logs, step=global_step)
+
+ # resetting batch / data time meters per log window
+ batch_time_m.reset()
+ data_time_m.reset()
+ # Save model checkpoint
+ if global_step % args.checkpointing_steps == 0:
+ if accelerator.is_main_process:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ # Generate images
+ if global_step % args.validation_steps == 0:
+ if args.use_ema:
+ # Store the VQGAN parameters temporarily and load the EMA parameters to perform inference.
+ ema_model.store(model.parameters())
+ ema_model.copy_to(model.parameters())
+ log_validation(model, args, validation_transform, accelerator, global_step)
+ if args.use_ema:
+ # Switch back to the original VQGAN parameters.
+ ema_model.restore(model.parameters())
+ end = time.time()
+ # Stop training if max steps is reached
+ if global_step >= args.max_train_steps:
+ break
+ # End for
+
+ accelerator.wait_for_everyone()
+
+ # Save the final trained checkpoint
+ if accelerator.is_main_process:
+ model = accelerator.unwrap_model(model)
+ discriminator = accelerator.unwrap_model(discriminator)
+ if args.use_ema:
+ ema_model.copy_to(model.parameters())
+ model.save_pretrained(os.path.join(args.output_dir, "vqmodel"))
+ discriminator.save_pretrained(os.path.join(args.output_dir, "discriminator"))
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/diffusers/examples/wuerstchen/text_to_image/README.md b/diffusers/examples/wuerstchen/text_to_image/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..a6ec4698b6118216ab3d1293cafa58b1ef133024
--- /dev/null
+++ b/diffusers/examples/wuerstchen/text_to_image/README.md
@@ -0,0 +1,93 @@
+# Würstchen text-to-image fine-tuning
+
+## Running locally with PyTorch
+
+Before running the scripts, make sure to install the library's training dependencies:
+
+**Important**
+
+To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date. To do this, execute the following steps in a new virtual environment:
+```bash
+git clone https://github.com/huggingface/diffusers
+cd diffusers
+pip install .
+```
+
+Then cd into the example folder and run
+```bash
+cd examples/wuerstchen/text_to_image
+pip install -r requirements.txt
+```
+
+And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
+
+```bash
+accelerate config
+```
+For this example we want to directly store the trained LoRA embeddings on the Hub, so we need to be logged in and add the `--push_to_hub` flag to the training script. To log in, run:
+```bash
+huggingface-cli login
+```
+
+## Prior training
+
+You can fine-tune the Würstchen prior model with the `train_text_to_image_prior.py` script. Note that we currently support `--gradient_checkpointing` for prior model fine-tuning so you can use it for more GPU memory constrained setups.
+
+
+
+
+```bash
+export DATASET_NAME="lambdalabs/naruto-blip-captions"
+
+accelerate launch train_text_to_image_prior.py \
+ --mixed_precision="fp16" \
+ --dataset_name=$DATASET_NAME \
+ --resolution=768 \
+ --train_batch_size=4 \
+ --gradient_accumulation_steps=4 \
+ --gradient_checkpointing \
+ --dataloader_num_workers=4 \
+ --max_train_steps=15000 \
+ --learning_rate=1e-05 \
+ --max_grad_norm=1 \
+ --checkpoints_total_limit=3 \
+ --lr_scheduler="constant" --lr_warmup_steps=0 \
+ --validation_prompts="A robot naruto, 4k photo" \
+ --report_to="wandb" \
+ --push_to_hub \
+ --output_dir="wuerstchen-prior-naruto-model"
+```
+
+
+## Training with LoRA
+
+Low-Rank Adaption of Large Language Models (or LoRA) was first introduced by Microsoft in [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen*.
+
+In a nutshell, LoRA allows adapting pretrained models by adding pairs of rank-decomposition matrices to existing weights and **only** training those newly added weights. This has a couple of advantages:
+
+- Previous pretrained weights are kept frozen so that the model is not prone to [catastrophic forgetting](https://www.pnas.org/doi/10.1073/pnas.1611835114).
+- Rank-decomposition matrices have significantly fewer parameters than original model, which means that trained LoRA weights are easily portable.
+- LoRA attention layers allow to control to which extent the model is adapted toward new training images via a `scale` parameter.
+
+
+### Prior Training
+
+First, you need to set up your development environment as explained in the [installation](#Running-locally-with-PyTorch) section. Make sure to set the `DATASET_NAME` environment variable. Here, we will use the [Naruto captions dataset](https://huggingface.co/datasets/lambdalabs/naruto-blip-captions).
+
+```bash
+export DATASET_NAME="lambdalabs/naruto-blip-captions"
+
+accelerate launch train_text_to_image_lora_prior.py \
+ --mixed_precision="fp16" \
+ --dataset_name=$DATASET_NAME --caption_column="text" \
+ --resolution=768 \
+ --train_batch_size=8 \
+ --num_train_epochs=100 --checkpointing_steps=5000 \
+ --learning_rate=1e-04 --lr_scheduler="constant" --lr_warmup_steps=0 \
+ --seed=42 \
+ --rank=4 \
+ --validation_prompt="cute dragon creature" \
+ --report_to="wandb" \
+ --push_to_hub \
+ --output_dir="wuerstchen-prior-naruto-lora"
+```
diff --git a/diffusers/examples/wuerstchen/text_to_image/__init__.py b/diffusers/examples/wuerstchen/text_to_image/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/diffusers/examples/wuerstchen/text_to_image/modeling_efficient_net_encoder.py b/diffusers/examples/wuerstchen/text_to_image/modeling_efficient_net_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..bd551ebf1623af6d03cff998b258f19ee1294245
--- /dev/null
+++ b/diffusers/examples/wuerstchen/text_to_image/modeling_efficient_net_encoder.py
@@ -0,0 +1,23 @@
+import torch.nn as nn
+from torchvision.models import efficientnet_v2_l, efficientnet_v2_s
+
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.models.modeling_utils import ModelMixin
+
+
+class EfficientNetEncoder(ModelMixin, ConfigMixin):
+ @register_to_config
+ def __init__(self, c_latent=16, c_cond=1280, effnet="efficientnet_v2_s"):
+ super().__init__()
+
+ if effnet == "efficientnet_v2_s":
+ self.backbone = efficientnet_v2_s(weights="DEFAULT").features
+ else:
+ self.backbone = efficientnet_v2_l(weights="DEFAULT").features
+ self.mapper = nn.Sequential(
+ nn.Conv2d(c_cond, c_latent, kernel_size=1, bias=False),
+ nn.BatchNorm2d(c_latent), # then normalize them to have mean 0 and std 1
+ )
+
+ def forward(self, x):
+ return self.mapper(self.backbone(x))
diff --git a/diffusers/examples/wuerstchen/text_to_image/requirements.txt b/diffusers/examples/wuerstchen/text_to_image/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..dfe8f3db46a2abc0f1c5f2cbab8278116434ba50
--- /dev/null
+++ b/diffusers/examples/wuerstchen/text_to_image/requirements.txt
@@ -0,0 +1,7 @@
+accelerate>=0.16.0
+torchvision
+transformers>=4.25.1
+wandb
+bitsandbytes
+deepspeed
+peft>=0.6.0
diff --git a/diffusers/examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py b/diffusers/examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py
new file mode 100644
index 0000000000000000000000000000000000000000..b4c9a44bb5b25a63ce742eb543ee73d1b4087171
--- /dev/null
+++ b/diffusers/examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py
@@ -0,0 +1,949 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+import argparse
+import logging
+import math
+import os
+import random
+import shutil
+from pathlib import Path
+
+import datasets
+import numpy as np
+import torch
+import torch.nn.functional as F
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.state import AcceleratorState, is_initialized
+from accelerate.utils import ProjectConfiguration, set_seed
+from datasets import load_dataset
+from huggingface_hub import create_repo, hf_hub_download, upload_folder
+from modeling_efficient_net_encoder import EfficientNetEncoder
+from peft import LoraConfig
+from peft.utils import get_peft_model_state_dict
+from torchvision import transforms
+from tqdm import tqdm
+from transformers import CLIPTextModel, PreTrainedTokenizerFast
+from transformers.utils import ContextManagers
+
+from diffusers import AutoPipelineForText2Image, DDPMWuerstchenScheduler, WuerstchenPriorPipeline
+from diffusers.optimization import get_scheduler
+from diffusers.pipelines.wuerstchen import DEFAULT_STAGE_C_TIMESTEPS, WuerstchenPrior
+from diffusers.utils import check_min_version, is_wandb_available, make_image_grid
+from diffusers.utils.logging import set_verbosity_error, set_verbosity_info
+
+
+if is_wandb_available():
+ import wandb
+
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.31.0.dev0")
+
+logger = get_logger(__name__, log_level="INFO")
+
+DATASET_NAME_MAPPING = {
+ "lambdalabs/naruto-blip-captions": ("image", "text"),
+}
+
+
+def save_model_card(
+ args,
+ repo_id: str,
+ images=None,
+ repo_folder=None,
+):
+ img_str = ""
+ if len(images) > 0:
+ image_grid = make_image_grid(images, 1, len(args.validation_prompts))
+ image_grid.save(os.path.join(repo_folder, "val_imgs_grid.png"))
+ img_str += "\n"
+
+ yaml = f"""
+---
+license: mit
+base_model: {args.pretrained_prior_model_name_or_path}
+datasets:
+- {args.dataset_name}
+tags:
+- wuerstchen
+- text-to-image
+- diffusers
+- diffusers-training
+- lora
+inference: true
+---
+ """
+ model_card = f"""
+# LoRA Finetuning - {repo_id}
+
+This pipeline was finetuned from **{args.pretrained_prior_model_name_or_path}** on the **{args.dataset_name}** dataset. Below are some example images generated with the finetuned pipeline using the following prompts: {args.validation_prompts}: \n
+{img_str}
+
+## Pipeline usage
+
+You can use the pipeline like so:
+
+```python
+from diffusers import DiffusionPipeline
+import torch
+
+pipeline = AutoPipelineForText2Image.from_pretrained(
+ "{args.pretrained_decoder_model_name_or_path}", torch_dtype={args.weight_dtype}
+ )
+# load lora weights from folder:
+pipeline.prior_pipe.load_lora_weights("{repo_id}", torch_dtype={args.weight_dtype})
+
+image = pipeline(prompt=prompt).images[0]
+image.save("my_image.png")
+```
+
+## Training info
+
+These are the key hyperparameters used during training:
+
+* LoRA rank: {args.rank}
+* Epochs: {args.num_train_epochs}
+* Learning rate: {args.learning_rate}
+* Batch size: {args.train_batch_size}
+* Gradient accumulation steps: {args.gradient_accumulation_steps}
+* Image resolution: {args.resolution}
+* Mixed-precision: {args.mixed_precision}
+
+"""
+ wandb_info = ""
+ if is_wandb_available():
+ wandb_run_url = None
+ if wandb.run is not None:
+ wandb_run_url = wandb.run.url
+
+ if wandb_run_url is not None:
+ wandb_info = f"""
+More information on all the CLI arguments and the environment are available on your [`wandb` run page]({wandb_run_url}).
+"""
+
+ model_card += wandb_info
+
+ with open(os.path.join(repo_folder, "README.md"), "w") as f:
+ f.write(yaml + model_card)
+
+
+def log_validation(text_encoder, tokenizer, prior, args, accelerator, weight_dtype, epoch):
+ logger.info("Running validation... ")
+
+ pipeline = AutoPipelineForText2Image.from_pretrained(
+ args.pretrained_decoder_model_name_or_path,
+ prior=accelerator.unwrap_model(prior),
+ prior_text_encoder=accelerator.unwrap_model(text_encoder),
+ prior_tokenizer=tokenizer,
+ torch_dtype=weight_dtype,
+ )
+ pipeline = pipeline.to(accelerator.device)
+ pipeline.set_progress_bar_config(disable=True)
+
+ if args.seed is None:
+ generator = None
+ else:
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
+
+ images = []
+ for i in range(len(args.validation_prompts)):
+ with torch.cuda.amp.autocast():
+ image = pipeline(
+ args.validation_prompts[i],
+ prior_timesteps=DEFAULT_STAGE_C_TIMESTEPS,
+ generator=generator,
+ height=args.resolution,
+ width=args.resolution,
+ ).images[0]
+ images.append(image)
+
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
+ elif tracker.name == "wandb":
+ tracker.log(
+ {
+ "validation": [
+ wandb.Image(image, caption=f"{i}: {args.validation_prompts[i]}")
+ for i, image in enumerate(images)
+ ]
+ }
+ )
+ else:
+ logger.warning(f"image logging not implemented for {tracker.name}")
+
+ del pipeline
+ torch.cuda.empty_cache()
+
+ return images
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Simple example of finetuning Würstchen Prior.")
+ parser.add_argument(
+ "--rank",
+ type=int,
+ default=4,
+ help=("The dimension of the LoRA update matrices."),
+ )
+ parser.add_argument(
+ "--pretrained_decoder_model_name_or_path",
+ type=str,
+ default="warp-ai/wuerstchen",
+ required=False,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--pretrained_prior_model_name_or_path",
+ type=str,
+ default="warp-ai/wuerstchen-prior",
+ required=False,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--dataset_name",
+ type=str,
+ default=None,
+ help=(
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
+ " or to a folder containing files that 🤗 Datasets can understand."
+ ),
+ )
+ parser.add_argument(
+ "--dataset_config_name",
+ type=str,
+ default=None,
+ help="The config of the Dataset, leave as None if there's only one config.",
+ )
+ parser.add_argument(
+ "--train_data_dir",
+ type=str,
+ default=None,
+ help=(
+ "A folder containing the training data. Folder contents must follow the structure described in"
+ " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
+ " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
+ ),
+ )
+ parser.add_argument(
+ "--image_column", type=str, default="image", help="The column of the dataset containing an image."
+ )
+ parser.add_argument(
+ "--caption_column",
+ type=str,
+ default="text",
+ help="The column of the dataset containing a caption or a list of captions.",
+ )
+ parser.add_argument(
+ "--max_train_samples",
+ type=int,
+ default=None,
+ help=(
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ ),
+ )
+ parser.add_argument(
+ "--validation_prompts",
+ type=str,
+ default=None,
+ nargs="+",
+ help=("A set of prompts evaluated every `--validation_epochs` and logged to `--report_to`."),
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="wuerstchen-model-finetuned-lora",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=1, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=100)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="learning rate",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument(
+ "--adam_weight_decay",
+ type=float,
+ default=0.0,
+ required=False,
+ help="weight decay_to_use",
+ )
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--validation_epochs",
+ type=int,
+ default=5,
+ help="Run validation every X epochs.",
+ )
+ parser.add_argument(
+ "--tracker_project_name",
+ type=str,
+ default="text2image-fine-tune",
+ help=(
+ "The `project_name` argument passed to Accelerator.init_trackers for"
+ " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
+ ),
+ )
+
+ args = parser.parse_args()
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ # Sanity checks
+ if args.dataset_name is None and args.train_data_dir is None:
+ raise ValueError("Need either a dataset name or a training folder.")
+
+ return args
+
+
+def main():
+ args = parse_args()
+ if args.report_to == "wandb" and args.hub_token is not None:
+ raise ValueError(
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
+ " Please use `huggingface-cli login` to authenticate with the Hub."
+ )
+
+ logging_dir = os.path.join(args.output_dir, args.logging_dir)
+ accelerator_project_config = ProjectConfiguration(
+ total_limit=args.checkpoints_total_limit, project_dir=args.output_dir, logging_dir=logging_dir
+ )
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ )
+
+ # Disable AMP for MPS.
+ if torch.backends.mps.is_available():
+ accelerator.native_amp = False
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ datasets.utils.logging.set_verbosity_warning()
+ transformers.utils.logging.set_verbosity_warning()
+ set_verbosity_info()
+ else:
+ datasets.utils.logging.set_verbosity_error()
+ transformers.utils.logging.set_verbosity_error()
+ set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
+ ).repo_id
+
+ # Load scheduler, effnet, tokenizer, clip_model
+ noise_scheduler = DDPMWuerstchenScheduler()
+ tokenizer = PreTrainedTokenizerFast.from_pretrained(
+ args.pretrained_prior_model_name_or_path, subfolder="tokenizer"
+ )
+
+ def deepspeed_zero_init_disabled_context_manager():
+ """
+ returns either a context list that includes one that will disable zero.Init or an empty context list
+ """
+ deepspeed_plugin = AcceleratorState().deepspeed_plugin if is_initialized() else None
+ if deepspeed_plugin is None:
+ return []
+
+ return [deepspeed_plugin.zero3_init_context_manager(enable=False)]
+
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+ with ContextManagers(deepspeed_zero_init_disabled_context_manager()):
+ pretrained_checkpoint_file = hf_hub_download("dome272/wuerstchen", filename="model_v2_stage_b.pt")
+ state_dict = torch.load(pretrained_checkpoint_file, map_location="cpu")
+ image_encoder = EfficientNetEncoder()
+ image_encoder.load_state_dict(state_dict["effnet_state_dict"])
+ image_encoder.eval()
+
+ text_encoder = CLIPTextModel.from_pretrained(
+ args.pretrained_prior_model_name_or_path, subfolder="text_encoder", torch_dtype=weight_dtype
+ ).eval()
+
+ # Freeze text_encoder, cast to weight_dtype and image_encoder and move to device
+ text_encoder.requires_grad_(False)
+ image_encoder.requires_grad_(False)
+ image_encoder.to(accelerator.device, dtype=weight_dtype)
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
+
+ # load prior model, cast to weight_dtype and move to device
+ prior = WuerstchenPrior.from_pretrained(args.pretrained_prior_model_name_or_path, subfolder="prior")
+ prior.to(accelerator.device, dtype=weight_dtype)
+
+ # lora attn processor
+ prior_lora_config = LoraConfig(
+ r=args.rank,
+ lora_alpha=args.rank,
+ target_modules=["to_k", "to_q", "to_v", "to_out.0", "add_k_proj", "add_v_proj"],
+ )
+ # Add adapter and make sure the trainable params are in float32.
+ prior.add_adapter(prior_lora_config)
+ if args.mixed_precision == "fp16":
+ for param in prior.parameters():
+ # only upcast trainable parameters (LoRA) into fp32
+ if param.requires_grad:
+ param.data = param.to(torch.float32)
+
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
+ def save_model_hook(models, weights, output_dir):
+ if accelerator.is_main_process:
+ prior_lora_layers_to_save = None
+
+ for model in models:
+ if isinstance(model, type(accelerator.unwrap_model(prior))):
+ prior_lora_layers_to_save = get_peft_model_state_dict(model)
+ else:
+ raise ValueError(f"unexpected save model: {model.__class__}")
+
+ # make sure to pop weight so that corresponding model is not saved again
+ weights.pop()
+
+ WuerstchenPriorPipeline.save_lora_weights(
+ output_dir,
+ unet_lora_layers=prior_lora_layers_to_save,
+ )
+
+ def load_model_hook(models, input_dir):
+ prior_ = None
+
+ while len(models) > 0:
+ model = models.pop()
+
+ if isinstance(model, type(accelerator.unwrap_model(prior))):
+ prior_ = model
+ else:
+ raise ValueError(f"unexpected save model: {model.__class__}")
+
+ lora_state_dict, network_alphas = WuerstchenPriorPipeline.lora_state_dict(input_dir)
+ WuerstchenPriorPipeline.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=prior_)
+ WuerstchenPriorPipeline.load_lora_into_text_encoder(
+ lora_state_dict,
+ network_alphas=network_alphas,
+ )
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
+
+ if args.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
+ )
+
+ optimizer_cls = bnb.optim.AdamW8bit
+ else:
+ optimizer_cls = torch.optim.AdamW
+ params_to_optimize = list(filter(lambda p: p.requires_grad, prior.parameters()))
+ optimizer = optimizer_cls(
+ params_to_optimize,
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ # Get the datasets: you can either provide your own training and evaluation files (see below)
+ # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
+
+ # In distributed training, the load_dataset function guarantees that only one local process can concurrently
+ # download the dataset.
+ if args.dataset_name is not None:
+ # Downloading and loading a dataset from the hub.
+ dataset = load_dataset(
+ args.dataset_name,
+ args.dataset_config_name,
+ cache_dir=args.cache_dir,
+ )
+ else:
+ data_files = {}
+ if args.train_data_dir is not None:
+ data_files["train"] = os.path.join(args.train_data_dir, "**")
+ dataset = load_dataset(
+ "imagefolder",
+ data_files=data_files,
+ cache_dir=args.cache_dir,
+ )
+ # See more about loading custom images at
+ # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
+
+ # Preprocessing the datasets.
+ # We need to tokenize inputs and targets.
+ column_names = dataset["train"].column_names
+
+ # Get the column names for input/target.
+ dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None)
+ if args.image_column is None:
+ image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
+ else:
+ image_column = args.image_column
+ if image_column not in column_names:
+ raise ValueError(
+ f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}"
+ )
+ if args.caption_column is None:
+ caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
+ else:
+ caption_column = args.caption_column
+ if caption_column not in column_names:
+ raise ValueError(
+ f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}"
+ )
+
+ # Preprocessing the datasets.
+ # We need to tokenize input captions and transform the images
+ def tokenize_captions(examples, is_train=True):
+ captions = []
+ for caption in examples[caption_column]:
+ if isinstance(caption, str):
+ captions.append(caption)
+ elif isinstance(caption, (list, np.ndarray)):
+ # take a random caption if there are multiple
+ captions.append(random.choice(caption) if is_train else caption[0])
+ else:
+ raise ValueError(
+ f"Caption column `{caption_column}` should contain either strings or lists of strings."
+ )
+ inputs = tokenizer(
+ captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
+ )
+ text_input_ids = inputs.input_ids
+ text_mask = inputs.attention_mask.bool()
+ return text_input_ids, text_mask
+
+ effnet_transforms = transforms.Compose(
+ [
+ transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR, antialias=True),
+ transforms.CenterCrop(args.resolution),
+ transforms.ToTensor(),
+ transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
+ ]
+ )
+
+ def preprocess_train(examples):
+ images = [image.convert("RGB") for image in examples[image_column]]
+ examples["effnet_pixel_values"] = [effnet_transforms(image) for image in images]
+ examples["text_input_ids"], examples["text_mask"] = tokenize_captions(examples)
+ return examples
+
+ with accelerator.main_process_first():
+ if args.max_train_samples is not None:
+ dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
+ # Set the training transforms
+ train_dataset = dataset["train"].with_transform(preprocess_train)
+
+ def collate_fn(examples):
+ effnet_pixel_values = torch.stack([example["effnet_pixel_values"] for example in examples])
+ effnet_pixel_values = effnet_pixel_values.to(memory_format=torch.contiguous_format).float()
+ text_input_ids = torch.stack([example["text_input_ids"] for example in examples])
+ text_mask = torch.stack([example["text_mask"] for example in examples])
+ return {"effnet_pixel_values": effnet_pixel_values, "text_input_ids": text_input_ids, "text_mask": text_mask}
+
+ # DataLoaders creation:
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset,
+ shuffle=True,
+ collate_fn=collate_fn,
+ batch_size=args.train_batch_size,
+ num_workers=args.dataloader_num_workers,
+ )
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
+ num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
+ )
+
+ prior, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ prior, optimizer, train_dataloader, lr_scheduler
+ )
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ tracker_config = dict(vars(args))
+ tracker_config.pop("validation_prompts")
+ accelerator.init_trackers(args.tracker_project_name, tracker_config)
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the most recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ resume_global_step = global_step * args.gradient_accumulation_steps
+ first_epoch = global_step // num_update_steps_per_epoch
+ resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
+
+ # Only show the progress bar once on each machine.
+ progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
+ progress_bar.set_description("Steps")
+
+ for epoch in range(first_epoch, args.num_train_epochs):
+ prior.train()
+ train_loss = 0.0
+ for step, batch in enumerate(train_dataloader):
+ # Skip steps until we reach the resumed step
+ if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
+ if step % args.gradient_accumulation_steps == 0:
+ progress_bar.update(1)
+ continue
+
+ with accelerator.accumulate(prior):
+ # Convert images to latent space
+ text_input_ids, text_mask, effnet_images = (
+ batch["text_input_ids"],
+ batch["text_mask"],
+ batch["effnet_pixel_values"].to(weight_dtype),
+ )
+
+ with torch.no_grad():
+ text_encoder_output = text_encoder(text_input_ids, attention_mask=text_mask)
+ prompt_embeds = text_encoder_output.last_hidden_state
+ image_embeds = image_encoder(effnet_images)
+ # scale
+ image_embeds = image_embeds.add(1.0).div(42.0)
+
+ # Sample noise that we'll add to the image_embeds
+ noise = torch.randn_like(image_embeds)
+ bsz = image_embeds.shape[0]
+
+ # Sample a random timestep for each image
+ timesteps = torch.rand((bsz,), device=image_embeds.device, dtype=weight_dtype)
+
+ # add noise to latent
+ noisy_latents = noise_scheduler.add_noise(image_embeds, noise, timesteps)
+
+ # Predict the noise residual and compute losscd
+ pred_noise = prior(noisy_latents, timesteps, prompt_embeds)
+
+ # vanilla loss
+ loss = F.mse_loss(pred_noise.float(), noise.float(), reduction="mean")
+
+ # Gather the losses across all processes for logging (if we use distributed training).
+ avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
+ train_loss += avg_loss.item() / args.gradient_accumulation_steps
+
+ # Backpropagate
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ accelerator.clip_grad_norm_(params_to_optimize, args.max_grad_norm)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ global_step += 1
+ accelerator.log({"train_loss": train_loss}, step=global_step)
+ train_loss = 0.0
+
+ if global_step % args.checkpointing_steps == 0:
+ if accelerator.is_main_process:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ if accelerator.is_main_process:
+ if args.validation_prompts is not None and epoch % args.validation_epochs == 0:
+ log_validation(text_encoder, tokenizer, prior, args, accelerator, weight_dtype, global_step)
+
+ # Create the pipeline using the trained modules and save it.
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ prior = accelerator.unwrap_model(prior)
+ prior = prior.to(torch.float32)
+
+ prior_lora_state_dict = get_peft_model_state_dict(prior)
+
+ WuerstchenPriorPipeline.save_lora_weights(
+ save_directory=args.output_dir,
+ unet_lora_layers=prior_lora_state_dict,
+ )
+
+ # Run a final round of inference.
+ images = []
+ if args.validation_prompts is not None:
+ logger.info("Running inference for collecting generated images...")
+ pipeline = AutoPipelineForText2Image.from_pretrained(
+ args.pretrained_decoder_model_name_or_path,
+ prior_text_encoder=accelerator.unwrap_model(text_encoder),
+ prior_tokenizer=tokenizer,
+ torch_dtype=weight_dtype,
+ )
+ pipeline = pipeline.to(accelerator.device)
+
+ # load lora weights
+ pipeline.prior_pipe.load_lora_weights(args.output_dir, weight_name="pytorch_lora_weights.safetensors")
+ pipeline.set_progress_bar_config(disable=True)
+
+ if args.seed is None:
+ generator = None
+ else:
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
+
+ for i in range(len(args.validation_prompts)):
+ with torch.cuda.amp.autocast():
+ image = pipeline(
+ args.validation_prompts[i],
+ prior_timesteps=DEFAULT_STAGE_C_TIMESTEPS,
+ generator=generator,
+ width=args.resolution,
+ height=args.resolution,
+ ).images[0]
+ images.append(image)
+
+ if args.push_to_hub:
+ save_model_card(args, repo_id, images, repo_folder=args.output_dir)
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/diffusers/examples/wuerstchen/text_to_image/train_text_to_image_prior.py b/diffusers/examples/wuerstchen/text_to_image/train_text_to_image_prior.py
new file mode 100644
index 0000000000000000000000000000000000000000..eba8de69203a9f840d70d0ebf25a87ee8050e58d
--- /dev/null
+++ b/diffusers/examples/wuerstchen/text_to_image/train_text_to_image_prior.py
@@ -0,0 +1,936 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+import argparse
+import logging
+import math
+import os
+import random
+import shutil
+from pathlib import Path
+
+import accelerate
+import datasets
+import numpy as np
+import torch
+import torch.nn.functional as F
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.state import AcceleratorState, is_initialized
+from accelerate.utils import ProjectConfiguration, set_seed
+from datasets import load_dataset
+from huggingface_hub import create_repo, hf_hub_download, upload_folder
+from modeling_efficient_net_encoder import EfficientNetEncoder
+from packaging import version
+from torchvision import transforms
+from tqdm import tqdm
+from transformers import CLIPTextModel, PreTrainedTokenizerFast
+from transformers.utils import ContextManagers
+
+from diffusers import AutoPipelineForText2Image, DDPMWuerstchenScheduler
+from diffusers.optimization import get_scheduler
+from diffusers.pipelines.wuerstchen import DEFAULT_STAGE_C_TIMESTEPS, WuerstchenPrior
+from diffusers.training_utils import EMAModel
+from diffusers.utils import check_min_version, is_wandb_available, make_image_grid
+from diffusers.utils.logging import set_verbosity_error, set_verbosity_info
+
+
+if is_wandb_available():
+ import wandb
+
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.31.0.dev0")
+
+logger = get_logger(__name__, log_level="INFO")
+
+DATASET_NAME_MAPPING = {
+ "lambdalabs/naruto-blip-captions": ("image", "text"),
+}
+
+
+def save_model_card(
+ args,
+ repo_id: str,
+ images=None,
+ repo_folder=None,
+):
+ img_str = ""
+ if len(images) > 0:
+ image_grid = make_image_grid(images, 1, len(args.validation_prompts))
+ image_grid.save(os.path.join(repo_folder, "val_imgs_grid.png"))
+ img_str += "\n"
+
+ yaml = f"""
+---
+license: mit
+base_model: {args.pretrained_prior_model_name_or_path}
+datasets:
+- {args.dataset_name}
+tags:
+- wuerstchen
+- text-to-image
+- diffusers
+- diffusers-training
+inference: true
+---
+ """
+ model_card = f"""
+# Finetuning - {repo_id}
+
+This pipeline was finetuned from **{args.pretrained_prior_model_name_or_path}** on the **{args.dataset_name}** dataset. Below are some example images generated with the finetuned pipeline using the following prompts: {args.validation_prompts}: \n
+{img_str}
+
+## Pipeline usage
+
+You can use the pipeline like so:
+
+```python
+from diffusers import DiffusionPipeline
+import torch
+
+pipe_prior = DiffusionPipeline.from_pretrained("{repo_id}", torch_dtype={args.weight_dtype})
+pipe_t2i = DiffusionPipeline.from_pretrained("{args.pretrained_decoder_model_name_or_path}", torch_dtype={args.weight_dtype})
+prompt = "{args.validation_prompts[0]}"
+(image_embeds,) = pipe_prior(prompt).to_tuple()
+image = pipe_t2i(image_embeddings=image_embeds, prompt=prompt).images[0]
+image.save("my_image.png")
+```
+
+## Training info
+
+These are the key hyperparameters used during training:
+
+* Epochs: {args.num_train_epochs}
+* Learning rate: {args.learning_rate}
+* Batch size: {args.train_batch_size}
+* Gradient accumulation steps: {args.gradient_accumulation_steps}
+* Image resolution: {args.resolution}
+* Mixed-precision: {args.mixed_precision}
+
+"""
+ wandb_info = ""
+ if is_wandb_available():
+ wandb_run_url = None
+ if wandb.run is not None:
+ wandb_run_url = wandb.run.url
+
+ if wandb_run_url is not None:
+ wandb_info = f"""
+More information on all the CLI arguments and the environment are available on your [`wandb` run page]({wandb_run_url}).
+"""
+
+ model_card += wandb_info
+
+ with open(os.path.join(repo_folder, "README.md"), "w") as f:
+ f.write(yaml + model_card)
+
+
+def log_validation(text_encoder, tokenizer, prior, args, accelerator, weight_dtype, epoch):
+ logger.info("Running validation... ")
+
+ pipeline = AutoPipelineForText2Image.from_pretrained(
+ args.pretrained_decoder_model_name_or_path,
+ prior_prior=accelerator.unwrap_model(prior),
+ prior_text_encoder=accelerator.unwrap_model(text_encoder),
+ prior_tokenizer=tokenizer,
+ torch_dtype=weight_dtype,
+ )
+ pipeline = pipeline.to(accelerator.device)
+ pipeline.set_progress_bar_config(disable=True)
+
+ if args.seed is None:
+ generator = None
+ else:
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
+
+ images = []
+ for i in range(len(args.validation_prompts)):
+ with torch.autocast("cuda"):
+ image = pipeline(
+ args.validation_prompts[i],
+ prior_timesteps=DEFAULT_STAGE_C_TIMESTEPS,
+ generator=generator,
+ height=args.resolution,
+ width=args.resolution,
+ ).images[0]
+
+ images.append(image)
+
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
+ elif tracker.name == "wandb":
+ tracker.log(
+ {
+ "validation": [
+ wandb.Image(image, caption=f"{i}: {args.validation_prompts[i]}")
+ for i, image in enumerate(images)
+ ]
+ }
+ )
+ else:
+ logger.warning(f"image logging not implemented for {tracker.name}")
+
+ del pipeline
+ torch.cuda.empty_cache()
+
+ return images
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Simple example of finetuning Würstchen Prior.")
+ parser.add_argument(
+ "--pretrained_decoder_model_name_or_path",
+ type=str,
+ default="warp-ai/wuerstchen",
+ required=False,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--pretrained_prior_model_name_or_path",
+ type=str,
+ default="warp-ai/wuerstchen-prior",
+ required=False,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--dataset_name",
+ type=str,
+ default=None,
+ help=(
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
+ " or to a folder containing files that 🤗 Datasets can understand."
+ ),
+ )
+ parser.add_argument(
+ "--dataset_config_name",
+ type=str,
+ default=None,
+ help="The config of the Dataset, leave as None if there's only one config.",
+ )
+ parser.add_argument(
+ "--train_data_dir",
+ type=str,
+ default=None,
+ help=(
+ "A folder containing the training data. Folder contents must follow the structure described in"
+ " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
+ " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
+ ),
+ )
+ parser.add_argument(
+ "--image_column", type=str, default="image", help="The column of the dataset containing an image."
+ )
+ parser.add_argument(
+ "--caption_column",
+ type=str,
+ default="text",
+ help="The column of the dataset containing a caption or a list of captions.",
+ )
+ parser.add_argument(
+ "--max_train_samples",
+ type=int,
+ default=None,
+ help=(
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ ),
+ )
+ parser.add_argument(
+ "--validation_prompts",
+ type=str,
+ default=None,
+ nargs="+",
+ help=("A set of prompts evaluated every `--validation_epochs` and logged to `--report_to`."),
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="wuerstchen-model-finetuned",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=1, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=100)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="learning rate",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument(
+ "--adam_weight_decay",
+ type=float,
+ default=0.0,
+ required=False,
+ help="weight decay_to_use",
+ )
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--validation_epochs",
+ type=int,
+ default=5,
+ help="Run validation every X epochs.",
+ )
+ parser.add_argument(
+ "--tracker_project_name",
+ type=str,
+ default="text2image-fine-tune",
+ help=(
+ "The `project_name` argument passed to Accelerator.init_trackers for"
+ " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
+ ),
+ )
+
+ args = parser.parse_args()
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ # Sanity checks
+ if args.dataset_name is None and args.train_data_dir is None:
+ raise ValueError("Need either a dataset name or a training folder.")
+
+ return args
+
+
+def main():
+ args = parse_args()
+ if args.report_to == "wandb" and args.hub_token is not None:
+ raise ValueError(
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
+ " Please use `huggingface-cli login` to authenticate with the Hub."
+ )
+
+ logging_dir = os.path.join(args.output_dir, args.logging_dir)
+ accelerator_project_config = ProjectConfiguration(
+ total_limit=args.checkpoints_total_limit, project_dir=args.output_dir, logging_dir=logging_dir
+ )
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ )
+
+ # Disable AMP for MPS.
+ if torch.backends.mps.is_available():
+ accelerator.native_amp = False
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ datasets.utils.logging.set_verbosity_warning()
+ transformers.utils.logging.set_verbosity_warning()
+ set_verbosity_info()
+ else:
+ datasets.utils.logging.set_verbosity_error()
+ transformers.utils.logging.set_verbosity_error()
+ set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
+ ).repo_id
+
+ # Load scheduler, effnet, tokenizer, clip_model
+ noise_scheduler = DDPMWuerstchenScheduler()
+ tokenizer = PreTrainedTokenizerFast.from_pretrained(
+ args.pretrained_prior_model_name_or_path, subfolder="tokenizer"
+ )
+
+ def deepspeed_zero_init_disabled_context_manager():
+ """
+ returns either a context list that includes one that will disable zero.Init or an empty context list
+ """
+ deepspeed_plugin = AcceleratorState().deepspeed_plugin if is_initialized() else None
+ if deepspeed_plugin is None:
+ return []
+
+ return [deepspeed_plugin.zero3_init_context_manager(enable=False)]
+
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+ with ContextManagers(deepspeed_zero_init_disabled_context_manager()):
+ pretrained_checkpoint_file = hf_hub_download("dome272/wuerstchen", filename="model_v2_stage_b.pt")
+ state_dict = torch.load(pretrained_checkpoint_file, map_location="cpu")
+ image_encoder = EfficientNetEncoder()
+ image_encoder.load_state_dict(state_dict["effnet_state_dict"])
+ image_encoder.eval()
+
+ text_encoder = CLIPTextModel.from_pretrained(
+ args.pretrained_prior_model_name_or_path, subfolder="text_encoder", torch_dtype=weight_dtype
+ ).eval()
+
+ # Freeze text_encoder and image_encoder
+ text_encoder.requires_grad_(False)
+ image_encoder.requires_grad_(False)
+
+ # load prior model
+ prior = WuerstchenPrior.from_pretrained(args.pretrained_prior_model_name_or_path, subfolder="prior")
+
+ # Create EMA for the prior
+ if args.use_ema:
+ ema_prior = WuerstchenPrior.from_pretrained(args.pretrained_prior_model_name_or_path, subfolder="prior")
+ ema_prior = EMAModel(ema_prior.parameters(), model_cls=WuerstchenPrior, model_config=ema_prior.config)
+ ema_prior.to(accelerator.device)
+
+ # `accelerate` 0.16.0 will have better support for customized saving
+ if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
+ def save_model_hook(models, weights, output_dir):
+ if args.use_ema:
+ ema_prior.save_pretrained(os.path.join(output_dir, "prior_ema"))
+
+ for i, model in enumerate(models):
+ model.save_pretrained(os.path.join(output_dir, "prior"))
+
+ # make sure to pop weight so that corresponding model is not saved again
+ weights.pop()
+
+ def load_model_hook(models, input_dir):
+ if args.use_ema:
+ load_model = EMAModel.from_pretrained(os.path.join(input_dir, "prior_ema"), WuerstchenPrior)
+ ema_prior.load_state_dict(load_model.state_dict())
+ ema_prior.to(accelerator.device)
+ del load_model
+
+ for i in range(len(models)):
+ # pop models so that they are not loaded again
+ model = models.pop()
+
+ # load diffusers style into model
+ load_model = WuerstchenPrior.from_pretrained(input_dir, subfolder="prior")
+ model.register_to_config(**load_model.config)
+
+ model.load_state_dict(load_model.state_dict())
+ del load_model
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
+
+ if args.gradient_checkpointing:
+ prior.enable_gradient_checkpointing()
+
+ if args.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
+ )
+
+ optimizer_cls = bnb.optim.AdamW8bit
+ else:
+ optimizer_cls = torch.optim.AdamW
+ optimizer = optimizer_cls(
+ prior.parameters(),
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ # Get the datasets: you can either provide your own training and evaluation files (see below)
+ # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
+
+ # In distributed training, the load_dataset function guarantees that only one local process can concurrently
+ # download the dataset.
+ if args.dataset_name is not None:
+ # Downloading and loading a dataset from the hub.
+ dataset = load_dataset(
+ args.dataset_name,
+ args.dataset_config_name,
+ cache_dir=args.cache_dir,
+ )
+ else:
+ data_files = {}
+ if args.train_data_dir is not None:
+ data_files["train"] = os.path.join(args.train_data_dir, "**")
+ dataset = load_dataset(
+ "imagefolder",
+ data_files=data_files,
+ cache_dir=args.cache_dir,
+ )
+ # See more about loading custom images at
+ # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
+
+ # Preprocessing the datasets.
+ # We need to tokenize inputs and targets.
+ column_names = dataset["train"].column_names
+
+ # Get the column names for input/target.
+ dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None)
+ if args.image_column is None:
+ image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
+ else:
+ image_column = args.image_column
+ if image_column not in column_names:
+ raise ValueError(
+ f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}"
+ )
+ if args.caption_column is None:
+ caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
+ else:
+ caption_column = args.caption_column
+ if caption_column not in column_names:
+ raise ValueError(
+ f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}"
+ )
+
+ # Preprocessing the datasets.
+ # We need to tokenize input captions and transform the images
+ def tokenize_captions(examples, is_train=True):
+ captions = []
+ for caption in examples[caption_column]:
+ if isinstance(caption, str):
+ captions.append(caption)
+ elif isinstance(caption, (list, np.ndarray)):
+ # take a random caption if there are multiple
+ captions.append(random.choice(caption) if is_train else caption[0])
+ else:
+ raise ValueError(
+ f"Caption column `{caption_column}` should contain either strings or lists of strings."
+ )
+ inputs = tokenizer(
+ captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
+ )
+ text_input_ids = inputs.input_ids
+ text_mask = inputs.attention_mask.bool()
+ return text_input_ids, text_mask
+
+ effnet_transforms = transforms.Compose(
+ [
+ transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR, antialias=True),
+ transforms.CenterCrop(args.resolution),
+ transforms.ToTensor(),
+ transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
+ ]
+ )
+
+ def preprocess_train(examples):
+ images = [image.convert("RGB") for image in examples[image_column]]
+ examples["effnet_pixel_values"] = [effnet_transforms(image) for image in images]
+ examples["text_input_ids"], examples["text_mask"] = tokenize_captions(examples)
+ return examples
+
+ with accelerator.main_process_first():
+ if args.max_train_samples is not None:
+ dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
+ # Set the training transforms
+ train_dataset = dataset["train"].with_transform(preprocess_train)
+
+ def collate_fn(examples):
+ effnet_pixel_values = torch.stack([example["effnet_pixel_values"] for example in examples])
+ effnet_pixel_values = effnet_pixel_values.to(memory_format=torch.contiguous_format).float()
+ text_input_ids = torch.stack([example["text_input_ids"] for example in examples])
+ text_mask = torch.stack([example["text_mask"] for example in examples])
+ return {"effnet_pixel_values": effnet_pixel_values, "text_input_ids": text_input_ids, "text_mask": text_mask}
+
+ # DataLoaders creation:
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset,
+ shuffle=True,
+ collate_fn=collate_fn,
+ batch_size=args.train_batch_size,
+ num_workers=args.dataloader_num_workers,
+ )
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
+ num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
+ )
+
+ prior, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ prior, optimizer, train_dataloader, lr_scheduler
+ )
+ image_encoder.to(accelerator.device, dtype=weight_dtype)
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ tracker_config = dict(vars(args))
+ tracker_config.pop("validation_prompts")
+ accelerator.init_trackers(args.tracker_project_name, tracker_config)
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the most recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ resume_global_step = global_step * args.gradient_accumulation_steps
+ first_epoch = global_step // num_update_steps_per_epoch
+ resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
+
+ # Only show the progress bar once on each machine.
+ progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
+ progress_bar.set_description("Steps")
+
+ for epoch in range(first_epoch, args.num_train_epochs):
+ prior.train()
+ train_loss = 0.0
+ for step, batch in enumerate(train_dataloader):
+ # Skip steps until we reach the resumed step
+ if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
+ if step % args.gradient_accumulation_steps == 0:
+ progress_bar.update(1)
+ continue
+
+ with accelerator.accumulate(prior):
+ # Convert images to latent space
+ text_input_ids, text_mask, effnet_images = (
+ batch["text_input_ids"],
+ batch["text_mask"],
+ batch["effnet_pixel_values"].to(weight_dtype),
+ )
+
+ with torch.no_grad():
+ text_encoder_output = text_encoder(text_input_ids, attention_mask=text_mask)
+ prompt_embeds = text_encoder_output.last_hidden_state
+ image_embeds = image_encoder(effnet_images)
+ # scale
+ image_embeds = image_embeds.add(1.0).div(42.0)
+
+ # Sample noise that we'll add to the image_embeds
+ noise = torch.randn_like(image_embeds)
+ bsz = image_embeds.shape[0]
+
+ # Sample a random timestep for each image
+ timesteps = torch.rand((bsz,), device=image_embeds.device, dtype=weight_dtype)
+
+ # add noise to latent
+ noisy_latents = noise_scheduler.add_noise(image_embeds, noise, timesteps)
+
+ # Predict the noise residual and compute losscd
+ pred_noise = prior(noisy_latents, timesteps, prompt_embeds)
+
+ # vanilla loss
+ loss = F.mse_loss(pred_noise.float(), noise.float(), reduction="mean")
+
+ # Gather the losses across all processes for logging (if we use distributed training).
+ avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
+ train_loss += avg_loss.item() / args.gradient_accumulation_steps
+
+ # Backpropagate
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ accelerator.clip_grad_norm_(prior.parameters(), args.max_grad_norm)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ if args.use_ema:
+ ema_prior.step(prior.parameters())
+ progress_bar.update(1)
+ global_step += 1
+ accelerator.log({"train_loss": train_loss}, step=global_step)
+ train_loss = 0.0
+
+ if global_step % args.checkpointing_steps == 0:
+ if accelerator.is_main_process:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ if accelerator.is_main_process:
+ if args.validation_prompts is not None and epoch % args.validation_epochs == 0:
+ if args.use_ema:
+ # Store the UNet parameters temporarily and load the EMA parameters to perform inference.
+ ema_prior.store(prior.parameters())
+ ema_prior.copy_to(prior.parameters())
+ log_validation(text_encoder, tokenizer, prior, args, accelerator, weight_dtype, global_step)
+ if args.use_ema:
+ # Switch back to the original UNet parameters.
+ ema_prior.restore(prior.parameters())
+
+ # Create the pipeline using the trained modules and save it.
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ prior = accelerator.unwrap_model(prior)
+ if args.use_ema:
+ ema_prior.copy_to(prior.parameters())
+
+ pipeline = AutoPipelineForText2Image.from_pretrained(
+ args.pretrained_decoder_model_name_or_path,
+ prior_prior=prior,
+ prior_text_encoder=accelerator.unwrap_model(text_encoder),
+ prior_tokenizer=tokenizer,
+ )
+ pipeline.prior_pipe.save_pretrained(os.path.join(args.output_dir, "prior_pipeline"))
+
+ # Run a final round of inference.
+ images = []
+ if args.validation_prompts is not None:
+ logger.info("Running inference for collecting generated images...")
+ pipeline = pipeline.to(accelerator.device, torch_dtype=weight_dtype)
+ pipeline.set_progress_bar_config(disable=True)
+
+ if args.seed is None:
+ generator = None
+ else:
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
+
+ for i in range(len(args.validation_prompts)):
+ with torch.autocast("cuda"):
+ image = pipeline(
+ args.validation_prompts[i],
+ prior_timesteps=DEFAULT_STAGE_C_TIMESTEPS,
+ generator=generator,
+ width=args.resolution,
+ height=args.resolution,
+ ).images[0]
+ images.append(image)
+
+ if args.push_to_hub:
+ save_model_card(args, repo_id, images, repo_folder=args.output_dir)
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/diffusers/pyproject.toml b/diffusers/pyproject.toml
new file mode 100644
index 0000000000000000000000000000000000000000..299865a1225d92be40427f6112af0dab1c2a7993
--- /dev/null
+++ b/diffusers/pyproject.toml
@@ -0,0 +1,29 @@
+[tool.ruff]
+line-length = 119
+
+[tool.ruff.lint]
+# Never enforce `E501` (line length violations).
+ignore = ["C901", "E501", "E741", "F402", "F823"]
+select = ["C", "E", "F", "I", "W"]
+
+# Ignore import violations in all `__init__.py` files.
+[tool.ruff.lint.per-file-ignores]
+"__init__.py" = ["E402", "F401", "F403", "F811"]
+"src/diffusers/utils/dummy_*.py" = ["F401"]
+
+[tool.ruff.lint.isort]
+lines-after-imports = 2
+known-first-party = ["diffusers"]
+
+[tool.ruff.format]
+# Like Black, use double quotes for strings.
+quote-style = "double"
+
+# Like Black, indent with spaces, rather than tabs.
+indent-style = "space"
+
+# Like Black, respect magic trailing commas.
+skip-magic-trailing-comma = false
+
+# Like Black, automatically detect the appropriate line ending.
+line-ending = "auto"
diff --git a/diffusers/scripts/__init__.py b/diffusers/scripts/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/diffusers/scripts/change_naming_configs_and_checkpoints.py b/diffusers/scripts/change_naming_configs_and_checkpoints.py
new file mode 100644
index 0000000000000000000000000000000000000000..adc1605e95b3e44d434ef810c49571cea1cb4643
--- /dev/null
+++ b/diffusers/scripts/change_naming_configs_and_checkpoints.py
@@ -0,0 +1,113 @@
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Conversion script for the LDM checkpoints."""
+
+import argparse
+import json
+import os
+
+import torch
+from transformers.file_utils import has_file
+
+from diffusers import UNet2DConditionModel, UNet2DModel
+
+
+do_only_config = False
+do_only_weights = True
+do_only_renaming = False
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ "--repo_path",
+ default=None,
+ type=str,
+ required=True,
+ help="The config json file corresponding to the architecture.",
+ )
+
+ parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
+
+ args = parser.parse_args()
+
+ config_parameters_to_change = {
+ "image_size": "sample_size",
+ "num_res_blocks": "layers_per_block",
+ "block_channels": "block_out_channels",
+ "down_blocks": "down_block_types",
+ "up_blocks": "up_block_types",
+ "downscale_freq_shift": "freq_shift",
+ "resnet_num_groups": "norm_num_groups",
+ "resnet_act_fn": "act_fn",
+ "resnet_eps": "norm_eps",
+ "num_head_channels": "attention_head_dim",
+ }
+
+ key_parameters_to_change = {
+ "time_steps": "time_proj",
+ "mid": "mid_block",
+ "downsample_blocks": "down_blocks",
+ "upsample_blocks": "up_blocks",
+ }
+
+ subfolder = "" if has_file(args.repo_path, "config.json") else "unet"
+
+ with open(os.path.join(args.repo_path, subfolder, "config.json"), "r", encoding="utf-8") as reader:
+ text = reader.read()
+ config = json.loads(text)
+
+ if do_only_config:
+ for key in config_parameters_to_change.keys():
+ config.pop(key, None)
+
+ if has_file(args.repo_path, "config.json"):
+ model = UNet2DModel(**config)
+ else:
+ class_name = UNet2DConditionModel if "ldm-text2im-large-256" in args.repo_path else UNet2DModel
+ model = class_name(**config)
+
+ if do_only_config:
+ model.save_config(os.path.join(args.repo_path, subfolder))
+
+ config = dict(model.config)
+
+ if do_only_renaming:
+ for key, value in config_parameters_to_change.items():
+ if key in config:
+ config[value] = config[key]
+ del config[key]
+
+ config["down_block_types"] = [k.replace("UNetRes", "") for k in config["down_block_types"]]
+ config["up_block_types"] = [k.replace("UNetRes", "") for k in config["up_block_types"]]
+
+ if do_only_weights:
+ state_dict = torch.load(os.path.join(args.repo_path, subfolder, "diffusion_pytorch_model.bin"))
+
+ new_state_dict = {}
+ for param_key, param_value in state_dict.items():
+ if param_key.endswith(".op.bias") or param_key.endswith(".op.weight"):
+ continue
+ has_changed = False
+ for key, new_key in key_parameters_to_change.items():
+ if not has_changed and param_key.split(".")[0] == key:
+ new_state_dict[".".join([new_key] + param_key.split(".")[1:])] = param_value
+ has_changed = True
+ if not has_changed:
+ new_state_dict[param_key] = param_value
+
+ model.load_state_dict(new_state_dict)
+ model.save_pretrained(os.path.join(args.repo_path, subfolder))
diff --git a/diffusers/scripts/conversion_ldm_uncond.py b/diffusers/scripts/conversion_ldm_uncond.py
new file mode 100644
index 0000000000000000000000000000000000000000..8c22cc1ce8f2da514c1f096afaf67f1650e4df8a
--- /dev/null
+++ b/diffusers/scripts/conversion_ldm_uncond.py
@@ -0,0 +1,56 @@
+import argparse
+
+import torch
+import yaml
+
+from diffusers import DDIMScheduler, LDMPipeline, UNetLDMModel, VQModel
+
+
+def convert_ldm_original(checkpoint_path, config_path, output_path):
+ config = yaml.safe_load(config_path)
+ state_dict = torch.load(checkpoint_path, map_location="cpu")["model"]
+ keys = list(state_dict.keys())
+
+ # extract state_dict for VQVAE
+ first_stage_dict = {}
+ first_stage_key = "first_stage_model."
+ for key in keys:
+ if key.startswith(first_stage_key):
+ first_stage_dict[key.replace(first_stage_key, "")] = state_dict[key]
+
+ # extract state_dict for UNetLDM
+ unet_state_dict = {}
+ unet_key = "model.diffusion_model."
+ for key in keys:
+ if key.startswith(unet_key):
+ unet_state_dict[key.replace(unet_key, "")] = state_dict[key]
+
+ vqvae_init_args = config["model"]["params"]["first_stage_config"]["params"]
+ unet_init_args = config["model"]["params"]["unet_config"]["params"]
+
+ vqvae = VQModel(**vqvae_init_args).eval()
+ vqvae.load_state_dict(first_stage_dict)
+
+ unet = UNetLDMModel(**unet_init_args).eval()
+ unet.load_state_dict(unet_state_dict)
+
+ noise_scheduler = DDIMScheduler(
+ timesteps=config["model"]["params"]["timesteps"],
+ beta_schedule="scaled_linear",
+ beta_start=config["model"]["params"]["linear_start"],
+ beta_end=config["model"]["params"]["linear_end"],
+ clip_sample=False,
+ )
+
+ pipeline = LDMPipeline(vqvae, unet, noise_scheduler)
+ pipeline.save_pretrained(output_path)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--checkpoint_path", type=str, required=True)
+ parser.add_argument("--config_path", type=str, required=True)
+ parser.add_argument("--output_path", type=str, required=True)
+ args = parser.parse_args()
+
+ convert_ldm_original(args.checkpoint_path, args.config_path, args.output_path)
diff --git a/diffusers/scripts/convert_amused.py b/diffusers/scripts/convert_amused.py
new file mode 100644
index 0000000000000000000000000000000000000000..21be29dfdb99a7093b3060e828c514c8357d03d5
--- /dev/null
+++ b/diffusers/scripts/convert_amused.py
@@ -0,0 +1,523 @@
+import inspect
+import os
+from argparse import ArgumentParser
+
+import numpy as np
+import torch
+from muse import MaskGiTUViT, VQGANModel
+from muse import PipelineMuse as OldPipelineMuse
+from transformers import CLIPTextModelWithProjection, CLIPTokenizer
+
+from diffusers import VQModel
+from diffusers.models.attention_processor import AttnProcessor
+from diffusers.models.unets.uvit_2d import UVit2DModel
+from diffusers.pipelines.amused.pipeline_amused import AmusedPipeline
+from diffusers.schedulers import AmusedScheduler
+
+
+torch.backends.cuda.enable_flash_sdp(False)
+torch.backends.cuda.enable_mem_efficient_sdp(False)
+torch.backends.cuda.enable_math_sdp(True)
+
+os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
+os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
+torch.use_deterministic_algorithms(True)
+
+# Enable CUDNN deterministic mode
+torch.backends.cudnn.deterministic = True
+torch.backends.cudnn.benchmark = False
+torch.backends.cuda.matmul.allow_tf32 = False
+
+device = "cuda"
+
+
+def main():
+ args = ArgumentParser()
+ args.add_argument("--model_256", action="store_true")
+ args.add_argument("--write_to", type=str, required=False, default=None)
+ args.add_argument("--transformer_path", type=str, required=False, default=None)
+ args = args.parse_args()
+
+ transformer_path = args.transformer_path
+ subfolder = "transformer"
+
+ if transformer_path is None:
+ if args.model_256:
+ transformer_path = "openMUSE/muse-256"
+ else:
+ transformer_path = (
+ "../research-run-512-checkpoints/research-run-512-with-downsample-checkpoint-554000/unwrapped_model/"
+ )
+ subfolder = None
+
+ old_transformer = MaskGiTUViT.from_pretrained(transformer_path, subfolder=subfolder)
+
+ old_transformer.to(device)
+
+ old_vae = VQGANModel.from_pretrained("openMUSE/muse-512", subfolder="vae")
+ old_vae.to(device)
+
+ vqvae = make_vqvae(old_vae)
+
+ tokenizer = CLIPTokenizer.from_pretrained("openMUSE/muse-512", subfolder="text_encoder")
+
+ text_encoder = CLIPTextModelWithProjection.from_pretrained("openMUSE/muse-512", subfolder="text_encoder")
+ text_encoder.to(device)
+
+ transformer = make_transformer(old_transformer, args.model_256)
+
+ scheduler = AmusedScheduler(mask_token_id=old_transformer.config.mask_token_id)
+
+ new_pipe = AmusedPipeline(
+ vqvae=vqvae, tokenizer=tokenizer, text_encoder=text_encoder, transformer=transformer, scheduler=scheduler
+ )
+
+ old_pipe = OldPipelineMuse(
+ vae=old_vae, transformer=old_transformer, text_encoder=text_encoder, tokenizer=tokenizer
+ )
+ old_pipe.to(device)
+
+ if args.model_256:
+ transformer_seq_len = 256
+ orig_size = (256, 256)
+ else:
+ transformer_seq_len = 1024
+ orig_size = (512, 512)
+
+ old_out = old_pipe(
+ "dog",
+ generator=torch.Generator(device).manual_seed(0),
+ transformer_seq_len=transformer_seq_len,
+ orig_size=orig_size,
+ timesteps=12,
+ )[0]
+
+ new_out = new_pipe("dog", generator=torch.Generator(device).manual_seed(0)).images[0]
+
+ old_out = np.array(old_out)
+ new_out = np.array(new_out)
+
+ diff = np.abs(old_out.astype(np.float64) - new_out.astype(np.float64))
+
+ # assert diff diff.sum() == 0
+ print("skipping pipeline full equivalence check")
+
+ print(f"max diff: {diff.max()}, diff.sum() / diff.size {diff.sum() / diff.size}")
+
+ if args.model_256:
+ assert diff.max() <= 3
+ assert diff.sum() / diff.size < 0.7
+ else:
+ assert diff.max() <= 1
+ assert diff.sum() / diff.size < 0.4
+
+ if args.write_to is not None:
+ new_pipe.save_pretrained(args.write_to)
+
+
+def make_transformer(old_transformer, model_256):
+ args = dict(old_transformer.config)
+ force_down_up_sample = args["force_down_up_sample"]
+
+ signature = inspect.signature(UVit2DModel.__init__)
+
+ args_ = {
+ "downsample": force_down_up_sample,
+ "upsample": force_down_up_sample,
+ "block_out_channels": args["block_out_channels"][0],
+ "sample_size": 16 if model_256 else 32,
+ }
+
+ for s in list(signature.parameters.keys()):
+ if s in ["self", "downsample", "upsample", "sample_size", "block_out_channels"]:
+ continue
+
+ args_[s] = args[s]
+
+ new_transformer = UVit2DModel(**args_)
+ new_transformer.to(device)
+
+ new_transformer.set_attn_processor(AttnProcessor())
+
+ state_dict = old_transformer.state_dict()
+
+ state_dict["cond_embed.linear_1.weight"] = state_dict.pop("cond_embed.0.weight")
+ state_dict["cond_embed.linear_2.weight"] = state_dict.pop("cond_embed.2.weight")
+
+ for i in range(22):
+ state_dict[f"transformer_layers.{i}.norm1.norm.weight"] = state_dict.pop(
+ f"transformer_layers.{i}.attn_layer_norm.weight"
+ )
+ state_dict[f"transformer_layers.{i}.norm1.linear.weight"] = state_dict.pop(
+ f"transformer_layers.{i}.self_attn_adaLN_modulation.mapper.weight"
+ )
+
+ state_dict[f"transformer_layers.{i}.attn1.to_q.weight"] = state_dict.pop(
+ f"transformer_layers.{i}.attention.query.weight"
+ )
+ state_dict[f"transformer_layers.{i}.attn1.to_k.weight"] = state_dict.pop(
+ f"transformer_layers.{i}.attention.key.weight"
+ )
+ state_dict[f"transformer_layers.{i}.attn1.to_v.weight"] = state_dict.pop(
+ f"transformer_layers.{i}.attention.value.weight"
+ )
+ state_dict[f"transformer_layers.{i}.attn1.to_out.0.weight"] = state_dict.pop(
+ f"transformer_layers.{i}.attention.out.weight"
+ )
+
+ state_dict[f"transformer_layers.{i}.norm2.norm.weight"] = state_dict.pop(
+ f"transformer_layers.{i}.crossattn_layer_norm.weight"
+ )
+ state_dict[f"transformer_layers.{i}.norm2.linear.weight"] = state_dict.pop(
+ f"transformer_layers.{i}.cross_attn_adaLN_modulation.mapper.weight"
+ )
+
+ state_dict[f"transformer_layers.{i}.attn2.to_q.weight"] = state_dict.pop(
+ f"transformer_layers.{i}.crossattention.query.weight"
+ )
+ state_dict[f"transformer_layers.{i}.attn2.to_k.weight"] = state_dict.pop(
+ f"transformer_layers.{i}.crossattention.key.weight"
+ )
+ state_dict[f"transformer_layers.{i}.attn2.to_v.weight"] = state_dict.pop(
+ f"transformer_layers.{i}.crossattention.value.weight"
+ )
+ state_dict[f"transformer_layers.{i}.attn2.to_out.0.weight"] = state_dict.pop(
+ f"transformer_layers.{i}.crossattention.out.weight"
+ )
+
+ state_dict[f"transformer_layers.{i}.norm3.norm.weight"] = state_dict.pop(
+ f"transformer_layers.{i}.ffn.pre_mlp_layer_norm.weight"
+ )
+ state_dict[f"transformer_layers.{i}.norm3.linear.weight"] = state_dict.pop(
+ f"transformer_layers.{i}.ffn.adaLN_modulation.mapper.weight"
+ )
+
+ wi_0_weight = state_dict.pop(f"transformer_layers.{i}.ffn.wi_0.weight")
+ wi_1_weight = state_dict.pop(f"transformer_layers.{i}.ffn.wi_1.weight")
+ proj_weight = torch.concat([wi_1_weight, wi_0_weight], dim=0)
+ state_dict[f"transformer_layers.{i}.ff.net.0.proj.weight"] = proj_weight
+
+ state_dict[f"transformer_layers.{i}.ff.net.2.weight"] = state_dict.pop(f"transformer_layers.{i}.ffn.wo.weight")
+
+ if force_down_up_sample:
+ state_dict["down_block.downsample.norm.weight"] = state_dict.pop("down_blocks.0.downsample.0.norm.weight")
+ state_dict["down_block.downsample.conv.weight"] = state_dict.pop("down_blocks.0.downsample.1.weight")
+
+ state_dict["up_block.upsample.norm.weight"] = state_dict.pop("up_blocks.0.upsample.0.norm.weight")
+ state_dict["up_block.upsample.conv.weight"] = state_dict.pop("up_blocks.0.upsample.1.weight")
+
+ state_dict["mlm_layer.layer_norm.weight"] = state_dict.pop("mlm_layer.layer_norm.norm.weight")
+
+ for i in range(3):
+ state_dict[f"down_block.res_blocks.{i}.norm.weight"] = state_dict.pop(
+ f"down_blocks.0.res_blocks.{i}.norm.norm.weight"
+ )
+ state_dict[f"down_block.res_blocks.{i}.channelwise_linear_1.weight"] = state_dict.pop(
+ f"down_blocks.0.res_blocks.{i}.channelwise.0.weight"
+ )
+ state_dict[f"down_block.res_blocks.{i}.channelwise_norm.gamma"] = state_dict.pop(
+ f"down_blocks.0.res_blocks.{i}.channelwise.2.gamma"
+ )
+ state_dict[f"down_block.res_blocks.{i}.channelwise_norm.beta"] = state_dict.pop(
+ f"down_blocks.0.res_blocks.{i}.channelwise.2.beta"
+ )
+ state_dict[f"down_block.res_blocks.{i}.channelwise_linear_2.weight"] = state_dict.pop(
+ f"down_blocks.0.res_blocks.{i}.channelwise.4.weight"
+ )
+ state_dict[f"down_block.res_blocks.{i}.cond_embeds_mapper.weight"] = state_dict.pop(
+ f"down_blocks.0.res_blocks.{i}.adaLN_modulation.mapper.weight"
+ )
+
+ state_dict[f"down_block.attention_blocks.{i}.norm1.weight"] = state_dict.pop(
+ f"down_blocks.0.attention_blocks.{i}.attn_layer_norm.weight"
+ )
+ state_dict[f"down_block.attention_blocks.{i}.attn1.to_q.weight"] = state_dict.pop(
+ f"down_blocks.0.attention_blocks.{i}.attention.query.weight"
+ )
+ state_dict[f"down_block.attention_blocks.{i}.attn1.to_k.weight"] = state_dict.pop(
+ f"down_blocks.0.attention_blocks.{i}.attention.key.weight"
+ )
+ state_dict[f"down_block.attention_blocks.{i}.attn1.to_v.weight"] = state_dict.pop(
+ f"down_blocks.0.attention_blocks.{i}.attention.value.weight"
+ )
+ state_dict[f"down_block.attention_blocks.{i}.attn1.to_out.0.weight"] = state_dict.pop(
+ f"down_blocks.0.attention_blocks.{i}.attention.out.weight"
+ )
+
+ state_dict[f"down_block.attention_blocks.{i}.norm2.weight"] = state_dict.pop(
+ f"down_blocks.0.attention_blocks.{i}.crossattn_layer_norm.weight"
+ )
+ state_dict[f"down_block.attention_blocks.{i}.attn2.to_q.weight"] = state_dict.pop(
+ f"down_blocks.0.attention_blocks.{i}.crossattention.query.weight"
+ )
+ state_dict[f"down_block.attention_blocks.{i}.attn2.to_k.weight"] = state_dict.pop(
+ f"down_blocks.0.attention_blocks.{i}.crossattention.key.weight"
+ )
+ state_dict[f"down_block.attention_blocks.{i}.attn2.to_v.weight"] = state_dict.pop(
+ f"down_blocks.0.attention_blocks.{i}.crossattention.value.weight"
+ )
+ state_dict[f"down_block.attention_blocks.{i}.attn2.to_out.0.weight"] = state_dict.pop(
+ f"down_blocks.0.attention_blocks.{i}.crossattention.out.weight"
+ )
+
+ state_dict[f"up_block.res_blocks.{i}.norm.weight"] = state_dict.pop(
+ f"up_blocks.0.res_blocks.{i}.norm.norm.weight"
+ )
+ state_dict[f"up_block.res_blocks.{i}.channelwise_linear_1.weight"] = state_dict.pop(
+ f"up_blocks.0.res_blocks.{i}.channelwise.0.weight"
+ )
+ state_dict[f"up_block.res_blocks.{i}.channelwise_norm.gamma"] = state_dict.pop(
+ f"up_blocks.0.res_blocks.{i}.channelwise.2.gamma"
+ )
+ state_dict[f"up_block.res_blocks.{i}.channelwise_norm.beta"] = state_dict.pop(
+ f"up_blocks.0.res_blocks.{i}.channelwise.2.beta"
+ )
+ state_dict[f"up_block.res_blocks.{i}.channelwise_linear_2.weight"] = state_dict.pop(
+ f"up_blocks.0.res_blocks.{i}.channelwise.4.weight"
+ )
+ state_dict[f"up_block.res_blocks.{i}.cond_embeds_mapper.weight"] = state_dict.pop(
+ f"up_blocks.0.res_blocks.{i}.adaLN_modulation.mapper.weight"
+ )
+
+ state_dict[f"up_block.attention_blocks.{i}.norm1.weight"] = state_dict.pop(
+ f"up_blocks.0.attention_blocks.{i}.attn_layer_norm.weight"
+ )
+ state_dict[f"up_block.attention_blocks.{i}.attn1.to_q.weight"] = state_dict.pop(
+ f"up_blocks.0.attention_blocks.{i}.attention.query.weight"
+ )
+ state_dict[f"up_block.attention_blocks.{i}.attn1.to_k.weight"] = state_dict.pop(
+ f"up_blocks.0.attention_blocks.{i}.attention.key.weight"
+ )
+ state_dict[f"up_block.attention_blocks.{i}.attn1.to_v.weight"] = state_dict.pop(
+ f"up_blocks.0.attention_blocks.{i}.attention.value.weight"
+ )
+ state_dict[f"up_block.attention_blocks.{i}.attn1.to_out.0.weight"] = state_dict.pop(
+ f"up_blocks.0.attention_blocks.{i}.attention.out.weight"
+ )
+
+ state_dict[f"up_block.attention_blocks.{i}.norm2.weight"] = state_dict.pop(
+ f"up_blocks.0.attention_blocks.{i}.crossattn_layer_norm.weight"
+ )
+ state_dict[f"up_block.attention_blocks.{i}.attn2.to_q.weight"] = state_dict.pop(
+ f"up_blocks.0.attention_blocks.{i}.crossattention.query.weight"
+ )
+ state_dict[f"up_block.attention_blocks.{i}.attn2.to_k.weight"] = state_dict.pop(
+ f"up_blocks.0.attention_blocks.{i}.crossattention.key.weight"
+ )
+ state_dict[f"up_block.attention_blocks.{i}.attn2.to_v.weight"] = state_dict.pop(
+ f"up_blocks.0.attention_blocks.{i}.crossattention.value.weight"
+ )
+ state_dict[f"up_block.attention_blocks.{i}.attn2.to_out.0.weight"] = state_dict.pop(
+ f"up_blocks.0.attention_blocks.{i}.crossattention.out.weight"
+ )
+
+ for key in list(state_dict.keys()):
+ if key.startswith("up_blocks.0"):
+ key_ = "up_block." + ".".join(key.split(".")[2:])
+ state_dict[key_] = state_dict.pop(key)
+
+ if key.startswith("down_blocks.0"):
+ key_ = "down_block." + ".".join(key.split(".")[2:])
+ state_dict[key_] = state_dict.pop(key)
+
+ new_transformer.load_state_dict(state_dict)
+
+ input_ids = torch.randint(0, 10, (1, 32, 32), device=old_transformer.device)
+ encoder_hidden_states = torch.randn((1, 77, 768), device=old_transformer.device)
+ cond_embeds = torch.randn((1, 768), device=old_transformer.device)
+ micro_conds = torch.tensor([[512, 512, 0, 0, 6]], dtype=torch.float32, device=old_transformer.device)
+
+ old_out = old_transformer(input_ids.reshape(1, -1), encoder_hidden_states, cond_embeds, micro_conds)
+ old_out = old_out.reshape(1, 32, 32, 8192).permute(0, 3, 1, 2)
+
+ new_out = new_transformer(input_ids, encoder_hidden_states, cond_embeds, micro_conds)
+
+ # NOTE: these differences are solely due to using the geglu block that has a single linear layer of
+ # double output dimension instead of two different linear layers
+ max_diff = (old_out - new_out).abs().max()
+ total_diff = (old_out - new_out).abs().sum()
+ print(f"Transformer max_diff: {max_diff} total_diff: {total_diff}")
+ assert max_diff < 0.01
+ assert total_diff < 1500
+
+ return new_transformer
+
+
+def make_vqvae(old_vae):
+ new_vae = VQModel(
+ act_fn="silu",
+ block_out_channels=[128, 256, 256, 512, 768],
+ down_block_types=[
+ "DownEncoderBlock2D",
+ "DownEncoderBlock2D",
+ "DownEncoderBlock2D",
+ "DownEncoderBlock2D",
+ "DownEncoderBlock2D",
+ ],
+ in_channels=3,
+ latent_channels=64,
+ layers_per_block=2,
+ norm_num_groups=32,
+ num_vq_embeddings=8192,
+ out_channels=3,
+ sample_size=32,
+ up_block_types=[
+ "UpDecoderBlock2D",
+ "UpDecoderBlock2D",
+ "UpDecoderBlock2D",
+ "UpDecoderBlock2D",
+ "UpDecoderBlock2D",
+ ],
+ mid_block_add_attention=False,
+ lookup_from_codebook=True,
+ )
+ new_vae.to(device)
+
+ # fmt: off
+
+ new_state_dict = {}
+
+ old_state_dict = old_vae.state_dict()
+
+ new_state_dict["encoder.conv_in.weight"] = old_state_dict.pop("encoder.conv_in.weight")
+ new_state_dict["encoder.conv_in.bias"] = old_state_dict.pop("encoder.conv_in.bias")
+
+ convert_vae_block_state_dict(old_state_dict, "encoder.down.0", new_state_dict, "encoder.down_blocks.0")
+ convert_vae_block_state_dict(old_state_dict, "encoder.down.1", new_state_dict, "encoder.down_blocks.1")
+ convert_vae_block_state_dict(old_state_dict, "encoder.down.2", new_state_dict, "encoder.down_blocks.2")
+ convert_vae_block_state_dict(old_state_dict, "encoder.down.3", new_state_dict, "encoder.down_blocks.3")
+ convert_vae_block_state_dict(old_state_dict, "encoder.down.4", new_state_dict, "encoder.down_blocks.4")
+
+ new_state_dict["encoder.mid_block.resnets.0.norm1.weight"] = old_state_dict.pop("encoder.mid.block_1.norm1.weight")
+ new_state_dict["encoder.mid_block.resnets.0.norm1.bias"] = old_state_dict.pop("encoder.mid.block_1.norm1.bias")
+ new_state_dict["encoder.mid_block.resnets.0.conv1.weight"] = old_state_dict.pop("encoder.mid.block_1.conv1.weight")
+ new_state_dict["encoder.mid_block.resnets.0.conv1.bias"] = old_state_dict.pop("encoder.mid.block_1.conv1.bias")
+ new_state_dict["encoder.mid_block.resnets.0.norm2.weight"] = old_state_dict.pop("encoder.mid.block_1.norm2.weight")
+ new_state_dict["encoder.mid_block.resnets.0.norm2.bias"] = old_state_dict.pop("encoder.mid.block_1.norm2.bias")
+ new_state_dict["encoder.mid_block.resnets.0.conv2.weight"] = old_state_dict.pop("encoder.mid.block_1.conv2.weight")
+ new_state_dict["encoder.mid_block.resnets.0.conv2.bias"] = old_state_dict.pop("encoder.mid.block_1.conv2.bias")
+ new_state_dict["encoder.mid_block.resnets.1.norm1.weight"] = old_state_dict.pop("encoder.mid.block_2.norm1.weight")
+ new_state_dict["encoder.mid_block.resnets.1.norm1.bias"] = old_state_dict.pop("encoder.mid.block_2.norm1.bias")
+ new_state_dict["encoder.mid_block.resnets.1.conv1.weight"] = old_state_dict.pop("encoder.mid.block_2.conv1.weight")
+ new_state_dict["encoder.mid_block.resnets.1.conv1.bias"] = old_state_dict.pop("encoder.mid.block_2.conv1.bias")
+ new_state_dict["encoder.mid_block.resnets.1.norm2.weight"] = old_state_dict.pop("encoder.mid.block_2.norm2.weight")
+ new_state_dict["encoder.mid_block.resnets.1.norm2.bias"] = old_state_dict.pop("encoder.mid.block_2.norm2.bias")
+ new_state_dict["encoder.mid_block.resnets.1.conv2.weight"] = old_state_dict.pop("encoder.mid.block_2.conv2.weight")
+ new_state_dict["encoder.mid_block.resnets.1.conv2.bias"] = old_state_dict.pop("encoder.mid.block_2.conv2.bias")
+ new_state_dict["encoder.conv_norm_out.weight"] = old_state_dict.pop("encoder.norm_out.weight")
+ new_state_dict["encoder.conv_norm_out.bias"] = old_state_dict.pop("encoder.norm_out.bias")
+ new_state_dict["encoder.conv_out.weight"] = old_state_dict.pop("encoder.conv_out.weight")
+ new_state_dict["encoder.conv_out.bias"] = old_state_dict.pop("encoder.conv_out.bias")
+ new_state_dict["quant_conv.weight"] = old_state_dict.pop("quant_conv.weight")
+ new_state_dict["quant_conv.bias"] = old_state_dict.pop("quant_conv.bias")
+ new_state_dict["quantize.embedding.weight"] = old_state_dict.pop("quantize.embedding.weight")
+ new_state_dict["post_quant_conv.weight"] = old_state_dict.pop("post_quant_conv.weight")
+ new_state_dict["post_quant_conv.bias"] = old_state_dict.pop("post_quant_conv.bias")
+ new_state_dict["decoder.conv_in.weight"] = old_state_dict.pop("decoder.conv_in.weight")
+ new_state_dict["decoder.conv_in.bias"] = old_state_dict.pop("decoder.conv_in.bias")
+ new_state_dict["decoder.mid_block.resnets.0.norm1.weight"] = old_state_dict.pop("decoder.mid.block_1.norm1.weight")
+ new_state_dict["decoder.mid_block.resnets.0.norm1.bias"] = old_state_dict.pop("decoder.mid.block_1.norm1.bias")
+ new_state_dict["decoder.mid_block.resnets.0.conv1.weight"] = old_state_dict.pop("decoder.mid.block_1.conv1.weight")
+ new_state_dict["decoder.mid_block.resnets.0.conv1.bias"] = old_state_dict.pop("decoder.mid.block_1.conv1.bias")
+ new_state_dict["decoder.mid_block.resnets.0.norm2.weight"] = old_state_dict.pop("decoder.mid.block_1.norm2.weight")
+ new_state_dict["decoder.mid_block.resnets.0.norm2.bias"] = old_state_dict.pop("decoder.mid.block_1.norm2.bias")
+ new_state_dict["decoder.mid_block.resnets.0.conv2.weight"] = old_state_dict.pop("decoder.mid.block_1.conv2.weight")
+ new_state_dict["decoder.mid_block.resnets.0.conv2.bias"] = old_state_dict.pop("decoder.mid.block_1.conv2.bias")
+ new_state_dict["decoder.mid_block.resnets.1.norm1.weight"] = old_state_dict.pop("decoder.mid.block_2.norm1.weight")
+ new_state_dict["decoder.mid_block.resnets.1.norm1.bias"] = old_state_dict.pop("decoder.mid.block_2.norm1.bias")
+ new_state_dict["decoder.mid_block.resnets.1.conv1.weight"] = old_state_dict.pop("decoder.mid.block_2.conv1.weight")
+ new_state_dict["decoder.mid_block.resnets.1.conv1.bias"] = old_state_dict.pop("decoder.mid.block_2.conv1.bias")
+ new_state_dict["decoder.mid_block.resnets.1.norm2.weight"] = old_state_dict.pop("decoder.mid.block_2.norm2.weight")
+ new_state_dict["decoder.mid_block.resnets.1.norm2.bias"] = old_state_dict.pop("decoder.mid.block_2.norm2.bias")
+ new_state_dict["decoder.mid_block.resnets.1.conv2.weight"] = old_state_dict.pop("decoder.mid.block_2.conv2.weight")
+ new_state_dict["decoder.mid_block.resnets.1.conv2.bias"] = old_state_dict.pop("decoder.mid.block_2.conv2.bias")
+
+ convert_vae_block_state_dict(old_state_dict, "decoder.up.0", new_state_dict, "decoder.up_blocks.4")
+ convert_vae_block_state_dict(old_state_dict, "decoder.up.1", new_state_dict, "decoder.up_blocks.3")
+ convert_vae_block_state_dict(old_state_dict, "decoder.up.2", new_state_dict, "decoder.up_blocks.2")
+ convert_vae_block_state_dict(old_state_dict, "decoder.up.3", new_state_dict, "decoder.up_blocks.1")
+ convert_vae_block_state_dict(old_state_dict, "decoder.up.4", new_state_dict, "decoder.up_blocks.0")
+
+ new_state_dict["decoder.conv_norm_out.weight"] = old_state_dict.pop("decoder.norm_out.weight")
+ new_state_dict["decoder.conv_norm_out.bias"] = old_state_dict.pop("decoder.norm_out.bias")
+ new_state_dict["decoder.conv_out.weight"] = old_state_dict.pop("decoder.conv_out.weight")
+ new_state_dict["decoder.conv_out.bias"] = old_state_dict.pop("decoder.conv_out.bias")
+
+ # fmt: on
+
+ assert len(old_state_dict.keys()) == 0
+
+ new_vae.load_state_dict(new_state_dict)
+
+ input = torch.randn((1, 3, 512, 512), device=device)
+ input = input.clamp(-1, 1)
+
+ old_encoder_output = old_vae.quant_conv(old_vae.encoder(input))
+ new_encoder_output = new_vae.quant_conv(new_vae.encoder(input))
+ assert (old_encoder_output == new_encoder_output).all()
+
+ old_decoder_output = old_vae.decoder(old_vae.post_quant_conv(old_encoder_output))
+ new_decoder_output = new_vae.decoder(new_vae.post_quant_conv(new_encoder_output))
+
+ # assert (old_decoder_output == new_decoder_output).all()
+ print("kipping vae decoder equivalence check")
+ print(f"vae decoder diff {(old_decoder_output - new_decoder_output).float().abs().sum()}")
+
+ old_output = old_vae(input)[0]
+ new_output = new_vae(input)[0]
+
+ # assert (old_output == new_output).all()
+ print("skipping full vae equivalence check")
+ print(f"vae full diff { (old_output - new_output).float().abs().sum()}")
+
+ return new_vae
+
+
+def convert_vae_block_state_dict(old_state_dict, prefix_from, new_state_dict, prefix_to):
+ # fmt: off
+
+ new_state_dict[f"{prefix_to}.resnets.0.norm1.weight"] = old_state_dict.pop(f"{prefix_from}.block.0.norm1.weight")
+ new_state_dict[f"{prefix_to}.resnets.0.norm1.bias"] = old_state_dict.pop(f"{prefix_from}.block.0.norm1.bias")
+ new_state_dict[f"{prefix_to}.resnets.0.conv1.weight"] = old_state_dict.pop(f"{prefix_from}.block.0.conv1.weight")
+ new_state_dict[f"{prefix_to}.resnets.0.conv1.bias"] = old_state_dict.pop(f"{prefix_from}.block.0.conv1.bias")
+ new_state_dict[f"{prefix_to}.resnets.0.norm2.weight"] = old_state_dict.pop(f"{prefix_from}.block.0.norm2.weight")
+ new_state_dict[f"{prefix_to}.resnets.0.norm2.bias"] = old_state_dict.pop(f"{prefix_from}.block.0.norm2.bias")
+ new_state_dict[f"{prefix_to}.resnets.0.conv2.weight"] = old_state_dict.pop(f"{prefix_from}.block.0.conv2.weight")
+ new_state_dict[f"{prefix_to}.resnets.0.conv2.bias"] = old_state_dict.pop(f"{prefix_from}.block.0.conv2.bias")
+
+ if f"{prefix_from}.block.0.nin_shortcut.weight" in old_state_dict:
+ new_state_dict[f"{prefix_to}.resnets.0.conv_shortcut.weight"] = old_state_dict.pop(f"{prefix_from}.block.0.nin_shortcut.weight")
+ new_state_dict[f"{prefix_to}.resnets.0.conv_shortcut.bias"] = old_state_dict.pop(f"{prefix_from}.block.0.nin_shortcut.bias")
+
+ new_state_dict[f"{prefix_to}.resnets.1.norm1.weight"] = old_state_dict.pop(f"{prefix_from}.block.1.norm1.weight")
+ new_state_dict[f"{prefix_to}.resnets.1.norm1.bias"] = old_state_dict.pop(f"{prefix_from}.block.1.norm1.bias")
+ new_state_dict[f"{prefix_to}.resnets.1.conv1.weight"] = old_state_dict.pop(f"{prefix_from}.block.1.conv1.weight")
+ new_state_dict[f"{prefix_to}.resnets.1.conv1.bias"] = old_state_dict.pop(f"{prefix_from}.block.1.conv1.bias")
+ new_state_dict[f"{prefix_to}.resnets.1.norm2.weight"] = old_state_dict.pop(f"{prefix_from}.block.1.norm2.weight")
+ new_state_dict[f"{prefix_to}.resnets.1.norm2.bias"] = old_state_dict.pop(f"{prefix_from}.block.1.norm2.bias")
+ new_state_dict[f"{prefix_to}.resnets.1.conv2.weight"] = old_state_dict.pop(f"{prefix_from}.block.1.conv2.weight")
+ new_state_dict[f"{prefix_to}.resnets.1.conv2.bias"] = old_state_dict.pop(f"{prefix_from}.block.1.conv2.bias")
+
+ if f"{prefix_from}.downsample.conv.weight" in old_state_dict:
+ new_state_dict[f"{prefix_to}.downsamplers.0.conv.weight"] = old_state_dict.pop(f"{prefix_from}.downsample.conv.weight")
+ new_state_dict[f"{prefix_to}.downsamplers.0.conv.bias"] = old_state_dict.pop(f"{prefix_from}.downsample.conv.bias")
+
+ if f"{prefix_from}.upsample.conv.weight" in old_state_dict:
+ new_state_dict[f"{prefix_to}.upsamplers.0.conv.weight"] = old_state_dict.pop(f"{prefix_from}.upsample.conv.weight")
+ new_state_dict[f"{prefix_to}.upsamplers.0.conv.bias"] = old_state_dict.pop(f"{prefix_from}.upsample.conv.bias")
+
+ if f"{prefix_from}.block.2.norm1.weight" in old_state_dict:
+ new_state_dict[f"{prefix_to}.resnets.2.norm1.weight"] = old_state_dict.pop(f"{prefix_from}.block.2.norm1.weight")
+ new_state_dict[f"{prefix_to}.resnets.2.norm1.bias"] = old_state_dict.pop(f"{prefix_from}.block.2.norm1.bias")
+ new_state_dict[f"{prefix_to}.resnets.2.conv1.weight"] = old_state_dict.pop(f"{prefix_from}.block.2.conv1.weight")
+ new_state_dict[f"{prefix_to}.resnets.2.conv1.bias"] = old_state_dict.pop(f"{prefix_from}.block.2.conv1.bias")
+ new_state_dict[f"{prefix_to}.resnets.2.norm2.weight"] = old_state_dict.pop(f"{prefix_from}.block.2.norm2.weight")
+ new_state_dict[f"{prefix_to}.resnets.2.norm2.bias"] = old_state_dict.pop(f"{prefix_from}.block.2.norm2.bias")
+ new_state_dict[f"{prefix_to}.resnets.2.conv2.weight"] = old_state_dict.pop(f"{prefix_from}.block.2.conv2.weight")
+ new_state_dict[f"{prefix_to}.resnets.2.conv2.bias"] = old_state_dict.pop(f"{prefix_from}.block.2.conv2.bias")
+
+ # fmt: on
+
+
+if __name__ == "__main__":
+ main()
diff --git a/diffusers/scripts/convert_animatediff_motion_lora_to_diffusers.py b/diffusers/scripts/convert_animatediff_motion_lora_to_diffusers.py
new file mode 100644
index 0000000000000000000000000000000000000000..21567ffa9e7a95abd86c573761ca90d1648b8002
--- /dev/null
+++ b/diffusers/scripts/convert_animatediff_motion_lora_to_diffusers.py
@@ -0,0 +1,69 @@
+import argparse
+import os
+
+import torch
+from huggingface_hub import create_repo, upload_folder
+from safetensors.torch import load_file, save_file
+
+
+def convert_motion_module(original_state_dict):
+ converted_state_dict = {}
+ for k, v in original_state_dict.items():
+ if "pos_encoder" in k:
+ continue
+
+ else:
+ converted_state_dict[
+ k.replace(".norms.0", ".norm1")
+ .replace(".norms.1", ".norm2")
+ .replace(".ff_norm", ".norm3")
+ .replace(".attention_blocks.0", ".attn1")
+ .replace(".attention_blocks.1", ".attn2")
+ .replace(".temporal_transformer", "")
+ ] = v
+
+ return converted_state_dict
+
+
+def get_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--ckpt_path", type=str, required=True, help="Path to checkpoint")
+ parser.add_argument("--output_path", type=str, required=True, help="Path to output directory")
+ parser.add_argument(
+ "--push_to_hub",
+ action="store_true",
+ default=False,
+ help="Whether to push the converted model to the HF or not",
+ )
+
+ return parser.parse_args()
+
+
+if __name__ == "__main__":
+ args = get_args()
+
+ if args.ckpt_path.endswith(".safetensors"):
+ state_dict = load_file(args.ckpt_path)
+ else:
+ state_dict = torch.load(args.ckpt_path, map_location="cpu")
+
+ if "state_dict" in state_dict.keys():
+ state_dict = state_dict["state_dict"]
+
+ conv_state_dict = convert_motion_module(state_dict)
+
+ # convert to new format
+ output_dict = {}
+ for module_name, params in conv_state_dict.items():
+ if type(params) is not torch.Tensor:
+ continue
+ output_dict.update({f"unet.{module_name}": params})
+
+ os.makedirs(args.output_path, exist_ok=True)
+
+ filepath = os.path.join(args.output_path, "diffusion_pytorch_model.safetensors")
+ save_file(output_dict, filepath)
+
+ if args.push_to_hub:
+ repo_id = create_repo(args.output_path, exist_ok=True).repo_id
+ upload_folder(repo_id=repo_id, folder_path=args.output_path, repo_type="model")
diff --git a/diffusers/scripts/convert_animatediff_motion_module_to_diffusers.py b/diffusers/scripts/convert_animatediff_motion_module_to_diffusers.py
new file mode 100644
index 0000000000000000000000000000000000000000..e188a6a533e86a837b58b61ba68905d5ab24a008
--- /dev/null
+++ b/diffusers/scripts/convert_animatediff_motion_module_to_diffusers.py
@@ -0,0 +1,62 @@
+import argparse
+
+import torch
+from safetensors.torch import load_file
+
+from diffusers import MotionAdapter
+
+
+def convert_motion_module(original_state_dict):
+ converted_state_dict = {}
+ for k, v in original_state_dict.items():
+ if "pos_encoder" in k:
+ continue
+
+ else:
+ converted_state_dict[
+ k.replace(".norms.0", ".norm1")
+ .replace(".norms.1", ".norm2")
+ .replace(".ff_norm", ".norm3")
+ .replace(".attention_blocks.0", ".attn1")
+ .replace(".attention_blocks.1", ".attn2")
+ .replace(".temporal_transformer", "")
+ ] = v
+
+ return converted_state_dict
+
+
+def get_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--ckpt_path", type=str, required=True)
+ parser.add_argument("--output_path", type=str, required=True)
+ parser.add_argument("--use_motion_mid_block", action="store_true")
+ parser.add_argument("--motion_max_seq_length", type=int, default=32)
+ parser.add_argument("--block_out_channels", nargs="+", default=[320, 640, 1280, 1280], type=int)
+ parser.add_argument("--save_fp16", action="store_true")
+
+ return parser.parse_args()
+
+
+if __name__ == "__main__":
+ args = get_args()
+
+ if args.ckpt_path.endswith(".safetensors"):
+ state_dict = load_file(args.ckpt_path)
+ else:
+ state_dict = torch.load(args.ckpt_path, map_location="cpu")
+
+ if "state_dict" in state_dict.keys():
+ state_dict = state_dict["state_dict"]
+
+ conv_state_dict = convert_motion_module(state_dict)
+ adapter = MotionAdapter(
+ block_out_channels=args.block_out_channels,
+ use_motion_mid_block=args.use_motion_mid_block,
+ motion_max_seq_length=args.motion_max_seq_length,
+ )
+ # skip loading position embeddings
+ adapter.load_state_dict(conv_state_dict, strict=False)
+ adapter.save_pretrained(args.output_path)
+
+ if args.save_fp16:
+ adapter.to(dtype=torch.float16).save_pretrained(args.output_path, variant="fp16")
diff --git a/diffusers/scripts/convert_animatediff_sparsectrl_to_diffusers.py b/diffusers/scripts/convert_animatediff_sparsectrl_to_diffusers.py
new file mode 100644
index 0000000000000000000000000000000000000000..f246dceb97f841273d0904248e4bfd8c341fda53
--- /dev/null
+++ b/diffusers/scripts/convert_animatediff_sparsectrl_to_diffusers.py
@@ -0,0 +1,83 @@
+import argparse
+from typing import Dict
+
+import torch
+import torch.nn as nn
+
+from diffusers import SparseControlNetModel
+
+
+KEYS_RENAME_MAPPING = {
+ ".attention_blocks.0": ".attn1",
+ ".attention_blocks.1": ".attn2",
+ ".attn1.pos_encoder": ".pos_embed",
+ ".ff_norm": ".norm3",
+ ".norms.0": ".norm1",
+ ".norms.1": ".norm2",
+ ".temporal_transformer": "",
+}
+
+
+def convert(original_state_dict: Dict[str, nn.Module]) -> Dict[str, nn.Module]:
+ converted_state_dict = {}
+
+ for key in list(original_state_dict.keys()):
+ renamed_key = key
+ for new_name, old_name in KEYS_RENAME_MAPPING.items():
+ renamed_key = renamed_key.replace(new_name, old_name)
+ converted_state_dict[renamed_key] = original_state_dict.pop(key)
+
+ return converted_state_dict
+
+
+def get_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--ckpt_path", type=str, required=True, help="Path to checkpoint")
+ parser.add_argument("--output_path", type=str, required=True, help="Path to output directory")
+ parser.add_argument(
+ "--max_motion_seq_length",
+ type=int,
+ default=32,
+ help="Max motion sequence length supported by the motion adapter",
+ )
+ parser.add_argument(
+ "--conditioning_channels", type=int, default=4, help="Number of channels in conditioning input to controlnet"
+ )
+ parser.add_argument(
+ "--use_simplified_condition_embedding",
+ action="store_true",
+ default=False,
+ help="Whether or not to use simplified condition embedding. When `conditioning_channels==4` i.e. latent inputs, set this to `True`. When `conditioning_channels==3` i.e. image inputs, set this to `False`",
+ )
+ parser.add_argument(
+ "--save_fp16",
+ action="store_true",
+ default=False,
+ help="Whether or not to save model in fp16 precision along with fp32",
+ )
+ parser.add_argument(
+ "--push_to_hub", action="store_true", default=False, help="Whether or not to push saved model to the HF hub"
+ )
+ return parser.parse_args()
+
+
+if __name__ == "__main__":
+ args = get_args()
+
+ state_dict = torch.load(args.ckpt_path, map_location="cpu")
+ if "state_dict" in state_dict.keys():
+ state_dict: dict = state_dict["state_dict"]
+
+ controlnet = SparseControlNetModel(
+ conditioning_channels=args.conditioning_channels,
+ motion_max_seq_length=args.max_motion_seq_length,
+ use_simplified_condition_embedding=args.use_simplified_condition_embedding,
+ )
+
+ state_dict = convert(state_dict)
+ controlnet.load_state_dict(state_dict, strict=True)
+
+ controlnet.save_pretrained(args.output_path, push_to_hub=args.push_to_hub)
+ if args.save_fp16:
+ controlnet = controlnet.to(dtype=torch.float16)
+ controlnet.save_pretrained(args.output_path, variant="fp16", push_to_hub=args.push_to_hub)
diff --git a/diffusers/scripts/convert_asymmetric_vqgan_to_diffusers.py b/diffusers/scripts/convert_asymmetric_vqgan_to_diffusers.py
new file mode 100644
index 0000000000000000000000000000000000000000..ffb735e18224a7ef48503367112f5ce8142bdf9c
--- /dev/null
+++ b/diffusers/scripts/convert_asymmetric_vqgan_to_diffusers.py
@@ -0,0 +1,184 @@
+import argparse
+import time
+from pathlib import Path
+from typing import Any, Dict, Literal
+
+import torch
+
+from diffusers import AsymmetricAutoencoderKL
+
+
+ASYMMETRIC_AUTOENCODER_KL_x_1_5_CONFIG = {
+ "in_channels": 3,
+ "out_channels": 3,
+ "down_block_types": [
+ "DownEncoderBlock2D",
+ "DownEncoderBlock2D",
+ "DownEncoderBlock2D",
+ "DownEncoderBlock2D",
+ ],
+ "down_block_out_channels": [128, 256, 512, 512],
+ "layers_per_down_block": 2,
+ "up_block_types": [
+ "UpDecoderBlock2D",
+ "UpDecoderBlock2D",
+ "UpDecoderBlock2D",
+ "UpDecoderBlock2D",
+ ],
+ "up_block_out_channels": [192, 384, 768, 768],
+ "layers_per_up_block": 3,
+ "act_fn": "silu",
+ "latent_channels": 4,
+ "norm_num_groups": 32,
+ "sample_size": 256,
+ "scaling_factor": 0.18215,
+}
+
+ASYMMETRIC_AUTOENCODER_KL_x_2_CONFIG = {
+ "in_channels": 3,
+ "out_channels": 3,
+ "down_block_types": [
+ "DownEncoderBlock2D",
+ "DownEncoderBlock2D",
+ "DownEncoderBlock2D",
+ "DownEncoderBlock2D",
+ ],
+ "down_block_out_channels": [128, 256, 512, 512],
+ "layers_per_down_block": 2,
+ "up_block_types": [
+ "UpDecoderBlock2D",
+ "UpDecoderBlock2D",
+ "UpDecoderBlock2D",
+ "UpDecoderBlock2D",
+ ],
+ "up_block_out_channels": [256, 512, 1024, 1024],
+ "layers_per_up_block": 5,
+ "act_fn": "silu",
+ "latent_channels": 4,
+ "norm_num_groups": 32,
+ "sample_size": 256,
+ "scaling_factor": 0.18215,
+}
+
+
+def convert_asymmetric_autoencoder_kl_state_dict(original_state_dict: Dict[str, Any]) -> Dict[str, Any]:
+ converted_state_dict = {}
+ for k, v in original_state_dict.items():
+ if k.startswith("encoder."):
+ converted_state_dict[
+ k.replace("encoder.down.", "encoder.down_blocks.")
+ .replace("encoder.mid.", "encoder.mid_block.")
+ .replace("encoder.norm_out.", "encoder.conv_norm_out.")
+ .replace(".downsample.", ".downsamplers.0.")
+ .replace(".nin_shortcut.", ".conv_shortcut.")
+ .replace(".block.", ".resnets.")
+ .replace(".block_1.", ".resnets.0.")
+ .replace(".block_2.", ".resnets.1.")
+ .replace(".attn_1.k.", ".attentions.0.to_k.")
+ .replace(".attn_1.q.", ".attentions.0.to_q.")
+ .replace(".attn_1.v.", ".attentions.0.to_v.")
+ .replace(".attn_1.proj_out.", ".attentions.0.to_out.0.")
+ .replace(".attn_1.norm.", ".attentions.0.group_norm.")
+ ] = v
+ elif k.startswith("decoder.") and "up_layers" not in k:
+ converted_state_dict[
+ k.replace("decoder.encoder.", "decoder.condition_encoder.")
+ .replace(".norm_out.", ".conv_norm_out.")
+ .replace(".up.0.", ".up_blocks.3.")
+ .replace(".up.1.", ".up_blocks.2.")
+ .replace(".up.2.", ".up_blocks.1.")
+ .replace(".up.3.", ".up_blocks.0.")
+ .replace(".block.", ".resnets.")
+ .replace("mid", "mid_block")
+ .replace(".0.upsample.", ".0.upsamplers.0.")
+ .replace(".1.upsample.", ".1.upsamplers.0.")
+ .replace(".2.upsample.", ".2.upsamplers.0.")
+ .replace(".nin_shortcut.", ".conv_shortcut.")
+ .replace(".block_1.", ".resnets.0.")
+ .replace(".block_2.", ".resnets.1.")
+ .replace(".attn_1.k.", ".attentions.0.to_k.")
+ .replace(".attn_1.q.", ".attentions.0.to_q.")
+ .replace(".attn_1.v.", ".attentions.0.to_v.")
+ .replace(".attn_1.proj_out.", ".attentions.0.to_out.0.")
+ .replace(".attn_1.norm.", ".attentions.0.group_norm.")
+ ] = v
+ elif k.startswith("quant_conv."):
+ converted_state_dict[k] = v
+ elif k.startswith("post_quant_conv."):
+ converted_state_dict[k] = v
+ else:
+ print(f" skipping key `{k}`")
+ # fix weights shape
+ for k, v in converted_state_dict.items():
+ if (
+ (k.startswith("encoder.mid_block.attentions.0") or k.startswith("decoder.mid_block.attentions.0"))
+ and k.endswith("weight")
+ and ("to_q" in k or "to_k" in k or "to_v" in k or "to_out" in k)
+ ):
+ converted_state_dict[k] = converted_state_dict[k][:, :, 0, 0]
+
+ return converted_state_dict
+
+
+def get_asymmetric_autoencoder_kl_from_original_checkpoint(
+ scale: Literal["1.5", "2"], original_checkpoint_path: str, map_location: torch.device
+) -> AsymmetricAutoencoderKL:
+ print("Loading original state_dict")
+ original_state_dict = torch.load(original_checkpoint_path, map_location=map_location)
+ original_state_dict = original_state_dict["state_dict"]
+ print("Converting state_dict")
+ converted_state_dict = convert_asymmetric_autoencoder_kl_state_dict(original_state_dict)
+ kwargs = ASYMMETRIC_AUTOENCODER_KL_x_1_5_CONFIG if scale == "1.5" else ASYMMETRIC_AUTOENCODER_KL_x_2_CONFIG
+ print("Initializing AsymmetricAutoencoderKL model")
+ asymmetric_autoencoder_kl = AsymmetricAutoencoderKL(**kwargs)
+ print("Loading weight from converted state_dict")
+ asymmetric_autoencoder_kl.load_state_dict(converted_state_dict)
+ asymmetric_autoencoder_kl.eval()
+ print("AsymmetricAutoencoderKL successfully initialized")
+ return asymmetric_autoencoder_kl
+
+
+if __name__ == "__main__":
+ start = time.time()
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--scale",
+ default=None,
+ type=str,
+ required=True,
+ help="Asymmetric VQGAN scale: `1.5` or `2`",
+ )
+ parser.add_argument(
+ "--original_checkpoint_path",
+ default=None,
+ type=str,
+ required=True,
+ help="Path to the original Asymmetric VQGAN checkpoint",
+ )
+ parser.add_argument(
+ "--output_path",
+ default=None,
+ type=str,
+ required=True,
+ help="Path to save pretrained AsymmetricAutoencoderKL model",
+ )
+ parser.add_argument(
+ "--map_location",
+ default="cpu",
+ type=str,
+ required=False,
+ help="The device passed to `map_location` when loading the checkpoint",
+ )
+ args = parser.parse_args()
+
+ assert args.scale in ["1.5", "2"], f"{args.scale} should be `1.5` of `2`"
+ assert Path(args.original_checkpoint_path).is_file()
+
+ asymmetric_autoencoder_kl = get_asymmetric_autoencoder_kl_from_original_checkpoint(
+ scale=args.scale,
+ original_checkpoint_path=args.original_checkpoint_path,
+ map_location=torch.device(args.map_location),
+ )
+ print("Saving pretrained AsymmetricAutoencoderKL")
+ asymmetric_autoencoder_kl.save_pretrained(args.output_path)
+ print(f"Done in {time.time() - start:.2f} seconds")
diff --git a/diffusers/scripts/convert_aura_flow_to_diffusers.py b/diffusers/scripts/convert_aura_flow_to_diffusers.py
new file mode 100644
index 0000000000000000000000000000000000000000..74c34f4851ff203422275b237cb03423a17d54a1
--- /dev/null
+++ b/diffusers/scripts/convert_aura_flow_to_diffusers.py
@@ -0,0 +1,131 @@
+import argparse
+
+import torch
+from huggingface_hub import hf_hub_download
+
+from diffusers.models.transformers.auraflow_transformer_2d import AuraFlowTransformer2DModel
+
+
+def load_original_state_dict(args):
+ model_pt = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename="aura_diffusion_pytorch_model.bin")
+ state_dict = torch.load(model_pt, map_location="cpu")
+ return state_dict
+
+
+def calculate_layers(state_dict_keys, key_prefix):
+ dit_layers = set()
+ for k in state_dict_keys:
+ if key_prefix in k:
+ dit_layers.add(int(k.split(".")[2]))
+ print(f"{key_prefix}: {len(dit_layers)}")
+ return len(dit_layers)
+
+
+# similar to SD3 but only for the last norm layer
+def swap_scale_shift(weight, dim):
+ shift, scale = weight.chunk(2, dim=0)
+ new_weight = torch.cat([scale, shift], dim=0)
+ return new_weight
+
+
+def convert_transformer(state_dict):
+ converted_state_dict = {}
+ state_dict_keys = list(state_dict.keys())
+
+ converted_state_dict["register_tokens"] = state_dict.pop("model.register_tokens")
+ converted_state_dict["pos_embed.pos_embed"] = state_dict.pop("model.positional_encoding")
+ converted_state_dict["pos_embed.proj.weight"] = state_dict.pop("model.init_x_linear.weight")
+ converted_state_dict["pos_embed.proj.bias"] = state_dict.pop("model.init_x_linear.bias")
+
+ converted_state_dict["time_step_proj.linear_1.weight"] = state_dict.pop("model.t_embedder.mlp.0.weight")
+ converted_state_dict["time_step_proj.linear_1.bias"] = state_dict.pop("model.t_embedder.mlp.0.bias")
+ converted_state_dict["time_step_proj.linear_2.weight"] = state_dict.pop("model.t_embedder.mlp.2.weight")
+ converted_state_dict["time_step_proj.linear_2.bias"] = state_dict.pop("model.t_embedder.mlp.2.bias")
+
+ converted_state_dict["context_embedder.weight"] = state_dict.pop("model.cond_seq_linear.weight")
+
+ mmdit_layers = calculate_layers(state_dict_keys, key_prefix="double_layers")
+ single_dit_layers = calculate_layers(state_dict_keys, key_prefix="single_layers")
+
+ # MMDiT blocks 🎸.
+ for i in range(mmdit_layers):
+ # feed-forward
+ path_mapping = {"mlpX": "ff", "mlpC": "ff_context"}
+ weight_mapping = {"c_fc1": "linear_1", "c_fc2": "linear_2", "c_proj": "out_projection"}
+ for orig_k, diffuser_k in path_mapping.items():
+ for k, v in weight_mapping.items():
+ converted_state_dict[f"joint_transformer_blocks.{i}.{diffuser_k}.{v}.weight"] = state_dict.pop(
+ f"model.double_layers.{i}.{orig_k}.{k}.weight"
+ )
+
+ # norms
+ path_mapping = {"modX": "norm1", "modC": "norm1_context"}
+ for orig_k, diffuser_k in path_mapping.items():
+ converted_state_dict[f"joint_transformer_blocks.{i}.{diffuser_k}.linear.weight"] = state_dict.pop(
+ f"model.double_layers.{i}.{orig_k}.1.weight"
+ )
+
+ # attns
+ x_attn_mapping = {"w2q": "to_q", "w2k": "to_k", "w2v": "to_v", "w2o": "to_out.0"}
+ context_attn_mapping = {"w1q": "add_q_proj", "w1k": "add_k_proj", "w1v": "add_v_proj", "w1o": "to_add_out"}
+ for attn_mapping in [x_attn_mapping, context_attn_mapping]:
+ for k, v in attn_mapping.items():
+ converted_state_dict[f"joint_transformer_blocks.{i}.attn.{v}.weight"] = state_dict.pop(
+ f"model.double_layers.{i}.attn.{k}.weight"
+ )
+
+ # Single-DiT blocks.
+ for i in range(single_dit_layers):
+ # feed-forward
+ mapping = {"c_fc1": "linear_1", "c_fc2": "linear_2", "c_proj": "out_projection"}
+ for k, v in mapping.items():
+ converted_state_dict[f"single_transformer_blocks.{i}.ff.{v}.weight"] = state_dict.pop(
+ f"model.single_layers.{i}.mlp.{k}.weight"
+ )
+
+ # norms
+ converted_state_dict[f"single_transformer_blocks.{i}.norm1.linear.weight"] = state_dict.pop(
+ f"model.single_layers.{i}.modCX.1.weight"
+ )
+
+ # attns
+ x_attn_mapping = {"w1q": "to_q", "w1k": "to_k", "w1v": "to_v", "w1o": "to_out.0"}
+ for k, v in x_attn_mapping.items():
+ converted_state_dict[f"single_transformer_blocks.{i}.attn.{v}.weight"] = state_dict.pop(
+ f"model.single_layers.{i}.attn.{k}.weight"
+ )
+
+ # Final blocks.
+ converted_state_dict["proj_out.weight"] = state_dict.pop("model.final_linear.weight")
+ converted_state_dict["norm_out.linear.weight"] = swap_scale_shift(state_dict.pop("model.modF.1.weight"), dim=None)
+
+ return converted_state_dict
+
+
+@torch.no_grad()
+def populate_state_dict(args):
+ original_state_dict = load_original_state_dict(args)
+ state_dict_keys = list(original_state_dict.keys())
+ mmdit_layers = calculate_layers(state_dict_keys, key_prefix="double_layers")
+ single_dit_layers = calculate_layers(state_dict_keys, key_prefix="single_layers")
+
+ converted_state_dict = convert_transformer(original_state_dict)
+ model_diffusers = AuraFlowTransformer2DModel(
+ num_mmdit_layers=mmdit_layers, num_single_dit_layers=single_dit_layers
+ )
+ model_diffusers.load_state_dict(converted_state_dict, strict=True)
+
+ return model_diffusers
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--original_state_dict_repo_id", default="AuraDiffusion/auradiffusion-v0.1a0", type=str)
+ parser.add_argument("--dump_path", default="aura-flow", type=str)
+ parser.add_argument("--hub_id", default=None, type=str)
+ args = parser.parse_args()
+
+ model_diffusers = populate_state_dict(args)
+ model_diffusers.save_pretrained(args.dump_path)
+ if args.hub_id is not None:
+ model_diffusers.push_to_hub(args.hub_id)
diff --git a/diffusers/scripts/convert_blipdiffusion_to_diffusers.py b/diffusers/scripts/convert_blipdiffusion_to_diffusers.py
new file mode 100644
index 0000000000000000000000000000000000000000..03cf67e5476bd43260d2829767fffc220a7801c1
--- /dev/null
+++ b/diffusers/scripts/convert_blipdiffusion_to_diffusers.py
@@ -0,0 +1,343 @@
+"""
+This script requires you to build `LAVIS` from source, since the pip version doesn't have BLIP Diffusion. Follow instructions here: https://github.com/salesforce/LAVIS/tree/main.
+"""
+
+import argparse
+import os
+import tempfile
+
+import torch
+from lavis.models import load_model_and_preprocess
+from transformers import CLIPTokenizer
+from transformers.models.blip_2.configuration_blip_2 import Blip2Config
+
+from diffusers import (
+ AutoencoderKL,
+ PNDMScheduler,
+ UNet2DConditionModel,
+)
+from diffusers.pipelines import BlipDiffusionPipeline
+from diffusers.pipelines.blip_diffusion.blip_image_processing import BlipImageProcessor
+from diffusers.pipelines.blip_diffusion.modeling_blip2 import Blip2QFormerModel
+from diffusers.pipelines.blip_diffusion.modeling_ctx_clip import ContextCLIPTextModel
+
+
+BLIP2_CONFIG = {
+ "vision_config": {
+ "hidden_size": 1024,
+ "num_hidden_layers": 23,
+ "num_attention_heads": 16,
+ "image_size": 224,
+ "patch_size": 14,
+ "intermediate_size": 4096,
+ "hidden_act": "quick_gelu",
+ },
+ "qformer_config": {
+ "cross_attention_frequency": 1,
+ "encoder_hidden_size": 1024,
+ "vocab_size": 30523,
+ },
+ "num_query_tokens": 16,
+}
+blip2config = Blip2Config(**BLIP2_CONFIG)
+
+
+def qformer_model_from_original_config():
+ qformer = Blip2QFormerModel(blip2config)
+ return qformer
+
+
+def embeddings_from_original_checkpoint(model, diffuser_embeddings_prefix, original_embeddings_prefix):
+ embeddings = {}
+ embeddings.update(
+ {
+ f"{diffuser_embeddings_prefix}.word_embeddings.weight": model[
+ f"{original_embeddings_prefix}.word_embeddings.weight"
+ ]
+ }
+ )
+ embeddings.update(
+ {
+ f"{diffuser_embeddings_prefix}.position_embeddings.weight": model[
+ f"{original_embeddings_prefix}.position_embeddings.weight"
+ ]
+ }
+ )
+ embeddings.update(
+ {f"{diffuser_embeddings_prefix}.LayerNorm.weight": model[f"{original_embeddings_prefix}.LayerNorm.weight"]}
+ )
+ embeddings.update(
+ {f"{diffuser_embeddings_prefix}.LayerNorm.bias": model[f"{original_embeddings_prefix}.LayerNorm.bias"]}
+ )
+ return embeddings
+
+
+def proj_layer_from_original_checkpoint(model, diffuser_proj_prefix, original_proj_prefix):
+ proj_layer = {}
+ proj_layer.update({f"{diffuser_proj_prefix}.dense1.weight": model[f"{original_proj_prefix}.dense1.weight"]})
+ proj_layer.update({f"{diffuser_proj_prefix}.dense1.bias": model[f"{original_proj_prefix}.dense1.bias"]})
+ proj_layer.update({f"{diffuser_proj_prefix}.dense2.weight": model[f"{original_proj_prefix}.dense2.weight"]})
+ proj_layer.update({f"{diffuser_proj_prefix}.dense2.bias": model[f"{original_proj_prefix}.dense2.bias"]})
+ proj_layer.update({f"{diffuser_proj_prefix}.LayerNorm.weight": model[f"{original_proj_prefix}.LayerNorm.weight"]})
+ proj_layer.update({f"{diffuser_proj_prefix}.LayerNorm.bias": model[f"{original_proj_prefix}.LayerNorm.bias"]})
+ return proj_layer
+
+
+def attention_from_original_checkpoint(model, diffuser_attention_prefix, original_attention_prefix):
+ attention = {}
+ attention.update(
+ {
+ f"{diffuser_attention_prefix}.attention.query.weight": model[
+ f"{original_attention_prefix}.self.query.weight"
+ ]
+ }
+ )
+ attention.update(
+ {f"{diffuser_attention_prefix}.attention.query.bias": model[f"{original_attention_prefix}.self.query.bias"]}
+ )
+ attention.update(
+ {f"{diffuser_attention_prefix}.attention.key.weight": model[f"{original_attention_prefix}.self.key.weight"]}
+ )
+ attention.update(
+ {f"{diffuser_attention_prefix}.attention.key.bias": model[f"{original_attention_prefix}.self.key.bias"]}
+ )
+ attention.update(
+ {
+ f"{diffuser_attention_prefix}.attention.value.weight": model[
+ f"{original_attention_prefix}.self.value.weight"
+ ]
+ }
+ )
+ attention.update(
+ {f"{diffuser_attention_prefix}.attention.value.bias": model[f"{original_attention_prefix}.self.value.bias"]}
+ )
+ attention.update(
+ {f"{diffuser_attention_prefix}.output.dense.weight": model[f"{original_attention_prefix}.output.dense.weight"]}
+ )
+ attention.update(
+ {f"{diffuser_attention_prefix}.output.dense.bias": model[f"{original_attention_prefix}.output.dense.bias"]}
+ )
+ attention.update(
+ {
+ f"{diffuser_attention_prefix}.output.LayerNorm.weight": model[
+ f"{original_attention_prefix}.output.LayerNorm.weight"
+ ]
+ }
+ )
+ attention.update(
+ {
+ f"{diffuser_attention_prefix}.output.LayerNorm.bias": model[
+ f"{original_attention_prefix}.output.LayerNorm.bias"
+ ]
+ }
+ )
+ return attention
+
+
+def output_layers_from_original_checkpoint(model, diffuser_output_prefix, original_output_prefix):
+ output_layers = {}
+ output_layers.update({f"{diffuser_output_prefix}.dense.weight": model[f"{original_output_prefix}.dense.weight"]})
+ output_layers.update({f"{diffuser_output_prefix}.dense.bias": model[f"{original_output_prefix}.dense.bias"]})
+ output_layers.update(
+ {f"{diffuser_output_prefix}.LayerNorm.weight": model[f"{original_output_prefix}.LayerNorm.weight"]}
+ )
+ output_layers.update(
+ {f"{diffuser_output_prefix}.LayerNorm.bias": model[f"{original_output_prefix}.LayerNorm.bias"]}
+ )
+ return output_layers
+
+
+def encoder_from_original_checkpoint(model, diffuser_encoder_prefix, original_encoder_prefix):
+ encoder = {}
+ for i in range(blip2config.qformer_config.num_hidden_layers):
+ encoder.update(
+ attention_from_original_checkpoint(
+ model, f"{diffuser_encoder_prefix}.{i}.attention", f"{original_encoder_prefix}.{i}.attention"
+ )
+ )
+ encoder.update(
+ attention_from_original_checkpoint(
+ model, f"{diffuser_encoder_prefix}.{i}.crossattention", f"{original_encoder_prefix}.{i}.crossattention"
+ )
+ )
+
+ encoder.update(
+ {
+ f"{diffuser_encoder_prefix}.{i}.intermediate.dense.weight": model[
+ f"{original_encoder_prefix}.{i}.intermediate.dense.weight"
+ ]
+ }
+ )
+ encoder.update(
+ {
+ f"{diffuser_encoder_prefix}.{i}.intermediate.dense.bias": model[
+ f"{original_encoder_prefix}.{i}.intermediate.dense.bias"
+ ]
+ }
+ )
+ encoder.update(
+ {
+ f"{diffuser_encoder_prefix}.{i}.intermediate_query.dense.weight": model[
+ f"{original_encoder_prefix}.{i}.intermediate_query.dense.weight"
+ ]
+ }
+ )
+ encoder.update(
+ {
+ f"{diffuser_encoder_prefix}.{i}.intermediate_query.dense.bias": model[
+ f"{original_encoder_prefix}.{i}.intermediate_query.dense.bias"
+ ]
+ }
+ )
+
+ encoder.update(
+ output_layers_from_original_checkpoint(
+ model, f"{diffuser_encoder_prefix}.{i}.output", f"{original_encoder_prefix}.{i}.output"
+ )
+ )
+ encoder.update(
+ output_layers_from_original_checkpoint(
+ model, f"{diffuser_encoder_prefix}.{i}.output_query", f"{original_encoder_prefix}.{i}.output_query"
+ )
+ )
+ return encoder
+
+
+def visual_encoder_layer_from_original_checkpoint(model, diffuser_prefix, original_prefix):
+ visual_encoder_layer = {}
+
+ visual_encoder_layer.update({f"{diffuser_prefix}.layer_norm1.weight": model[f"{original_prefix}.ln_1.weight"]})
+ visual_encoder_layer.update({f"{diffuser_prefix}.layer_norm1.bias": model[f"{original_prefix}.ln_1.bias"]})
+ visual_encoder_layer.update({f"{diffuser_prefix}.layer_norm2.weight": model[f"{original_prefix}.ln_2.weight"]})
+ visual_encoder_layer.update({f"{diffuser_prefix}.layer_norm2.bias": model[f"{original_prefix}.ln_2.bias"]})
+ visual_encoder_layer.update(
+ {f"{diffuser_prefix}.self_attn.qkv.weight": model[f"{original_prefix}.attn.in_proj_weight"]}
+ )
+ visual_encoder_layer.update(
+ {f"{diffuser_prefix}.self_attn.qkv.bias": model[f"{original_prefix}.attn.in_proj_bias"]}
+ )
+ visual_encoder_layer.update(
+ {f"{diffuser_prefix}.self_attn.projection.weight": model[f"{original_prefix}.attn.out_proj.weight"]}
+ )
+ visual_encoder_layer.update(
+ {f"{diffuser_prefix}.self_attn.projection.bias": model[f"{original_prefix}.attn.out_proj.bias"]}
+ )
+ visual_encoder_layer.update({f"{diffuser_prefix}.mlp.fc1.weight": model[f"{original_prefix}.mlp.c_fc.weight"]})
+ visual_encoder_layer.update({f"{diffuser_prefix}.mlp.fc1.bias": model[f"{original_prefix}.mlp.c_fc.bias"]})
+ visual_encoder_layer.update({f"{diffuser_prefix}.mlp.fc2.weight": model[f"{original_prefix}.mlp.c_proj.weight"]})
+ visual_encoder_layer.update({f"{diffuser_prefix}.mlp.fc2.bias": model[f"{original_prefix}.mlp.c_proj.bias"]})
+
+ return visual_encoder_layer
+
+
+def visual_encoder_from_original_checkpoint(model, diffuser_prefix, original_prefix):
+ visual_encoder = {}
+
+ visual_encoder.update(
+ {
+ f"{diffuser_prefix}.embeddings.class_embedding": model[f"{original_prefix}.class_embedding"]
+ .unsqueeze(0)
+ .unsqueeze(0)
+ }
+ )
+ visual_encoder.update(
+ {
+ f"{diffuser_prefix}.embeddings.position_embedding": model[
+ f"{original_prefix}.positional_embedding"
+ ].unsqueeze(0)
+ }
+ )
+ visual_encoder.update(
+ {f"{diffuser_prefix}.embeddings.patch_embedding.weight": model[f"{original_prefix}.conv1.weight"]}
+ )
+ visual_encoder.update({f"{diffuser_prefix}.pre_layernorm.weight": model[f"{original_prefix}.ln_pre.weight"]})
+ visual_encoder.update({f"{diffuser_prefix}.pre_layernorm.bias": model[f"{original_prefix}.ln_pre.bias"]})
+
+ for i in range(blip2config.vision_config.num_hidden_layers):
+ visual_encoder.update(
+ visual_encoder_layer_from_original_checkpoint(
+ model, f"{diffuser_prefix}.encoder.layers.{i}", f"{original_prefix}.transformer.resblocks.{i}"
+ )
+ )
+
+ visual_encoder.update({f"{diffuser_prefix}.post_layernorm.weight": model["blip.ln_vision.weight"]})
+ visual_encoder.update({f"{diffuser_prefix}.post_layernorm.bias": model["blip.ln_vision.bias"]})
+
+ return visual_encoder
+
+
+def qformer_original_checkpoint_to_diffusers_checkpoint(model):
+ qformer_checkpoint = {}
+ qformer_checkpoint.update(embeddings_from_original_checkpoint(model, "embeddings", "blip.Qformer.bert.embeddings"))
+ qformer_checkpoint.update({"query_tokens": model["blip.query_tokens"]})
+ qformer_checkpoint.update(proj_layer_from_original_checkpoint(model, "proj_layer", "proj_layer"))
+ qformer_checkpoint.update(
+ encoder_from_original_checkpoint(model, "encoder.layer", "blip.Qformer.bert.encoder.layer")
+ )
+ qformer_checkpoint.update(visual_encoder_from_original_checkpoint(model, "visual_encoder", "blip.visual_encoder"))
+ return qformer_checkpoint
+
+
+def get_qformer(model):
+ print("loading qformer")
+
+ qformer = qformer_model_from_original_config()
+ qformer_diffusers_checkpoint = qformer_original_checkpoint_to_diffusers_checkpoint(model)
+
+ load_checkpoint_to_model(qformer_diffusers_checkpoint, qformer)
+
+ print("done loading qformer")
+ return qformer
+
+
+def load_checkpoint_to_model(checkpoint, model):
+ with tempfile.NamedTemporaryFile(delete=False) as file:
+ torch.save(checkpoint, file.name)
+ del checkpoint
+ model.load_state_dict(torch.load(file.name), strict=False)
+
+ os.remove(file.name)
+
+
+def save_blip_diffusion_model(model, args):
+ qformer = get_qformer(model)
+ qformer.eval()
+
+ text_encoder = ContextCLIPTextModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="text_encoder")
+ vae = AutoencoderKL.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="vae")
+
+ unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet")
+ vae.eval()
+ text_encoder.eval()
+ scheduler = PNDMScheduler(
+ beta_start=0.00085,
+ beta_end=0.012,
+ beta_schedule="scaled_linear",
+ set_alpha_to_one=False,
+ skip_prk_steps=True,
+ )
+ tokenizer = CLIPTokenizer.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="tokenizer")
+ image_processor = BlipImageProcessor()
+ blip_diffusion = BlipDiffusionPipeline(
+ tokenizer=tokenizer,
+ text_encoder=text_encoder,
+ vae=vae,
+ unet=unet,
+ scheduler=scheduler,
+ qformer=qformer,
+ image_processor=image_processor,
+ )
+ blip_diffusion.save_pretrained(args.checkpoint_path)
+
+
+def main(args):
+ model, _, _ = load_model_and_preprocess("blip_diffusion", "base", device="cpu", is_eval=True)
+ save_blip_diffusion_model(model.state_dict(), args)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--checkpoint_path", default=None, type=str, required=True, help="Path to the output model.")
+ args = parser.parse_args()
+
+ main(args)
diff --git a/diffusers/scripts/convert_cogvideox_to_diffusers.py b/diffusers/scripts/convert_cogvideox_to_diffusers.py
new file mode 100644
index 0000000000000000000000000000000000000000..4343eaf34038dc67faf905fcba5c57917ece208b
--- /dev/null
+++ b/diffusers/scripts/convert_cogvideox_to_diffusers.py
@@ -0,0 +1,290 @@
+import argparse
+from typing import Any, Dict
+
+import torch
+from transformers import T5EncoderModel, T5Tokenizer
+
+from diffusers import (
+ AutoencoderKLCogVideoX,
+ CogVideoXDDIMScheduler,
+ CogVideoXImageToVideoPipeline,
+ CogVideoXPipeline,
+ CogVideoXTransformer3DModel,
+)
+
+
+def reassign_query_key_value_inplace(key: str, state_dict: Dict[str, Any]):
+ to_q_key = key.replace("query_key_value", "to_q")
+ to_k_key = key.replace("query_key_value", "to_k")
+ to_v_key = key.replace("query_key_value", "to_v")
+ to_q, to_k, to_v = torch.chunk(state_dict[key], chunks=3, dim=0)
+ state_dict[to_q_key] = to_q
+ state_dict[to_k_key] = to_k
+ state_dict[to_v_key] = to_v
+ state_dict.pop(key)
+
+
+def reassign_query_key_layernorm_inplace(key: str, state_dict: Dict[str, Any]):
+ layer_id, weight_or_bias = key.split(".")[-2:]
+
+ if "query" in key:
+ new_key = f"transformer_blocks.{layer_id}.attn1.norm_q.{weight_or_bias}"
+ elif "key" in key:
+ new_key = f"transformer_blocks.{layer_id}.attn1.norm_k.{weight_or_bias}"
+
+ state_dict[new_key] = state_dict.pop(key)
+
+
+def reassign_adaln_norm_inplace(key: str, state_dict: Dict[str, Any]):
+ layer_id, _, weight_or_bias = key.split(".")[-3:]
+
+ weights_or_biases = state_dict[key].chunk(12, dim=0)
+ norm1_weights_or_biases = torch.cat(weights_or_biases[0:3] + weights_or_biases[6:9])
+ norm2_weights_or_biases = torch.cat(weights_or_biases[3:6] + weights_or_biases[9:12])
+
+ norm1_key = f"transformer_blocks.{layer_id}.norm1.linear.{weight_or_bias}"
+ state_dict[norm1_key] = norm1_weights_or_biases
+
+ norm2_key = f"transformer_blocks.{layer_id}.norm2.linear.{weight_or_bias}"
+ state_dict[norm2_key] = norm2_weights_or_biases
+
+ state_dict.pop(key)
+
+
+def remove_keys_inplace(key: str, state_dict: Dict[str, Any]):
+ state_dict.pop(key)
+
+
+def replace_up_keys_inplace(key: str, state_dict: Dict[str, Any]):
+ key_split = key.split(".")
+ layer_index = int(key_split[2])
+ replace_layer_index = 4 - 1 - layer_index
+
+ key_split[1] = "up_blocks"
+ key_split[2] = str(replace_layer_index)
+ new_key = ".".join(key_split)
+
+ state_dict[new_key] = state_dict.pop(key)
+
+
+TRANSFORMER_KEYS_RENAME_DICT = {
+ "transformer.final_layernorm": "norm_final",
+ "transformer": "transformer_blocks",
+ "attention": "attn1",
+ "mlp": "ff.net",
+ "dense_h_to_4h": "0.proj",
+ "dense_4h_to_h": "2",
+ ".layers": "",
+ "dense": "to_out.0",
+ "input_layernorm": "norm1.norm",
+ "post_attn1_layernorm": "norm2.norm",
+ "time_embed.0": "time_embedding.linear_1",
+ "time_embed.2": "time_embedding.linear_2",
+ "mixins.patch_embed": "patch_embed",
+ "mixins.final_layer.norm_final": "norm_out.norm",
+ "mixins.final_layer.linear": "proj_out",
+ "mixins.final_layer.adaLN_modulation.1": "norm_out.linear",
+ "mixins.pos_embed.pos_embedding": "patch_embed.pos_embedding", # Specific to CogVideoX-5b-I2V
+}
+
+TRANSFORMER_SPECIAL_KEYS_REMAP = {
+ "query_key_value": reassign_query_key_value_inplace,
+ "query_layernorm_list": reassign_query_key_layernorm_inplace,
+ "key_layernorm_list": reassign_query_key_layernorm_inplace,
+ "adaln_layer.adaLN_modulations": reassign_adaln_norm_inplace,
+ "embed_tokens": remove_keys_inplace,
+ "freqs_sin": remove_keys_inplace,
+ "freqs_cos": remove_keys_inplace,
+ "position_embedding": remove_keys_inplace,
+}
+
+VAE_KEYS_RENAME_DICT = {
+ "block.": "resnets.",
+ "down.": "down_blocks.",
+ "downsample": "downsamplers.0",
+ "upsample": "upsamplers.0",
+ "nin_shortcut": "conv_shortcut",
+ "encoder.mid.block_1": "encoder.mid_block.resnets.0",
+ "encoder.mid.block_2": "encoder.mid_block.resnets.1",
+ "decoder.mid.block_1": "decoder.mid_block.resnets.0",
+ "decoder.mid.block_2": "decoder.mid_block.resnets.1",
+}
+
+VAE_SPECIAL_KEYS_REMAP = {
+ "loss": remove_keys_inplace,
+ "up.": replace_up_keys_inplace,
+}
+
+TOKENIZER_MAX_LENGTH = 226
+
+
+def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]:
+ state_dict = saved_dict
+ if "model" in saved_dict.keys():
+ state_dict = state_dict["model"]
+ if "module" in saved_dict.keys():
+ state_dict = state_dict["module"]
+ if "state_dict" in saved_dict.keys():
+ state_dict = state_dict["state_dict"]
+ return state_dict
+
+
+def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]:
+ state_dict[new_key] = state_dict.pop(old_key)
+
+
+def convert_transformer(
+ ckpt_path: str,
+ num_layers: int,
+ num_attention_heads: int,
+ use_rotary_positional_embeddings: bool,
+ i2v: bool,
+ dtype: torch.dtype,
+):
+ PREFIX_KEY = "model.diffusion_model."
+
+ original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True))
+ transformer = CogVideoXTransformer3DModel(
+ in_channels=32 if i2v else 16,
+ num_layers=num_layers,
+ num_attention_heads=num_attention_heads,
+ use_rotary_positional_embeddings=use_rotary_positional_embeddings,
+ use_learned_positional_embeddings=i2v,
+ ).to(dtype=dtype)
+
+ for key in list(original_state_dict.keys()):
+ new_key = key[len(PREFIX_KEY) :]
+ for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
+ new_key = new_key.replace(replace_key, rename_key)
+ update_state_dict_inplace(original_state_dict, key, new_key)
+
+ for key in list(original_state_dict.keys()):
+ for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
+ if special_key not in key:
+ continue
+ handler_fn_inplace(key, original_state_dict)
+ transformer.load_state_dict(original_state_dict, strict=True)
+ return transformer
+
+
+def convert_vae(ckpt_path: str, scaling_factor: float, dtype: torch.dtype):
+ original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True))
+ vae = AutoencoderKLCogVideoX(scaling_factor=scaling_factor).to(dtype=dtype)
+
+ for key in list(original_state_dict.keys()):
+ new_key = key[:]
+ for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items():
+ new_key = new_key.replace(replace_key, rename_key)
+ update_state_dict_inplace(original_state_dict, key, new_key)
+
+ for key in list(original_state_dict.keys()):
+ for special_key, handler_fn_inplace in VAE_SPECIAL_KEYS_REMAP.items():
+ if special_key not in key:
+ continue
+ handler_fn_inplace(key, original_state_dict)
+
+ vae.load_state_dict(original_state_dict, strict=True)
+ return vae
+
+
+def get_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint"
+ )
+ parser.add_argument("--vae_ckpt_path", type=str, default=None, help="Path to original vae checkpoint")
+ parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
+ parser.add_argument("--fp16", action="store_true", default=False, help="Whether to save the model weights in fp16")
+ parser.add_argument("--bf16", action="store_true", default=False, help="Whether to save the model weights in bf16")
+ parser.add_argument(
+ "--push_to_hub", action="store_true", default=False, help="Whether to push to HF Hub after saving"
+ )
+ parser.add_argument(
+ "--text_encoder_cache_dir", type=str, default=None, help="Path to text encoder cache directory"
+ )
+ # For CogVideoX-2B, num_layers is 30. For 5B, it is 42
+ parser.add_argument("--num_layers", type=int, default=30, help="Number of transformer blocks")
+ # For CogVideoX-2B, num_attention_heads is 30. For 5B, it is 48
+ parser.add_argument("--num_attention_heads", type=int, default=30, help="Number of attention heads")
+ # For CogVideoX-2B, use_rotary_positional_embeddings is False. For 5B, it is True
+ parser.add_argument(
+ "--use_rotary_positional_embeddings", action="store_true", default=False, help="Whether to use RoPE or not"
+ )
+ # For CogVideoX-2B, scaling_factor is 1.15258426. For 5B, it is 0.7
+ parser.add_argument("--scaling_factor", type=float, default=1.15258426, help="Scaling factor in the VAE")
+ # For CogVideoX-2B, snr_shift_scale is 3.0. For 5B, it is 1.0
+ parser.add_argument("--snr_shift_scale", type=float, default=3.0, help="Scaling factor in the VAE")
+ parser.add_argument("--i2v", action="store_true", default=False, help="Whether to save the model weights in fp16")
+ return parser.parse_args()
+
+
+if __name__ == "__main__":
+ args = get_args()
+
+ transformer = None
+ vae = None
+
+ if args.fp16 and args.bf16:
+ raise ValueError("You cannot pass both --fp16 and --bf16 at the same time.")
+
+ dtype = torch.float16 if args.fp16 else torch.bfloat16 if args.bf16 else torch.float32
+
+ if args.transformer_ckpt_path is not None:
+ transformer = convert_transformer(
+ args.transformer_ckpt_path,
+ args.num_layers,
+ args.num_attention_heads,
+ args.use_rotary_positional_embeddings,
+ args.i2v,
+ dtype,
+ )
+ if args.vae_ckpt_path is not None:
+ vae = convert_vae(args.vae_ckpt_path, args.scaling_factor, dtype)
+
+ text_encoder_id = "google/t5-v1_1-xxl"
+ tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH)
+ text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir)
+
+ # Apparently, the conversion does not work anymore without this :shrug:
+ for param in text_encoder.parameters():
+ param.data = param.data.contiguous()
+
+ scheduler = CogVideoXDDIMScheduler.from_config(
+ {
+ "snr_shift_scale": args.snr_shift_scale,
+ "beta_end": 0.012,
+ "beta_schedule": "scaled_linear",
+ "beta_start": 0.00085,
+ "clip_sample": False,
+ "num_train_timesteps": 1000,
+ "prediction_type": "v_prediction",
+ "rescale_betas_zero_snr": True,
+ "set_alpha_to_one": True,
+ "timestep_spacing": "trailing",
+ }
+ )
+ if args.i2v:
+ pipeline_cls = CogVideoXImageToVideoPipeline
+ else:
+ pipeline_cls = CogVideoXPipeline
+
+ pipe = pipeline_cls(
+ tokenizer=tokenizer,
+ text_encoder=text_encoder,
+ vae=vae,
+ transformer=transformer,
+ scheduler=scheduler,
+ )
+
+ if args.fp16:
+ pipe = pipe.to(dtype=torch.float16)
+ if args.bf16:
+ pipe = pipe.to(dtype=torch.bfloat16)
+
+ # We don't use variant here because the model must be run in fp16 (2B) or bf16 (5B). It would be weird
+ # for users to specify variant when the default is not fp32 and they want to run with the correct default (which
+ # is either fp16/bf16 here).
+
+ # This is necessary This is necessary for users with insufficient memory,
+ # such as those using Colab and notebooks, as it can save some memory used for model loading.
+ pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB", push_to_hub=args.push_to_hub)
diff --git a/diffusers/scripts/convert_consistency_decoder.py b/diffusers/scripts/convert_consistency_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..0cb5fc50dd60b2956ba9dd3c4044c49349387943
--- /dev/null
+++ b/diffusers/scripts/convert_consistency_decoder.py
@@ -0,0 +1,1128 @@
+import math
+import os
+import urllib
+import warnings
+from argparse import ArgumentParser
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from huggingface_hub.utils import insecure_hashlib
+from safetensors.torch import load_file as stl
+from tqdm import tqdm
+
+from diffusers import AutoencoderKL, ConsistencyDecoderVAE, DiffusionPipeline, StableDiffusionPipeline, UNet2DModel
+from diffusers.models.autoencoders.vae import Encoder
+from diffusers.models.embeddings import TimestepEmbedding
+from diffusers.models.unets.unet_2d_blocks import ResnetDownsampleBlock2D, ResnetUpsampleBlock2D, UNetMidBlock2D
+
+
+args = ArgumentParser()
+args.add_argument("--save_pretrained", required=False, default=None, type=str)
+args.add_argument("--test_image", required=True, type=str)
+args = args.parse_args()
+
+
+def _extract_into_tensor(arr, timesteps, broadcast_shape):
+ # from: https://github.com/openai/guided-diffusion/blob/22e0df8183507e13a7813f8d38d51b072ca1e67c/guided_diffusion/gaussian_diffusion.py#L895 """
+ res = arr[timesteps].float()
+ dims_to_append = len(broadcast_shape) - len(res.shape)
+ return res[(...,) + (None,) * dims_to_append]
+
+
+def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
+ # from: https://github.com/openai/guided-diffusion/blob/22e0df8183507e13a7813f8d38d51b072ca1e67c/guided_diffusion/gaussian_diffusion.py#L45
+ betas = []
+ for i in range(num_diffusion_timesteps):
+ t1 = i / num_diffusion_timesteps
+ t2 = (i + 1) / num_diffusion_timesteps
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
+ return torch.tensor(betas)
+
+
+def _download(url: str, root: str):
+ os.makedirs(root, exist_ok=True)
+ filename = os.path.basename(url)
+
+ expected_sha256 = url.split("/")[-2]
+ download_target = os.path.join(root, filename)
+
+ if os.path.exists(download_target) and not os.path.isfile(download_target):
+ raise RuntimeError(f"{download_target} exists and is not a regular file")
+
+ if os.path.isfile(download_target):
+ if insecure_hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
+ return download_target
+ else:
+ warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
+
+ with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
+ with tqdm(
+ total=int(source.info().get("Content-Length")),
+ ncols=80,
+ unit="iB",
+ unit_scale=True,
+ unit_divisor=1024,
+ ) as loop:
+ while True:
+ buffer = source.read(8192)
+ if not buffer:
+ break
+
+ output.write(buffer)
+ loop.update(len(buffer))
+
+ if insecure_hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
+ raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match")
+
+ return download_target
+
+
+class ConsistencyDecoder:
+ def __init__(self, device="cuda:0", download_root=os.path.expanduser("~/.cache/clip")):
+ self.n_distilled_steps = 64
+ download_target = _download(
+ "https://openaipublic.azureedge.net/diff-vae/c9cebd3132dd9c42936d803e33424145a748843c8f716c0814838bdc8a2fe7cb/decoder.pt",
+ download_root,
+ )
+ self.ckpt = torch.jit.load(download_target).to(device)
+ self.device = device
+ sigma_data = 0.5
+ betas = betas_for_alpha_bar(1024, lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2).to(device)
+ alphas = 1.0 - betas
+ alphas_cumprod = torch.cumprod(alphas, dim=0)
+ self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
+ self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
+ sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod)
+ sigmas = torch.sqrt(1.0 / alphas_cumprod - 1)
+ self.c_skip = sqrt_recip_alphas_cumprod * sigma_data**2 / (sigmas**2 + sigma_data**2)
+ self.c_out = sigmas * sigma_data / (sigmas**2 + sigma_data**2) ** 0.5
+ self.c_in = sqrt_recip_alphas_cumprod / (sigmas**2 + sigma_data**2) ** 0.5
+
+ @staticmethod
+ def round_timesteps(timesteps, total_timesteps, n_distilled_steps, truncate_start=True):
+ with torch.no_grad():
+ space = torch.div(total_timesteps, n_distilled_steps, rounding_mode="floor")
+ rounded_timesteps = (torch.div(timesteps, space, rounding_mode="floor") + 1) * space
+ if truncate_start:
+ rounded_timesteps[rounded_timesteps == total_timesteps] -= space
+ else:
+ rounded_timesteps[rounded_timesteps == total_timesteps] -= space
+ rounded_timesteps[rounded_timesteps == 0] += space
+ return rounded_timesteps
+
+ @staticmethod
+ def ldm_transform_latent(z, extra_scale_factor=1):
+ channel_means = [0.38862467, 0.02253063, 0.07381133, -0.0171294]
+ channel_stds = [0.9654121, 1.0440036, 0.76147926, 0.77022034]
+
+ if len(z.shape) != 4:
+ raise ValueError()
+
+ z = z * 0.18215
+ channels = [z[:, i] for i in range(z.shape[1])]
+
+ channels = [extra_scale_factor * (c - channel_means[i]) / channel_stds[i] for i, c in enumerate(channels)]
+ return torch.stack(channels, dim=1)
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ features: torch.Tensor,
+ schedule=[1.0, 0.5],
+ generator=None,
+ ):
+ features = self.ldm_transform_latent(features)
+ ts = self.round_timesteps(
+ torch.arange(0, 1024),
+ 1024,
+ self.n_distilled_steps,
+ truncate_start=False,
+ )
+ shape = (
+ features.size(0),
+ 3,
+ 8 * features.size(2),
+ 8 * features.size(3),
+ )
+ x_start = torch.zeros(shape, device=features.device, dtype=features.dtype)
+ schedule_timesteps = [int((1024 - 1) * s) for s in schedule]
+ for i in schedule_timesteps:
+ t = ts[i].item()
+ t_ = torch.tensor([t] * features.shape[0]).to(self.device)
+ # noise = torch.randn_like(x_start)
+ noise = torch.randn(x_start.shape, dtype=x_start.dtype, generator=generator).to(device=x_start.device)
+ x_start = (
+ _extract_into_tensor(self.sqrt_alphas_cumprod, t_, x_start.shape) * x_start
+ + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t_, x_start.shape) * noise
+ )
+ c_in = _extract_into_tensor(self.c_in, t_, x_start.shape)
+
+ import torch.nn.functional as F
+
+ from diffusers import UNet2DModel
+
+ if isinstance(self.ckpt, UNet2DModel):
+ input = torch.concat([c_in * x_start, F.upsample_nearest(features, scale_factor=8)], dim=1)
+ model_output = self.ckpt(input, t_).sample
+ else:
+ model_output = self.ckpt(c_in * x_start, t_, features=features)
+
+ B, C = x_start.shape[:2]
+ model_output, _ = torch.split(model_output, C, dim=1)
+ pred_xstart = (
+ _extract_into_tensor(self.c_out, t_, x_start.shape) * model_output
+ + _extract_into_tensor(self.c_skip, t_, x_start.shape) * x_start
+ ).clamp(-1, 1)
+ x_start = pred_xstart
+ return x_start
+
+
+def save_image(image, name):
+ import numpy as np
+ from PIL import Image
+
+ image = image[0].cpu().numpy()
+ image = (image + 1.0) * 127.5
+ image = image.clip(0, 255).astype(np.uint8)
+ image = Image.fromarray(image.transpose(1, 2, 0))
+ image.save(name)
+
+
+def load_image(uri, size=None, center_crop=False):
+ import numpy as np
+ from PIL import Image
+
+ image = Image.open(uri)
+ if center_crop:
+ image = image.crop(
+ (
+ (image.width - min(image.width, image.height)) // 2,
+ (image.height - min(image.width, image.height)) // 2,
+ (image.width + min(image.width, image.height)) // 2,
+ (image.height + min(image.width, image.height)) // 2,
+ )
+ )
+ if size is not None:
+ image = image.resize(size)
+ image = torch.tensor(np.array(image).transpose(2, 0, 1)).unsqueeze(0).float()
+ image = image / 127.5 - 1.0
+ return image
+
+
+class TimestepEmbedding_(nn.Module):
+ def __init__(self, n_time=1024, n_emb=320, n_out=1280) -> None:
+ super().__init__()
+ self.emb = nn.Embedding(n_time, n_emb)
+ self.f_1 = nn.Linear(n_emb, n_out)
+ self.f_2 = nn.Linear(n_out, n_out)
+
+ def forward(self, x) -> torch.Tensor:
+ x = self.emb(x)
+ x = self.f_1(x)
+ x = F.silu(x)
+ return self.f_2(x)
+
+
+class ImageEmbedding(nn.Module):
+ def __init__(self, in_channels=7, out_channels=320) -> None:
+ super().__init__()
+ self.f = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
+
+ def forward(self, x) -> torch.Tensor:
+ return self.f(x)
+
+
+class ImageUnembedding(nn.Module):
+ def __init__(self, in_channels=320, out_channels=6) -> None:
+ super().__init__()
+ self.gn = nn.GroupNorm(32, in_channels)
+ self.f = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
+
+ def forward(self, x) -> torch.Tensor:
+ return self.f(F.silu(self.gn(x)))
+
+
+class ConvResblock(nn.Module):
+ def __init__(self, in_features=320, out_features=320) -> None:
+ super().__init__()
+ self.f_t = nn.Linear(1280, out_features * 2)
+
+ self.gn_1 = nn.GroupNorm(32, in_features)
+ self.f_1 = nn.Conv2d(in_features, out_features, kernel_size=3, padding=1)
+
+ self.gn_2 = nn.GroupNorm(32, out_features)
+ self.f_2 = nn.Conv2d(out_features, out_features, kernel_size=3, padding=1)
+
+ skip_conv = in_features != out_features
+ self.f_s = nn.Conv2d(in_features, out_features, kernel_size=1, padding=0) if skip_conv else nn.Identity()
+
+ def forward(self, x, t):
+ x_skip = x
+ t = self.f_t(F.silu(t))
+ t = t.chunk(2, dim=1)
+ t_1 = t[0].unsqueeze(dim=2).unsqueeze(dim=3) + 1
+ t_2 = t[1].unsqueeze(dim=2).unsqueeze(dim=3)
+
+ gn_1 = F.silu(self.gn_1(x))
+ f_1 = self.f_1(gn_1)
+
+ gn_2 = self.gn_2(f_1)
+
+ return self.f_s(x_skip) + self.f_2(F.silu(gn_2 * t_1 + t_2))
+
+
+# Also ConvResblock
+class Downsample(nn.Module):
+ def __init__(self, in_channels=320) -> None:
+ super().__init__()
+ self.f_t = nn.Linear(1280, in_channels * 2)
+
+ self.gn_1 = nn.GroupNorm(32, in_channels)
+ self.f_1 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
+ self.gn_2 = nn.GroupNorm(32, in_channels)
+
+ self.f_2 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
+
+ def forward(self, x, t) -> torch.Tensor:
+ x_skip = x
+
+ t = self.f_t(F.silu(t))
+ t_1, t_2 = t.chunk(2, dim=1)
+ t_1 = t_1.unsqueeze(2).unsqueeze(3) + 1
+ t_2 = t_2.unsqueeze(2).unsqueeze(3)
+
+ gn_1 = F.silu(self.gn_1(x))
+ avg_pool2d = F.avg_pool2d(gn_1, kernel_size=(2, 2), stride=None)
+
+ f_1 = self.f_1(avg_pool2d)
+ gn_2 = self.gn_2(f_1)
+
+ f_2 = self.f_2(F.silu(t_2 + (t_1 * gn_2)))
+
+ return f_2 + F.avg_pool2d(x_skip, kernel_size=(2, 2), stride=None)
+
+
+# Also ConvResblock
+class Upsample(nn.Module):
+ def __init__(self, in_channels=1024) -> None:
+ super().__init__()
+ self.f_t = nn.Linear(1280, in_channels * 2)
+
+ self.gn_1 = nn.GroupNorm(32, in_channels)
+ self.f_1 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
+ self.gn_2 = nn.GroupNorm(32, in_channels)
+
+ self.f_2 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
+
+ def forward(self, x, t) -> torch.Tensor:
+ x_skip = x
+
+ t = self.f_t(F.silu(t))
+ t_1, t_2 = t.chunk(2, dim=1)
+ t_1 = t_1.unsqueeze(2).unsqueeze(3) + 1
+ t_2 = t_2.unsqueeze(2).unsqueeze(3)
+
+ gn_1 = F.silu(self.gn_1(x))
+ upsample = F.upsample_nearest(gn_1, scale_factor=2)
+ f_1 = self.f_1(upsample)
+ gn_2 = self.gn_2(f_1)
+
+ f_2 = self.f_2(F.silu(t_2 + (t_1 * gn_2)))
+
+ return f_2 + F.upsample_nearest(x_skip, scale_factor=2)
+
+
+class ConvUNetVAE(nn.Module):
+ def __init__(self) -> None:
+ super().__init__()
+ self.embed_image = ImageEmbedding()
+ self.embed_time = TimestepEmbedding_()
+
+ down_0 = nn.ModuleList(
+ [
+ ConvResblock(320, 320),
+ ConvResblock(320, 320),
+ ConvResblock(320, 320),
+ Downsample(320),
+ ]
+ )
+ down_1 = nn.ModuleList(
+ [
+ ConvResblock(320, 640),
+ ConvResblock(640, 640),
+ ConvResblock(640, 640),
+ Downsample(640),
+ ]
+ )
+ down_2 = nn.ModuleList(
+ [
+ ConvResblock(640, 1024),
+ ConvResblock(1024, 1024),
+ ConvResblock(1024, 1024),
+ Downsample(1024),
+ ]
+ )
+ down_3 = nn.ModuleList(
+ [
+ ConvResblock(1024, 1024),
+ ConvResblock(1024, 1024),
+ ConvResblock(1024, 1024),
+ ]
+ )
+ self.down = nn.ModuleList(
+ [
+ down_0,
+ down_1,
+ down_2,
+ down_3,
+ ]
+ )
+
+ self.mid = nn.ModuleList(
+ [
+ ConvResblock(1024, 1024),
+ ConvResblock(1024, 1024),
+ ]
+ )
+
+ up_3 = nn.ModuleList(
+ [
+ ConvResblock(1024 * 2, 1024),
+ ConvResblock(1024 * 2, 1024),
+ ConvResblock(1024 * 2, 1024),
+ ConvResblock(1024 * 2, 1024),
+ Upsample(1024),
+ ]
+ )
+ up_2 = nn.ModuleList(
+ [
+ ConvResblock(1024 * 2, 1024),
+ ConvResblock(1024 * 2, 1024),
+ ConvResblock(1024 * 2, 1024),
+ ConvResblock(1024 + 640, 1024),
+ Upsample(1024),
+ ]
+ )
+ up_1 = nn.ModuleList(
+ [
+ ConvResblock(1024 + 640, 640),
+ ConvResblock(640 * 2, 640),
+ ConvResblock(640 * 2, 640),
+ ConvResblock(320 + 640, 640),
+ Upsample(640),
+ ]
+ )
+ up_0 = nn.ModuleList(
+ [
+ ConvResblock(320 + 640, 320),
+ ConvResblock(320 * 2, 320),
+ ConvResblock(320 * 2, 320),
+ ConvResblock(320 * 2, 320),
+ ]
+ )
+ self.up = nn.ModuleList(
+ [
+ up_0,
+ up_1,
+ up_2,
+ up_3,
+ ]
+ )
+
+ self.output = ImageUnembedding()
+
+ def forward(self, x, t, features) -> torch.Tensor:
+ converted = hasattr(self, "converted") and self.converted
+
+ x = torch.cat([x, F.upsample_nearest(features, scale_factor=8)], dim=1)
+
+ if converted:
+ t = self.time_embedding(self.time_proj(t))
+ else:
+ t = self.embed_time(t)
+
+ x = self.embed_image(x)
+
+ skips = [x]
+ for i, down in enumerate(self.down):
+ if converted and i in [0, 1, 2, 3]:
+ x, skips_ = down(x, t)
+ for skip in skips_:
+ skips.append(skip)
+ else:
+ for block in down:
+ x = block(x, t)
+ skips.append(x)
+ print(x.float().abs().sum())
+
+ if converted:
+ x = self.mid(x, t)
+ else:
+ for i in range(2):
+ x = self.mid[i](x, t)
+ print(x.float().abs().sum())
+
+ for i, up in enumerate(self.up[::-1]):
+ if converted and i in [0, 1, 2, 3]:
+ skip_4 = skips.pop()
+ skip_3 = skips.pop()
+ skip_2 = skips.pop()
+ skip_1 = skips.pop()
+ skips_ = (skip_1, skip_2, skip_3, skip_4)
+ x = up(x, skips_, t)
+ else:
+ for block in up:
+ if isinstance(block, ConvResblock):
+ x = torch.concat([x, skips.pop()], dim=1)
+ x = block(x, t)
+
+ return self.output(x)
+
+
+def rename_state_dict_key(k):
+ k = k.replace("blocks.", "")
+ for i in range(5):
+ k = k.replace(f"down_{i}_", f"down.{i}.")
+ k = k.replace(f"conv_{i}.", f"{i}.")
+ k = k.replace(f"up_{i}_", f"up.{i}.")
+ k = k.replace(f"mid_{i}", f"mid.{i}")
+ k = k.replace("upsamp.", "4.")
+ k = k.replace("downsamp.", "3.")
+ k = k.replace("f_t.w", "f_t.weight").replace("f_t.b", "f_t.bias")
+ k = k.replace("f_1.w", "f_1.weight").replace("f_1.b", "f_1.bias")
+ k = k.replace("f_2.w", "f_2.weight").replace("f_2.b", "f_2.bias")
+ k = k.replace("f_s.w", "f_s.weight").replace("f_s.b", "f_s.bias")
+ k = k.replace("f.w", "f.weight").replace("f.b", "f.bias")
+ k = k.replace("gn_1.g", "gn_1.weight").replace("gn_1.b", "gn_1.bias")
+ k = k.replace("gn_2.g", "gn_2.weight").replace("gn_2.b", "gn_2.bias")
+ k = k.replace("gn.g", "gn.weight").replace("gn.b", "gn.bias")
+ return k
+
+
+def rename_state_dict(sd, embedding):
+ sd = {rename_state_dict_key(k): v for k, v in sd.items()}
+ sd["embed_time.emb.weight"] = embedding["weight"]
+ return sd
+
+
+# encode with stable diffusion vae
+pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
+pipe.vae.cuda()
+
+# construct original decoder with jitted model
+decoder_consistency = ConsistencyDecoder(device="cuda:0")
+
+# construct UNet code, overwrite the decoder with conv_unet_vae
+model = ConvUNetVAE()
+model.load_state_dict(
+ rename_state_dict(
+ stl("consistency_decoder.safetensors"),
+ stl("embedding.safetensors"),
+ )
+)
+model = model.cuda()
+
+decoder_consistency.ckpt = model
+
+image = load_image(args.test_image, size=(256, 256), center_crop=True)
+latent = pipe.vae.encode(image.half().cuda()).latent_dist.sample()
+
+# decode with gan
+sample_gan = pipe.vae.decode(latent).sample.detach()
+save_image(sample_gan, "gan.png")
+
+# decode with conv_unet_vae
+sample_consistency_orig = decoder_consistency(latent, generator=torch.Generator("cpu").manual_seed(0))
+save_image(sample_consistency_orig, "con_orig.png")
+
+
+########### conversion
+
+print("CONVERSION")
+
+print("DOWN BLOCK ONE")
+
+block_one_sd_orig = model.down[0].state_dict()
+block_one_sd_new = {}
+
+for i in range(3):
+ block_one_sd_new[f"resnets.{i}.norm1.weight"] = block_one_sd_orig.pop(f"{i}.gn_1.weight")
+ block_one_sd_new[f"resnets.{i}.norm1.bias"] = block_one_sd_orig.pop(f"{i}.gn_1.bias")
+ block_one_sd_new[f"resnets.{i}.conv1.weight"] = block_one_sd_orig.pop(f"{i}.f_1.weight")
+ block_one_sd_new[f"resnets.{i}.conv1.bias"] = block_one_sd_orig.pop(f"{i}.f_1.bias")
+ block_one_sd_new[f"resnets.{i}.time_emb_proj.weight"] = block_one_sd_orig.pop(f"{i}.f_t.weight")
+ block_one_sd_new[f"resnets.{i}.time_emb_proj.bias"] = block_one_sd_orig.pop(f"{i}.f_t.bias")
+ block_one_sd_new[f"resnets.{i}.norm2.weight"] = block_one_sd_orig.pop(f"{i}.gn_2.weight")
+ block_one_sd_new[f"resnets.{i}.norm2.bias"] = block_one_sd_orig.pop(f"{i}.gn_2.bias")
+ block_one_sd_new[f"resnets.{i}.conv2.weight"] = block_one_sd_orig.pop(f"{i}.f_2.weight")
+ block_one_sd_new[f"resnets.{i}.conv2.bias"] = block_one_sd_orig.pop(f"{i}.f_2.bias")
+
+block_one_sd_new["downsamplers.0.norm1.weight"] = block_one_sd_orig.pop("3.gn_1.weight")
+block_one_sd_new["downsamplers.0.norm1.bias"] = block_one_sd_orig.pop("3.gn_1.bias")
+block_one_sd_new["downsamplers.0.conv1.weight"] = block_one_sd_orig.pop("3.f_1.weight")
+block_one_sd_new["downsamplers.0.conv1.bias"] = block_one_sd_orig.pop("3.f_1.bias")
+block_one_sd_new["downsamplers.0.time_emb_proj.weight"] = block_one_sd_orig.pop("3.f_t.weight")
+block_one_sd_new["downsamplers.0.time_emb_proj.bias"] = block_one_sd_orig.pop("3.f_t.bias")
+block_one_sd_new["downsamplers.0.norm2.weight"] = block_one_sd_orig.pop("3.gn_2.weight")
+block_one_sd_new["downsamplers.0.norm2.bias"] = block_one_sd_orig.pop("3.gn_2.bias")
+block_one_sd_new["downsamplers.0.conv2.weight"] = block_one_sd_orig.pop("3.f_2.weight")
+block_one_sd_new["downsamplers.0.conv2.bias"] = block_one_sd_orig.pop("3.f_2.bias")
+
+assert len(block_one_sd_orig) == 0
+
+block_one = ResnetDownsampleBlock2D(
+ in_channels=320,
+ out_channels=320,
+ temb_channels=1280,
+ num_layers=3,
+ add_downsample=True,
+ resnet_time_scale_shift="scale_shift",
+ resnet_eps=1e-5,
+)
+
+block_one.load_state_dict(block_one_sd_new)
+
+print("DOWN BLOCK TWO")
+
+block_two_sd_orig = model.down[1].state_dict()
+block_two_sd_new = {}
+
+for i in range(3):
+ block_two_sd_new[f"resnets.{i}.norm1.weight"] = block_two_sd_orig.pop(f"{i}.gn_1.weight")
+ block_two_sd_new[f"resnets.{i}.norm1.bias"] = block_two_sd_orig.pop(f"{i}.gn_1.bias")
+ block_two_sd_new[f"resnets.{i}.conv1.weight"] = block_two_sd_orig.pop(f"{i}.f_1.weight")
+ block_two_sd_new[f"resnets.{i}.conv1.bias"] = block_two_sd_orig.pop(f"{i}.f_1.bias")
+ block_two_sd_new[f"resnets.{i}.time_emb_proj.weight"] = block_two_sd_orig.pop(f"{i}.f_t.weight")
+ block_two_sd_new[f"resnets.{i}.time_emb_proj.bias"] = block_two_sd_orig.pop(f"{i}.f_t.bias")
+ block_two_sd_new[f"resnets.{i}.norm2.weight"] = block_two_sd_orig.pop(f"{i}.gn_2.weight")
+ block_two_sd_new[f"resnets.{i}.norm2.bias"] = block_two_sd_orig.pop(f"{i}.gn_2.bias")
+ block_two_sd_new[f"resnets.{i}.conv2.weight"] = block_two_sd_orig.pop(f"{i}.f_2.weight")
+ block_two_sd_new[f"resnets.{i}.conv2.bias"] = block_two_sd_orig.pop(f"{i}.f_2.bias")
+
+ if i == 0:
+ block_two_sd_new[f"resnets.{i}.conv_shortcut.weight"] = block_two_sd_orig.pop(f"{i}.f_s.weight")
+ block_two_sd_new[f"resnets.{i}.conv_shortcut.bias"] = block_two_sd_orig.pop(f"{i}.f_s.bias")
+
+block_two_sd_new["downsamplers.0.norm1.weight"] = block_two_sd_orig.pop("3.gn_1.weight")
+block_two_sd_new["downsamplers.0.norm1.bias"] = block_two_sd_orig.pop("3.gn_1.bias")
+block_two_sd_new["downsamplers.0.conv1.weight"] = block_two_sd_orig.pop("3.f_1.weight")
+block_two_sd_new["downsamplers.0.conv1.bias"] = block_two_sd_orig.pop("3.f_1.bias")
+block_two_sd_new["downsamplers.0.time_emb_proj.weight"] = block_two_sd_orig.pop("3.f_t.weight")
+block_two_sd_new["downsamplers.0.time_emb_proj.bias"] = block_two_sd_orig.pop("3.f_t.bias")
+block_two_sd_new["downsamplers.0.norm2.weight"] = block_two_sd_orig.pop("3.gn_2.weight")
+block_two_sd_new["downsamplers.0.norm2.bias"] = block_two_sd_orig.pop("3.gn_2.bias")
+block_two_sd_new["downsamplers.0.conv2.weight"] = block_two_sd_orig.pop("3.f_2.weight")
+block_two_sd_new["downsamplers.0.conv2.bias"] = block_two_sd_orig.pop("3.f_2.bias")
+
+assert len(block_two_sd_orig) == 0
+
+block_two = ResnetDownsampleBlock2D(
+ in_channels=320,
+ out_channels=640,
+ temb_channels=1280,
+ num_layers=3,
+ add_downsample=True,
+ resnet_time_scale_shift="scale_shift",
+ resnet_eps=1e-5,
+)
+
+block_two.load_state_dict(block_two_sd_new)
+
+print("DOWN BLOCK THREE")
+
+block_three_sd_orig = model.down[2].state_dict()
+block_three_sd_new = {}
+
+for i in range(3):
+ block_three_sd_new[f"resnets.{i}.norm1.weight"] = block_three_sd_orig.pop(f"{i}.gn_1.weight")
+ block_three_sd_new[f"resnets.{i}.norm1.bias"] = block_three_sd_orig.pop(f"{i}.gn_1.bias")
+ block_three_sd_new[f"resnets.{i}.conv1.weight"] = block_three_sd_orig.pop(f"{i}.f_1.weight")
+ block_three_sd_new[f"resnets.{i}.conv1.bias"] = block_three_sd_orig.pop(f"{i}.f_1.bias")
+ block_three_sd_new[f"resnets.{i}.time_emb_proj.weight"] = block_three_sd_orig.pop(f"{i}.f_t.weight")
+ block_three_sd_new[f"resnets.{i}.time_emb_proj.bias"] = block_three_sd_orig.pop(f"{i}.f_t.bias")
+ block_three_sd_new[f"resnets.{i}.norm2.weight"] = block_three_sd_orig.pop(f"{i}.gn_2.weight")
+ block_three_sd_new[f"resnets.{i}.norm2.bias"] = block_three_sd_orig.pop(f"{i}.gn_2.bias")
+ block_three_sd_new[f"resnets.{i}.conv2.weight"] = block_three_sd_orig.pop(f"{i}.f_2.weight")
+ block_three_sd_new[f"resnets.{i}.conv2.bias"] = block_three_sd_orig.pop(f"{i}.f_2.bias")
+
+ if i == 0:
+ block_three_sd_new[f"resnets.{i}.conv_shortcut.weight"] = block_three_sd_orig.pop(f"{i}.f_s.weight")
+ block_three_sd_new[f"resnets.{i}.conv_shortcut.bias"] = block_three_sd_orig.pop(f"{i}.f_s.bias")
+
+block_three_sd_new["downsamplers.0.norm1.weight"] = block_three_sd_orig.pop("3.gn_1.weight")
+block_three_sd_new["downsamplers.0.norm1.bias"] = block_three_sd_orig.pop("3.gn_1.bias")
+block_three_sd_new["downsamplers.0.conv1.weight"] = block_three_sd_orig.pop("3.f_1.weight")
+block_three_sd_new["downsamplers.0.conv1.bias"] = block_three_sd_orig.pop("3.f_1.bias")
+block_three_sd_new["downsamplers.0.time_emb_proj.weight"] = block_three_sd_orig.pop("3.f_t.weight")
+block_three_sd_new["downsamplers.0.time_emb_proj.bias"] = block_three_sd_orig.pop("3.f_t.bias")
+block_three_sd_new["downsamplers.0.norm2.weight"] = block_three_sd_orig.pop("3.gn_2.weight")
+block_three_sd_new["downsamplers.0.norm2.bias"] = block_three_sd_orig.pop("3.gn_2.bias")
+block_three_sd_new["downsamplers.0.conv2.weight"] = block_three_sd_orig.pop("3.f_2.weight")
+block_three_sd_new["downsamplers.0.conv2.bias"] = block_three_sd_orig.pop("3.f_2.bias")
+
+assert len(block_three_sd_orig) == 0
+
+block_three = ResnetDownsampleBlock2D(
+ in_channels=640,
+ out_channels=1024,
+ temb_channels=1280,
+ num_layers=3,
+ add_downsample=True,
+ resnet_time_scale_shift="scale_shift",
+ resnet_eps=1e-5,
+)
+
+block_three.load_state_dict(block_three_sd_new)
+
+print("DOWN BLOCK FOUR")
+
+block_four_sd_orig = model.down[3].state_dict()
+block_four_sd_new = {}
+
+for i in range(3):
+ block_four_sd_new[f"resnets.{i}.norm1.weight"] = block_four_sd_orig.pop(f"{i}.gn_1.weight")
+ block_four_sd_new[f"resnets.{i}.norm1.bias"] = block_four_sd_orig.pop(f"{i}.gn_1.bias")
+ block_four_sd_new[f"resnets.{i}.conv1.weight"] = block_four_sd_orig.pop(f"{i}.f_1.weight")
+ block_four_sd_new[f"resnets.{i}.conv1.bias"] = block_four_sd_orig.pop(f"{i}.f_1.bias")
+ block_four_sd_new[f"resnets.{i}.time_emb_proj.weight"] = block_four_sd_orig.pop(f"{i}.f_t.weight")
+ block_four_sd_new[f"resnets.{i}.time_emb_proj.bias"] = block_four_sd_orig.pop(f"{i}.f_t.bias")
+ block_four_sd_new[f"resnets.{i}.norm2.weight"] = block_four_sd_orig.pop(f"{i}.gn_2.weight")
+ block_four_sd_new[f"resnets.{i}.norm2.bias"] = block_four_sd_orig.pop(f"{i}.gn_2.bias")
+ block_four_sd_new[f"resnets.{i}.conv2.weight"] = block_four_sd_orig.pop(f"{i}.f_2.weight")
+ block_four_sd_new[f"resnets.{i}.conv2.bias"] = block_four_sd_orig.pop(f"{i}.f_2.bias")
+
+assert len(block_four_sd_orig) == 0
+
+block_four = ResnetDownsampleBlock2D(
+ in_channels=1024,
+ out_channels=1024,
+ temb_channels=1280,
+ num_layers=3,
+ add_downsample=False,
+ resnet_time_scale_shift="scale_shift",
+ resnet_eps=1e-5,
+)
+
+block_four.load_state_dict(block_four_sd_new)
+
+
+print("MID BLOCK 1")
+
+mid_block_one_sd_orig = model.mid.state_dict()
+mid_block_one_sd_new = {}
+
+for i in range(2):
+ mid_block_one_sd_new[f"resnets.{i}.norm1.weight"] = mid_block_one_sd_orig.pop(f"{i}.gn_1.weight")
+ mid_block_one_sd_new[f"resnets.{i}.norm1.bias"] = mid_block_one_sd_orig.pop(f"{i}.gn_1.bias")
+ mid_block_one_sd_new[f"resnets.{i}.conv1.weight"] = mid_block_one_sd_orig.pop(f"{i}.f_1.weight")
+ mid_block_one_sd_new[f"resnets.{i}.conv1.bias"] = mid_block_one_sd_orig.pop(f"{i}.f_1.bias")
+ mid_block_one_sd_new[f"resnets.{i}.time_emb_proj.weight"] = mid_block_one_sd_orig.pop(f"{i}.f_t.weight")
+ mid_block_one_sd_new[f"resnets.{i}.time_emb_proj.bias"] = mid_block_one_sd_orig.pop(f"{i}.f_t.bias")
+ mid_block_one_sd_new[f"resnets.{i}.norm2.weight"] = mid_block_one_sd_orig.pop(f"{i}.gn_2.weight")
+ mid_block_one_sd_new[f"resnets.{i}.norm2.bias"] = mid_block_one_sd_orig.pop(f"{i}.gn_2.bias")
+ mid_block_one_sd_new[f"resnets.{i}.conv2.weight"] = mid_block_one_sd_orig.pop(f"{i}.f_2.weight")
+ mid_block_one_sd_new[f"resnets.{i}.conv2.bias"] = mid_block_one_sd_orig.pop(f"{i}.f_2.bias")
+
+assert len(mid_block_one_sd_orig) == 0
+
+mid_block_one = UNetMidBlock2D(
+ in_channels=1024,
+ temb_channels=1280,
+ num_layers=1,
+ resnet_time_scale_shift="scale_shift",
+ resnet_eps=1e-5,
+ add_attention=False,
+)
+
+mid_block_one.load_state_dict(mid_block_one_sd_new)
+
+print("UP BLOCK ONE")
+
+up_block_one_sd_orig = model.up[-1].state_dict()
+up_block_one_sd_new = {}
+
+for i in range(4):
+ up_block_one_sd_new[f"resnets.{i}.norm1.weight"] = up_block_one_sd_orig.pop(f"{i}.gn_1.weight")
+ up_block_one_sd_new[f"resnets.{i}.norm1.bias"] = up_block_one_sd_orig.pop(f"{i}.gn_1.bias")
+ up_block_one_sd_new[f"resnets.{i}.conv1.weight"] = up_block_one_sd_orig.pop(f"{i}.f_1.weight")
+ up_block_one_sd_new[f"resnets.{i}.conv1.bias"] = up_block_one_sd_orig.pop(f"{i}.f_1.bias")
+ up_block_one_sd_new[f"resnets.{i}.time_emb_proj.weight"] = up_block_one_sd_orig.pop(f"{i}.f_t.weight")
+ up_block_one_sd_new[f"resnets.{i}.time_emb_proj.bias"] = up_block_one_sd_orig.pop(f"{i}.f_t.bias")
+ up_block_one_sd_new[f"resnets.{i}.norm2.weight"] = up_block_one_sd_orig.pop(f"{i}.gn_2.weight")
+ up_block_one_sd_new[f"resnets.{i}.norm2.bias"] = up_block_one_sd_orig.pop(f"{i}.gn_2.bias")
+ up_block_one_sd_new[f"resnets.{i}.conv2.weight"] = up_block_one_sd_orig.pop(f"{i}.f_2.weight")
+ up_block_one_sd_new[f"resnets.{i}.conv2.bias"] = up_block_one_sd_orig.pop(f"{i}.f_2.bias")
+ up_block_one_sd_new[f"resnets.{i}.conv_shortcut.weight"] = up_block_one_sd_orig.pop(f"{i}.f_s.weight")
+ up_block_one_sd_new[f"resnets.{i}.conv_shortcut.bias"] = up_block_one_sd_orig.pop(f"{i}.f_s.bias")
+
+up_block_one_sd_new["upsamplers.0.norm1.weight"] = up_block_one_sd_orig.pop("4.gn_1.weight")
+up_block_one_sd_new["upsamplers.0.norm1.bias"] = up_block_one_sd_orig.pop("4.gn_1.bias")
+up_block_one_sd_new["upsamplers.0.conv1.weight"] = up_block_one_sd_orig.pop("4.f_1.weight")
+up_block_one_sd_new["upsamplers.0.conv1.bias"] = up_block_one_sd_orig.pop("4.f_1.bias")
+up_block_one_sd_new["upsamplers.0.time_emb_proj.weight"] = up_block_one_sd_orig.pop("4.f_t.weight")
+up_block_one_sd_new["upsamplers.0.time_emb_proj.bias"] = up_block_one_sd_orig.pop("4.f_t.bias")
+up_block_one_sd_new["upsamplers.0.norm2.weight"] = up_block_one_sd_orig.pop("4.gn_2.weight")
+up_block_one_sd_new["upsamplers.0.norm2.bias"] = up_block_one_sd_orig.pop("4.gn_2.bias")
+up_block_one_sd_new["upsamplers.0.conv2.weight"] = up_block_one_sd_orig.pop("4.f_2.weight")
+up_block_one_sd_new["upsamplers.0.conv2.bias"] = up_block_one_sd_orig.pop("4.f_2.bias")
+
+assert len(up_block_one_sd_orig) == 0
+
+up_block_one = ResnetUpsampleBlock2D(
+ in_channels=1024,
+ prev_output_channel=1024,
+ out_channels=1024,
+ temb_channels=1280,
+ num_layers=4,
+ add_upsample=True,
+ resnet_time_scale_shift="scale_shift",
+ resnet_eps=1e-5,
+)
+
+up_block_one.load_state_dict(up_block_one_sd_new)
+
+print("UP BLOCK TWO")
+
+up_block_two_sd_orig = model.up[-2].state_dict()
+up_block_two_sd_new = {}
+
+for i in range(4):
+ up_block_two_sd_new[f"resnets.{i}.norm1.weight"] = up_block_two_sd_orig.pop(f"{i}.gn_1.weight")
+ up_block_two_sd_new[f"resnets.{i}.norm1.bias"] = up_block_two_sd_orig.pop(f"{i}.gn_1.bias")
+ up_block_two_sd_new[f"resnets.{i}.conv1.weight"] = up_block_two_sd_orig.pop(f"{i}.f_1.weight")
+ up_block_two_sd_new[f"resnets.{i}.conv1.bias"] = up_block_two_sd_orig.pop(f"{i}.f_1.bias")
+ up_block_two_sd_new[f"resnets.{i}.time_emb_proj.weight"] = up_block_two_sd_orig.pop(f"{i}.f_t.weight")
+ up_block_two_sd_new[f"resnets.{i}.time_emb_proj.bias"] = up_block_two_sd_orig.pop(f"{i}.f_t.bias")
+ up_block_two_sd_new[f"resnets.{i}.norm2.weight"] = up_block_two_sd_orig.pop(f"{i}.gn_2.weight")
+ up_block_two_sd_new[f"resnets.{i}.norm2.bias"] = up_block_two_sd_orig.pop(f"{i}.gn_2.bias")
+ up_block_two_sd_new[f"resnets.{i}.conv2.weight"] = up_block_two_sd_orig.pop(f"{i}.f_2.weight")
+ up_block_two_sd_new[f"resnets.{i}.conv2.bias"] = up_block_two_sd_orig.pop(f"{i}.f_2.bias")
+ up_block_two_sd_new[f"resnets.{i}.conv_shortcut.weight"] = up_block_two_sd_orig.pop(f"{i}.f_s.weight")
+ up_block_two_sd_new[f"resnets.{i}.conv_shortcut.bias"] = up_block_two_sd_orig.pop(f"{i}.f_s.bias")
+
+up_block_two_sd_new["upsamplers.0.norm1.weight"] = up_block_two_sd_orig.pop("4.gn_1.weight")
+up_block_two_sd_new["upsamplers.0.norm1.bias"] = up_block_two_sd_orig.pop("4.gn_1.bias")
+up_block_two_sd_new["upsamplers.0.conv1.weight"] = up_block_two_sd_orig.pop("4.f_1.weight")
+up_block_two_sd_new["upsamplers.0.conv1.bias"] = up_block_two_sd_orig.pop("4.f_1.bias")
+up_block_two_sd_new["upsamplers.0.time_emb_proj.weight"] = up_block_two_sd_orig.pop("4.f_t.weight")
+up_block_two_sd_new["upsamplers.0.time_emb_proj.bias"] = up_block_two_sd_orig.pop("4.f_t.bias")
+up_block_two_sd_new["upsamplers.0.norm2.weight"] = up_block_two_sd_orig.pop("4.gn_2.weight")
+up_block_two_sd_new["upsamplers.0.norm2.bias"] = up_block_two_sd_orig.pop("4.gn_2.bias")
+up_block_two_sd_new["upsamplers.0.conv2.weight"] = up_block_two_sd_orig.pop("4.f_2.weight")
+up_block_two_sd_new["upsamplers.0.conv2.bias"] = up_block_two_sd_orig.pop("4.f_2.bias")
+
+assert len(up_block_two_sd_orig) == 0
+
+up_block_two = ResnetUpsampleBlock2D(
+ in_channels=640,
+ prev_output_channel=1024,
+ out_channels=1024,
+ temb_channels=1280,
+ num_layers=4,
+ add_upsample=True,
+ resnet_time_scale_shift="scale_shift",
+ resnet_eps=1e-5,
+)
+
+up_block_two.load_state_dict(up_block_two_sd_new)
+
+print("UP BLOCK THREE")
+
+up_block_three_sd_orig = model.up[-3].state_dict()
+up_block_three_sd_new = {}
+
+for i in range(4):
+ up_block_three_sd_new[f"resnets.{i}.norm1.weight"] = up_block_three_sd_orig.pop(f"{i}.gn_1.weight")
+ up_block_three_sd_new[f"resnets.{i}.norm1.bias"] = up_block_three_sd_orig.pop(f"{i}.gn_1.bias")
+ up_block_three_sd_new[f"resnets.{i}.conv1.weight"] = up_block_three_sd_orig.pop(f"{i}.f_1.weight")
+ up_block_three_sd_new[f"resnets.{i}.conv1.bias"] = up_block_three_sd_orig.pop(f"{i}.f_1.bias")
+ up_block_three_sd_new[f"resnets.{i}.time_emb_proj.weight"] = up_block_three_sd_orig.pop(f"{i}.f_t.weight")
+ up_block_three_sd_new[f"resnets.{i}.time_emb_proj.bias"] = up_block_three_sd_orig.pop(f"{i}.f_t.bias")
+ up_block_three_sd_new[f"resnets.{i}.norm2.weight"] = up_block_three_sd_orig.pop(f"{i}.gn_2.weight")
+ up_block_three_sd_new[f"resnets.{i}.norm2.bias"] = up_block_three_sd_orig.pop(f"{i}.gn_2.bias")
+ up_block_three_sd_new[f"resnets.{i}.conv2.weight"] = up_block_three_sd_orig.pop(f"{i}.f_2.weight")
+ up_block_three_sd_new[f"resnets.{i}.conv2.bias"] = up_block_three_sd_orig.pop(f"{i}.f_2.bias")
+ up_block_three_sd_new[f"resnets.{i}.conv_shortcut.weight"] = up_block_three_sd_orig.pop(f"{i}.f_s.weight")
+ up_block_three_sd_new[f"resnets.{i}.conv_shortcut.bias"] = up_block_three_sd_orig.pop(f"{i}.f_s.bias")
+
+up_block_three_sd_new["upsamplers.0.norm1.weight"] = up_block_three_sd_orig.pop("4.gn_1.weight")
+up_block_three_sd_new["upsamplers.0.norm1.bias"] = up_block_three_sd_orig.pop("4.gn_1.bias")
+up_block_three_sd_new["upsamplers.0.conv1.weight"] = up_block_three_sd_orig.pop("4.f_1.weight")
+up_block_three_sd_new["upsamplers.0.conv1.bias"] = up_block_three_sd_orig.pop("4.f_1.bias")
+up_block_three_sd_new["upsamplers.0.time_emb_proj.weight"] = up_block_three_sd_orig.pop("4.f_t.weight")
+up_block_three_sd_new["upsamplers.0.time_emb_proj.bias"] = up_block_three_sd_orig.pop("4.f_t.bias")
+up_block_three_sd_new["upsamplers.0.norm2.weight"] = up_block_three_sd_orig.pop("4.gn_2.weight")
+up_block_three_sd_new["upsamplers.0.norm2.bias"] = up_block_three_sd_orig.pop("4.gn_2.bias")
+up_block_three_sd_new["upsamplers.0.conv2.weight"] = up_block_three_sd_orig.pop("4.f_2.weight")
+up_block_three_sd_new["upsamplers.0.conv2.bias"] = up_block_three_sd_orig.pop("4.f_2.bias")
+
+assert len(up_block_three_sd_orig) == 0
+
+up_block_three = ResnetUpsampleBlock2D(
+ in_channels=320,
+ prev_output_channel=1024,
+ out_channels=640,
+ temb_channels=1280,
+ num_layers=4,
+ add_upsample=True,
+ resnet_time_scale_shift="scale_shift",
+ resnet_eps=1e-5,
+)
+
+up_block_three.load_state_dict(up_block_three_sd_new)
+
+print("UP BLOCK FOUR")
+
+up_block_four_sd_orig = model.up[-4].state_dict()
+up_block_four_sd_new = {}
+
+for i in range(4):
+ up_block_four_sd_new[f"resnets.{i}.norm1.weight"] = up_block_four_sd_orig.pop(f"{i}.gn_1.weight")
+ up_block_four_sd_new[f"resnets.{i}.norm1.bias"] = up_block_four_sd_orig.pop(f"{i}.gn_1.bias")
+ up_block_four_sd_new[f"resnets.{i}.conv1.weight"] = up_block_four_sd_orig.pop(f"{i}.f_1.weight")
+ up_block_four_sd_new[f"resnets.{i}.conv1.bias"] = up_block_four_sd_orig.pop(f"{i}.f_1.bias")
+ up_block_four_sd_new[f"resnets.{i}.time_emb_proj.weight"] = up_block_four_sd_orig.pop(f"{i}.f_t.weight")
+ up_block_four_sd_new[f"resnets.{i}.time_emb_proj.bias"] = up_block_four_sd_orig.pop(f"{i}.f_t.bias")
+ up_block_four_sd_new[f"resnets.{i}.norm2.weight"] = up_block_four_sd_orig.pop(f"{i}.gn_2.weight")
+ up_block_four_sd_new[f"resnets.{i}.norm2.bias"] = up_block_four_sd_orig.pop(f"{i}.gn_2.bias")
+ up_block_four_sd_new[f"resnets.{i}.conv2.weight"] = up_block_four_sd_orig.pop(f"{i}.f_2.weight")
+ up_block_four_sd_new[f"resnets.{i}.conv2.bias"] = up_block_four_sd_orig.pop(f"{i}.f_2.bias")
+ up_block_four_sd_new[f"resnets.{i}.conv_shortcut.weight"] = up_block_four_sd_orig.pop(f"{i}.f_s.weight")
+ up_block_four_sd_new[f"resnets.{i}.conv_shortcut.bias"] = up_block_four_sd_orig.pop(f"{i}.f_s.bias")
+
+assert len(up_block_four_sd_orig) == 0
+
+up_block_four = ResnetUpsampleBlock2D(
+ in_channels=320,
+ prev_output_channel=640,
+ out_channels=320,
+ temb_channels=1280,
+ num_layers=4,
+ add_upsample=False,
+ resnet_time_scale_shift="scale_shift",
+ resnet_eps=1e-5,
+)
+
+up_block_four.load_state_dict(up_block_four_sd_new)
+
+print("initial projection (conv_in)")
+
+conv_in_sd_orig = model.embed_image.state_dict()
+conv_in_sd_new = {}
+
+conv_in_sd_new["weight"] = conv_in_sd_orig.pop("f.weight")
+conv_in_sd_new["bias"] = conv_in_sd_orig.pop("f.bias")
+
+assert len(conv_in_sd_orig) == 0
+
+block_out_channels = [320, 640, 1024, 1024]
+
+in_channels = 7
+conv_in_kernel = 3
+conv_in_padding = (conv_in_kernel - 1) // 2
+conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding)
+
+conv_in.load_state_dict(conv_in_sd_new)
+
+print("out projection (conv_out) (conv_norm_out)")
+out_channels = 6
+norm_num_groups = 32
+norm_eps = 1e-5
+act_fn = "silu"
+conv_out_kernel = 3
+conv_out_padding = (conv_out_kernel - 1) // 2
+conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
+# uses torch.functional in orig
+# conv_act = get_activation(act_fn)
+conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding)
+
+conv_norm_out.load_state_dict(model.output.gn.state_dict())
+conv_out.load_state_dict(model.output.f.state_dict())
+
+print("timestep projection (time_proj) (time_embedding)")
+
+f1_sd = model.embed_time.f_1.state_dict()
+f2_sd = model.embed_time.f_2.state_dict()
+
+time_embedding_sd = {
+ "linear_1.weight": f1_sd.pop("weight"),
+ "linear_1.bias": f1_sd.pop("bias"),
+ "linear_2.weight": f2_sd.pop("weight"),
+ "linear_2.bias": f2_sd.pop("bias"),
+}
+
+assert len(f1_sd) == 0
+assert len(f2_sd) == 0
+
+time_embedding_type = "learned"
+num_train_timesteps = 1024
+time_embedding_dim = 1280
+
+time_proj = nn.Embedding(num_train_timesteps, block_out_channels[0])
+timestep_input_dim = block_out_channels[0]
+
+time_embedding = TimestepEmbedding(timestep_input_dim, time_embedding_dim)
+
+time_proj.load_state_dict(model.embed_time.emb.state_dict())
+time_embedding.load_state_dict(time_embedding_sd)
+
+print("CONVERT")
+
+time_embedding.to("cuda")
+time_proj.to("cuda")
+conv_in.to("cuda")
+
+block_one.to("cuda")
+block_two.to("cuda")
+block_three.to("cuda")
+block_four.to("cuda")
+
+mid_block_one.to("cuda")
+
+up_block_one.to("cuda")
+up_block_two.to("cuda")
+up_block_three.to("cuda")
+up_block_four.to("cuda")
+
+conv_norm_out.to("cuda")
+conv_out.to("cuda")
+
+model.time_proj = time_proj
+model.time_embedding = time_embedding
+model.embed_image = conv_in
+
+model.down[0] = block_one
+model.down[1] = block_two
+model.down[2] = block_three
+model.down[3] = block_four
+
+model.mid = mid_block_one
+
+model.up[-1] = up_block_one
+model.up[-2] = up_block_two
+model.up[-3] = up_block_three
+model.up[-4] = up_block_four
+
+model.output.gn = conv_norm_out
+model.output.f = conv_out
+
+model.converted = True
+
+sample_consistency_new = decoder_consistency(latent, generator=torch.Generator("cpu").manual_seed(0))
+save_image(sample_consistency_new, "con_new.png")
+
+assert (sample_consistency_orig == sample_consistency_new).all()
+
+print("making unet")
+
+unet = UNet2DModel(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ down_block_types=(
+ "ResnetDownsampleBlock2D",
+ "ResnetDownsampleBlock2D",
+ "ResnetDownsampleBlock2D",
+ "ResnetDownsampleBlock2D",
+ ),
+ up_block_types=(
+ "ResnetUpsampleBlock2D",
+ "ResnetUpsampleBlock2D",
+ "ResnetUpsampleBlock2D",
+ "ResnetUpsampleBlock2D",
+ ),
+ block_out_channels=block_out_channels,
+ layers_per_block=3,
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ resnet_time_scale_shift="scale_shift",
+ time_embedding_type="learned",
+ num_train_timesteps=num_train_timesteps,
+ add_attention=False,
+)
+
+unet_state_dict = {}
+
+
+def add_state_dict(prefix, mod):
+ for k, v in mod.state_dict().items():
+ unet_state_dict[f"{prefix}.{k}"] = v
+
+
+add_state_dict("conv_in", conv_in)
+add_state_dict("time_proj", time_proj)
+add_state_dict("time_embedding", time_embedding)
+add_state_dict("down_blocks.0", block_one)
+add_state_dict("down_blocks.1", block_two)
+add_state_dict("down_blocks.2", block_three)
+add_state_dict("down_blocks.3", block_four)
+add_state_dict("mid_block", mid_block_one)
+add_state_dict("up_blocks.0", up_block_one)
+add_state_dict("up_blocks.1", up_block_two)
+add_state_dict("up_blocks.2", up_block_three)
+add_state_dict("up_blocks.3", up_block_four)
+add_state_dict("conv_norm_out", conv_norm_out)
+add_state_dict("conv_out", conv_out)
+
+unet.load_state_dict(unet_state_dict)
+
+print("running with diffusers unet")
+
+unet.to("cuda")
+
+decoder_consistency.ckpt = unet
+
+sample_consistency_new_2 = decoder_consistency(latent, generator=torch.Generator("cpu").manual_seed(0))
+save_image(sample_consistency_new_2, "con_new_2.png")
+
+assert (sample_consistency_orig == sample_consistency_new_2).all()
+
+print("running with diffusers model")
+
+Encoder.old_constructor = Encoder.__init__
+
+
+def new_constructor(self, **kwargs):
+ self.old_constructor(**kwargs)
+ self.constructor_arguments = kwargs
+
+
+Encoder.__init__ = new_constructor
+
+
+vae = AutoencoderKL.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="vae")
+consistency_vae = ConsistencyDecoderVAE(
+ encoder_args=vae.encoder.constructor_arguments,
+ decoder_args=unet.config,
+ scaling_factor=vae.config.scaling_factor,
+ block_out_channels=vae.config.block_out_channels,
+ latent_channels=vae.config.latent_channels,
+)
+consistency_vae.encoder.load_state_dict(vae.encoder.state_dict())
+consistency_vae.quant_conv.load_state_dict(vae.quant_conv.state_dict())
+consistency_vae.decoder_unet.load_state_dict(unet.state_dict())
+
+consistency_vae.to(dtype=torch.float16, device="cuda")
+
+sample_consistency_new_3 = consistency_vae.decode(
+ 0.18215 * latent, generator=torch.Generator("cpu").manual_seed(0)
+).sample
+
+print("max difference")
+print((sample_consistency_orig - sample_consistency_new_3).abs().max())
+print("total difference")
+print((sample_consistency_orig - sample_consistency_new_3).abs().sum())
+# assert (sample_consistency_orig == sample_consistency_new_3).all()
+
+print("running with diffusers pipeline")
+
+pipe = DiffusionPipeline.from_pretrained(
+ "runwayml/stable-diffusion-v1-5", vae=consistency_vae, torch_dtype=torch.float16
+)
+pipe.to("cuda")
+
+pipe("horse", generator=torch.Generator("cpu").manual_seed(0)).images[0].save("horse.png")
+
+
+if args.save_pretrained is not None:
+ consistency_vae.save_pretrained(args.save_pretrained)
diff --git a/diffusers/scripts/convert_consistency_to_diffusers.py b/diffusers/scripts/convert_consistency_to_diffusers.py
new file mode 100644
index 0000000000000000000000000000000000000000..0f8b4ddca8efd5e4392754f12b67448adad8c0b7
--- /dev/null
+++ b/diffusers/scripts/convert_consistency_to_diffusers.py
@@ -0,0 +1,315 @@
+import argparse
+import os
+
+import torch
+
+from diffusers import (
+ CMStochasticIterativeScheduler,
+ ConsistencyModelPipeline,
+ UNet2DModel,
+)
+
+
+TEST_UNET_CONFIG = {
+ "sample_size": 32,
+ "in_channels": 3,
+ "out_channels": 3,
+ "layers_per_block": 2,
+ "num_class_embeds": 1000,
+ "block_out_channels": [32, 64],
+ "attention_head_dim": 8,
+ "down_block_types": [
+ "ResnetDownsampleBlock2D",
+ "AttnDownBlock2D",
+ ],
+ "up_block_types": [
+ "AttnUpBlock2D",
+ "ResnetUpsampleBlock2D",
+ ],
+ "resnet_time_scale_shift": "scale_shift",
+ "attn_norm_num_groups": 32,
+ "upsample_type": "resnet",
+ "downsample_type": "resnet",
+}
+
+IMAGENET_64_UNET_CONFIG = {
+ "sample_size": 64,
+ "in_channels": 3,
+ "out_channels": 3,
+ "layers_per_block": 3,
+ "num_class_embeds": 1000,
+ "block_out_channels": [192, 192 * 2, 192 * 3, 192 * 4],
+ "attention_head_dim": 64,
+ "down_block_types": [
+ "ResnetDownsampleBlock2D",
+ "AttnDownBlock2D",
+ "AttnDownBlock2D",
+ "AttnDownBlock2D",
+ ],
+ "up_block_types": [
+ "AttnUpBlock2D",
+ "AttnUpBlock2D",
+ "AttnUpBlock2D",
+ "ResnetUpsampleBlock2D",
+ ],
+ "resnet_time_scale_shift": "scale_shift",
+ "attn_norm_num_groups": 32,
+ "upsample_type": "resnet",
+ "downsample_type": "resnet",
+}
+
+LSUN_256_UNET_CONFIG = {
+ "sample_size": 256,
+ "in_channels": 3,
+ "out_channels": 3,
+ "layers_per_block": 2,
+ "num_class_embeds": None,
+ "block_out_channels": [256, 256, 256 * 2, 256 * 2, 256 * 4, 256 * 4],
+ "attention_head_dim": 64,
+ "down_block_types": [
+ "ResnetDownsampleBlock2D",
+ "ResnetDownsampleBlock2D",
+ "ResnetDownsampleBlock2D",
+ "AttnDownBlock2D",
+ "AttnDownBlock2D",
+ "AttnDownBlock2D",
+ ],
+ "up_block_types": [
+ "AttnUpBlock2D",
+ "AttnUpBlock2D",
+ "AttnUpBlock2D",
+ "ResnetUpsampleBlock2D",
+ "ResnetUpsampleBlock2D",
+ "ResnetUpsampleBlock2D",
+ ],
+ "resnet_time_scale_shift": "default",
+ "upsample_type": "resnet",
+ "downsample_type": "resnet",
+}
+
+CD_SCHEDULER_CONFIG = {
+ "num_train_timesteps": 40,
+ "sigma_min": 0.002,
+ "sigma_max": 80.0,
+}
+
+CT_IMAGENET_64_SCHEDULER_CONFIG = {
+ "num_train_timesteps": 201,
+ "sigma_min": 0.002,
+ "sigma_max": 80.0,
+}
+
+CT_LSUN_256_SCHEDULER_CONFIG = {
+ "num_train_timesteps": 151,
+ "sigma_min": 0.002,
+ "sigma_max": 80.0,
+}
+
+
+def str2bool(v):
+ """
+ https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse
+ """
+ if isinstance(v, bool):
+ return v
+ if v.lower() in ("yes", "true", "t", "y", "1"):
+ return True
+ elif v.lower() in ("no", "false", "f", "n", "0"):
+ return False
+ else:
+ raise argparse.ArgumentTypeError("boolean value expected")
+
+
+def convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix, has_skip=False):
+ new_checkpoint[f"{new_prefix}.norm1.weight"] = checkpoint[f"{old_prefix}.in_layers.0.weight"]
+ new_checkpoint[f"{new_prefix}.norm1.bias"] = checkpoint[f"{old_prefix}.in_layers.0.bias"]
+ new_checkpoint[f"{new_prefix}.conv1.weight"] = checkpoint[f"{old_prefix}.in_layers.2.weight"]
+ new_checkpoint[f"{new_prefix}.conv1.bias"] = checkpoint[f"{old_prefix}.in_layers.2.bias"]
+ new_checkpoint[f"{new_prefix}.time_emb_proj.weight"] = checkpoint[f"{old_prefix}.emb_layers.1.weight"]
+ new_checkpoint[f"{new_prefix}.time_emb_proj.bias"] = checkpoint[f"{old_prefix}.emb_layers.1.bias"]
+ new_checkpoint[f"{new_prefix}.norm2.weight"] = checkpoint[f"{old_prefix}.out_layers.0.weight"]
+ new_checkpoint[f"{new_prefix}.norm2.bias"] = checkpoint[f"{old_prefix}.out_layers.0.bias"]
+ new_checkpoint[f"{new_prefix}.conv2.weight"] = checkpoint[f"{old_prefix}.out_layers.3.weight"]
+ new_checkpoint[f"{new_prefix}.conv2.bias"] = checkpoint[f"{old_prefix}.out_layers.3.bias"]
+
+ if has_skip:
+ new_checkpoint[f"{new_prefix}.conv_shortcut.weight"] = checkpoint[f"{old_prefix}.skip_connection.weight"]
+ new_checkpoint[f"{new_prefix}.conv_shortcut.bias"] = checkpoint[f"{old_prefix}.skip_connection.bias"]
+
+ return new_checkpoint
+
+
+def convert_attention(checkpoint, new_checkpoint, old_prefix, new_prefix, attention_dim=None):
+ weight_q, weight_k, weight_v = checkpoint[f"{old_prefix}.qkv.weight"].chunk(3, dim=0)
+ bias_q, bias_k, bias_v = checkpoint[f"{old_prefix}.qkv.bias"].chunk(3, dim=0)
+
+ new_checkpoint[f"{new_prefix}.group_norm.weight"] = checkpoint[f"{old_prefix}.norm.weight"]
+ new_checkpoint[f"{new_prefix}.group_norm.bias"] = checkpoint[f"{old_prefix}.norm.bias"]
+
+ new_checkpoint[f"{new_prefix}.to_q.weight"] = weight_q.squeeze(-1).squeeze(-1)
+ new_checkpoint[f"{new_prefix}.to_q.bias"] = bias_q.squeeze(-1).squeeze(-1)
+ new_checkpoint[f"{new_prefix}.to_k.weight"] = weight_k.squeeze(-1).squeeze(-1)
+ new_checkpoint[f"{new_prefix}.to_k.bias"] = bias_k.squeeze(-1).squeeze(-1)
+ new_checkpoint[f"{new_prefix}.to_v.weight"] = weight_v.squeeze(-1).squeeze(-1)
+ new_checkpoint[f"{new_prefix}.to_v.bias"] = bias_v.squeeze(-1).squeeze(-1)
+
+ new_checkpoint[f"{new_prefix}.to_out.0.weight"] = (
+ checkpoint[f"{old_prefix}.proj_out.weight"].squeeze(-1).squeeze(-1)
+ )
+ new_checkpoint[f"{new_prefix}.to_out.0.bias"] = checkpoint[f"{old_prefix}.proj_out.bias"].squeeze(-1).squeeze(-1)
+
+ return new_checkpoint
+
+
+def con_pt_to_diffuser(checkpoint_path: str, unet_config):
+ checkpoint = torch.load(checkpoint_path, map_location="cpu")
+ new_checkpoint = {}
+
+ new_checkpoint["time_embedding.linear_1.weight"] = checkpoint["time_embed.0.weight"]
+ new_checkpoint["time_embedding.linear_1.bias"] = checkpoint["time_embed.0.bias"]
+ new_checkpoint["time_embedding.linear_2.weight"] = checkpoint["time_embed.2.weight"]
+ new_checkpoint["time_embedding.linear_2.bias"] = checkpoint["time_embed.2.bias"]
+
+ if unet_config["num_class_embeds"] is not None:
+ new_checkpoint["class_embedding.weight"] = checkpoint["label_emb.weight"]
+
+ new_checkpoint["conv_in.weight"] = checkpoint["input_blocks.0.0.weight"]
+ new_checkpoint["conv_in.bias"] = checkpoint["input_blocks.0.0.bias"]
+
+ down_block_types = unet_config["down_block_types"]
+ layers_per_block = unet_config["layers_per_block"]
+ attention_head_dim = unet_config["attention_head_dim"]
+ channels_list = unet_config["block_out_channels"]
+ current_layer = 1
+ prev_channels = channels_list[0]
+
+ for i, layer_type in enumerate(down_block_types):
+ current_channels = channels_list[i]
+ downsample_block_has_skip = current_channels != prev_channels
+ if layer_type == "ResnetDownsampleBlock2D":
+ for j in range(layers_per_block):
+ new_prefix = f"down_blocks.{i}.resnets.{j}"
+ old_prefix = f"input_blocks.{current_layer}.0"
+ has_skip = True if j == 0 and downsample_block_has_skip else False
+ new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix, has_skip=has_skip)
+ current_layer += 1
+
+ elif layer_type == "AttnDownBlock2D":
+ for j in range(layers_per_block):
+ new_prefix = f"down_blocks.{i}.resnets.{j}"
+ old_prefix = f"input_blocks.{current_layer}.0"
+ has_skip = True if j == 0 and downsample_block_has_skip else False
+ new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix, has_skip=has_skip)
+ new_prefix = f"down_blocks.{i}.attentions.{j}"
+ old_prefix = f"input_blocks.{current_layer}.1"
+ new_checkpoint = convert_attention(
+ checkpoint, new_checkpoint, old_prefix, new_prefix, attention_head_dim
+ )
+ current_layer += 1
+
+ if i != len(down_block_types) - 1:
+ new_prefix = f"down_blocks.{i}.downsamplers.0"
+ old_prefix = f"input_blocks.{current_layer}.0"
+ new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix)
+ current_layer += 1
+
+ prev_channels = current_channels
+
+ # hardcoded the mid-block for now
+ new_prefix = "mid_block.resnets.0"
+ old_prefix = "middle_block.0"
+ new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix)
+ new_prefix = "mid_block.attentions.0"
+ old_prefix = "middle_block.1"
+ new_checkpoint = convert_attention(checkpoint, new_checkpoint, old_prefix, new_prefix, attention_head_dim)
+ new_prefix = "mid_block.resnets.1"
+ old_prefix = "middle_block.2"
+ new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix)
+
+ current_layer = 0
+ up_block_types = unet_config["up_block_types"]
+
+ for i, layer_type in enumerate(up_block_types):
+ if layer_type == "ResnetUpsampleBlock2D":
+ for j in range(layers_per_block + 1):
+ new_prefix = f"up_blocks.{i}.resnets.{j}"
+ old_prefix = f"output_blocks.{current_layer}.0"
+ new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix, has_skip=True)
+ current_layer += 1
+
+ if i != len(up_block_types) - 1:
+ new_prefix = f"up_blocks.{i}.upsamplers.0"
+ old_prefix = f"output_blocks.{current_layer-1}.1"
+ new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix)
+ elif layer_type == "AttnUpBlock2D":
+ for j in range(layers_per_block + 1):
+ new_prefix = f"up_blocks.{i}.resnets.{j}"
+ old_prefix = f"output_blocks.{current_layer}.0"
+ new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix, has_skip=True)
+ new_prefix = f"up_blocks.{i}.attentions.{j}"
+ old_prefix = f"output_blocks.{current_layer}.1"
+ new_checkpoint = convert_attention(
+ checkpoint, new_checkpoint, old_prefix, new_prefix, attention_head_dim
+ )
+ current_layer += 1
+
+ if i != len(up_block_types) - 1:
+ new_prefix = f"up_blocks.{i}.upsamplers.0"
+ old_prefix = f"output_blocks.{current_layer-1}.2"
+ new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix)
+
+ new_checkpoint["conv_norm_out.weight"] = checkpoint["out.0.weight"]
+ new_checkpoint["conv_norm_out.bias"] = checkpoint["out.0.bias"]
+ new_checkpoint["conv_out.weight"] = checkpoint["out.2.weight"]
+ new_checkpoint["conv_out.bias"] = checkpoint["out.2.bias"]
+
+ return new_checkpoint
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument("--unet_path", default=None, type=str, required=True, help="Path to the unet.pt to convert.")
+ parser.add_argument(
+ "--dump_path", default=None, type=str, required=True, help="Path to output the converted UNet model."
+ )
+ parser.add_argument("--class_cond", default=True, type=str, help="Whether the model is class-conditional.")
+
+ args = parser.parse_args()
+ args.class_cond = str2bool(args.class_cond)
+
+ ckpt_name = os.path.basename(args.unet_path)
+ print(f"Checkpoint: {ckpt_name}")
+
+ # Get U-Net config
+ if "imagenet64" in ckpt_name:
+ unet_config = IMAGENET_64_UNET_CONFIG
+ elif "256" in ckpt_name and (("bedroom" in ckpt_name) or ("cat" in ckpt_name)):
+ unet_config = LSUN_256_UNET_CONFIG
+ elif "test" in ckpt_name:
+ unet_config = TEST_UNET_CONFIG
+ else:
+ raise ValueError(f"Checkpoint type {ckpt_name} is not currently supported.")
+
+ if not args.class_cond:
+ unet_config["num_class_embeds"] = None
+
+ converted_unet_ckpt = con_pt_to_diffuser(args.unet_path, unet_config)
+
+ image_unet = UNet2DModel(**unet_config)
+ image_unet.load_state_dict(converted_unet_ckpt)
+
+ # Get scheduler config
+ if "cd" in ckpt_name or "test" in ckpt_name:
+ scheduler_config = CD_SCHEDULER_CONFIG
+ elif "ct" in ckpt_name and "imagenet64" in ckpt_name:
+ scheduler_config = CT_IMAGENET_64_SCHEDULER_CONFIG
+ elif "ct" in ckpt_name and "256" in ckpt_name and (("bedroom" in ckpt_name) or ("cat" in ckpt_name)):
+ scheduler_config = CT_LSUN_256_SCHEDULER_CONFIG
+ else:
+ raise ValueError(f"Checkpoint type {ckpt_name} is not currently supported.")
+
+ cm_scheduler = CMStochasticIterativeScheduler(**scheduler_config)
+
+ consistency_model = ConsistencyModelPipeline(unet=image_unet, scheduler=cm_scheduler)
+ consistency_model.save_pretrained(args.dump_path)
diff --git a/diffusers/scripts/convert_dance_diffusion_to_diffusers.py b/diffusers/scripts/convert_dance_diffusion_to_diffusers.py
new file mode 100644
index 0000000000000000000000000000000000000000..ce69bfe2bfc8e4828517004e5bea6335339f3466
--- /dev/null
+++ b/diffusers/scripts/convert_dance_diffusion_to_diffusers.py
@@ -0,0 +1,345 @@
+#!/usr/bin/env python3
+import argparse
+import math
+import os
+from copy import deepcopy
+
+import requests
+import torch
+from audio_diffusion.models import DiffusionAttnUnet1D
+from diffusion import sampling
+from torch import nn
+
+from diffusers import DanceDiffusionPipeline, IPNDMScheduler, UNet1DModel
+
+
+MODELS_MAP = {
+ "gwf-440k": {
+ "url": "https://model-server.zqevans2.workers.dev/gwf-440k.ckpt",
+ "sample_rate": 48000,
+ "sample_size": 65536,
+ },
+ "jmann-small-190k": {
+ "url": "https://model-server.zqevans2.workers.dev/jmann-small-190k.ckpt",
+ "sample_rate": 48000,
+ "sample_size": 65536,
+ },
+ "jmann-large-580k": {
+ "url": "https://model-server.zqevans2.workers.dev/jmann-large-580k.ckpt",
+ "sample_rate": 48000,
+ "sample_size": 131072,
+ },
+ "maestro-uncond-150k": {
+ "url": "https://model-server.zqevans2.workers.dev/maestro-uncond-150k.ckpt",
+ "sample_rate": 16000,
+ "sample_size": 65536,
+ },
+ "unlocked-uncond-250k": {
+ "url": "https://model-server.zqevans2.workers.dev/unlocked-uncond-250k.ckpt",
+ "sample_rate": 16000,
+ "sample_size": 65536,
+ },
+ "honk-140k": {
+ "url": "https://model-server.zqevans2.workers.dev/honk-140k.ckpt",
+ "sample_rate": 16000,
+ "sample_size": 65536,
+ },
+}
+
+
+def alpha_sigma_to_t(alpha, sigma):
+ """Returns a timestep, given the scaling factors for the clean image and for
+ the noise."""
+ return torch.atan2(sigma, alpha) / math.pi * 2
+
+
+def get_crash_schedule(t):
+ sigma = torch.sin(t * math.pi / 2) ** 2
+ alpha = (1 - sigma**2) ** 0.5
+ return alpha_sigma_to_t(alpha, sigma)
+
+
+class Object(object):
+ pass
+
+
+class DiffusionUncond(nn.Module):
+ def __init__(self, global_args):
+ super().__init__()
+
+ self.diffusion = DiffusionAttnUnet1D(global_args, n_attn_layers=4)
+ self.diffusion_ema = deepcopy(self.diffusion)
+ self.rng = torch.quasirandom.SobolEngine(1, scramble=True)
+
+
+def download(model_name):
+ url = MODELS_MAP[model_name]["url"]
+ r = requests.get(url, stream=True)
+
+ local_filename = f"./{model_name}.ckpt"
+ with open(local_filename, "wb") as fp:
+ for chunk in r.iter_content(chunk_size=8192):
+ fp.write(chunk)
+
+ return local_filename
+
+
+DOWN_NUM_TO_LAYER = {
+ "1": "resnets.0",
+ "2": "attentions.0",
+ "3": "resnets.1",
+ "4": "attentions.1",
+ "5": "resnets.2",
+ "6": "attentions.2",
+}
+UP_NUM_TO_LAYER = {
+ "8": "resnets.0",
+ "9": "attentions.0",
+ "10": "resnets.1",
+ "11": "attentions.1",
+ "12": "resnets.2",
+ "13": "attentions.2",
+}
+MID_NUM_TO_LAYER = {
+ "1": "resnets.0",
+ "2": "attentions.0",
+ "3": "resnets.1",
+ "4": "attentions.1",
+ "5": "resnets.2",
+ "6": "attentions.2",
+ "8": "resnets.3",
+ "9": "attentions.3",
+ "10": "resnets.4",
+ "11": "attentions.4",
+ "12": "resnets.5",
+ "13": "attentions.5",
+}
+DEPTH_0_TO_LAYER = {
+ "0": "resnets.0",
+ "1": "resnets.1",
+ "2": "resnets.2",
+ "4": "resnets.0",
+ "5": "resnets.1",
+ "6": "resnets.2",
+}
+
+RES_CONV_MAP = {
+ "skip": "conv_skip",
+ "main.0": "conv_1",
+ "main.1": "group_norm_1",
+ "main.3": "conv_2",
+ "main.4": "group_norm_2",
+}
+
+ATTN_MAP = {
+ "norm": "group_norm",
+ "qkv_proj": ["query", "key", "value"],
+ "out_proj": ["proj_attn"],
+}
+
+
+def convert_resconv_naming(name):
+ if name.startswith("skip"):
+ return name.replace("skip", RES_CONV_MAP["skip"])
+
+ # name has to be of format main.{digit}
+ if not name.startswith("main."):
+ raise ValueError(f"ResConvBlock error with {name}")
+
+ return name.replace(name[:6], RES_CONV_MAP[name[:6]])
+
+
+def convert_attn_naming(name):
+ for key, value in ATTN_MAP.items():
+ if name.startswith(key) and not isinstance(value, list):
+ return name.replace(key, value)
+ elif name.startswith(key):
+ return [name.replace(key, v) for v in value]
+ raise ValueError(f"Attn error with {name}")
+
+
+def rename(input_string, max_depth=13):
+ string = input_string
+
+ if string.split(".")[0] == "timestep_embed":
+ return string.replace("timestep_embed", "time_proj")
+
+ depth = 0
+ if string.startswith("net.3."):
+ depth += 1
+ string = string[6:]
+ elif string.startswith("net."):
+ string = string[4:]
+
+ while string.startswith("main.7."):
+ depth += 1
+ string = string[7:]
+
+ if string.startswith("main."):
+ string = string[5:]
+
+ # mid block
+ if string[:2].isdigit():
+ layer_num = string[:2]
+ string_left = string[2:]
+ else:
+ layer_num = string[0]
+ string_left = string[1:]
+
+ if depth == max_depth:
+ new_layer = MID_NUM_TO_LAYER[layer_num]
+ prefix = "mid_block"
+ elif depth > 0 and int(layer_num) < 7:
+ new_layer = DOWN_NUM_TO_LAYER[layer_num]
+ prefix = f"down_blocks.{depth}"
+ elif depth > 0 and int(layer_num) > 7:
+ new_layer = UP_NUM_TO_LAYER[layer_num]
+ prefix = f"up_blocks.{max_depth - depth - 1}"
+ elif depth == 0:
+ new_layer = DEPTH_0_TO_LAYER[layer_num]
+ prefix = f"up_blocks.{max_depth - 1}" if int(layer_num) > 3 else "down_blocks.0"
+
+ if not string_left.startswith("."):
+ raise ValueError(f"Naming error with {input_string} and string_left: {string_left}.")
+
+ string_left = string_left[1:]
+
+ if "resnets" in new_layer:
+ string_left = convert_resconv_naming(string_left)
+ elif "attentions" in new_layer:
+ new_string_left = convert_attn_naming(string_left)
+ string_left = new_string_left
+
+ if not isinstance(string_left, list):
+ new_string = prefix + "." + new_layer + "." + string_left
+ else:
+ new_string = [prefix + "." + new_layer + "." + s for s in string_left]
+ return new_string
+
+
+def rename_orig_weights(state_dict):
+ new_state_dict = {}
+ for k, v in state_dict.items():
+ if k.endswith("kernel"):
+ # up- and downsample layers, don't have trainable weights
+ continue
+
+ new_k = rename(k)
+
+ # check if we need to transform from Conv => Linear for attention
+ if isinstance(new_k, list):
+ new_state_dict = transform_conv_attns(new_state_dict, new_k, v)
+ else:
+ new_state_dict[new_k] = v
+
+ return new_state_dict
+
+
+def transform_conv_attns(new_state_dict, new_k, v):
+ if len(new_k) == 1:
+ if len(v.shape) == 3:
+ # weight
+ new_state_dict[new_k[0]] = v[:, :, 0]
+ else:
+ # bias
+ new_state_dict[new_k[0]] = v
+ else:
+ # qkv matrices
+ trippled_shape = v.shape[0]
+ single_shape = trippled_shape // 3
+ for i in range(3):
+ if len(v.shape) == 3:
+ new_state_dict[new_k[i]] = v[i * single_shape : (i + 1) * single_shape, :, 0]
+ else:
+ new_state_dict[new_k[i]] = v[i * single_shape : (i + 1) * single_shape]
+ return new_state_dict
+
+
+def main(args):
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+ model_name = args.model_path.split("/")[-1].split(".")[0]
+ if not os.path.isfile(args.model_path):
+ assert (
+ model_name == args.model_path
+ ), f"Make sure to provide one of the official model names {MODELS_MAP.keys()}"
+ args.model_path = download(model_name)
+
+ sample_rate = MODELS_MAP[model_name]["sample_rate"]
+ sample_size = MODELS_MAP[model_name]["sample_size"]
+
+ config = Object()
+ config.sample_size = sample_size
+ config.sample_rate = sample_rate
+ config.latent_dim = 0
+
+ diffusers_model = UNet1DModel(sample_size=sample_size, sample_rate=sample_rate)
+ diffusers_state_dict = diffusers_model.state_dict()
+
+ orig_model = DiffusionUncond(config)
+ orig_model.load_state_dict(torch.load(args.model_path, map_location=device)["state_dict"])
+ orig_model = orig_model.diffusion_ema.eval()
+ orig_model_state_dict = orig_model.state_dict()
+ renamed_state_dict = rename_orig_weights(orig_model_state_dict)
+
+ renamed_minus_diffusers = set(renamed_state_dict.keys()) - set(diffusers_state_dict.keys())
+ diffusers_minus_renamed = set(diffusers_state_dict.keys()) - set(renamed_state_dict.keys())
+
+ assert len(renamed_minus_diffusers) == 0, f"Problem with {renamed_minus_diffusers}"
+ assert all(k.endswith("kernel") for k in list(diffusers_minus_renamed)), f"Problem with {diffusers_minus_renamed}"
+
+ for key, value in renamed_state_dict.items():
+ assert (
+ diffusers_state_dict[key].squeeze().shape == value.squeeze().shape
+ ), f"Shape for {key} doesn't match. Diffusers: {diffusers_state_dict[key].shape} vs. {value.shape}"
+ if key == "time_proj.weight":
+ value = value.squeeze()
+
+ diffusers_state_dict[key] = value
+
+ diffusers_model.load_state_dict(diffusers_state_dict)
+
+ steps = 100
+ seed = 33
+
+ diffusers_scheduler = IPNDMScheduler(num_train_timesteps=steps)
+
+ generator = torch.manual_seed(seed)
+ noise = torch.randn([1, 2, config.sample_size], generator=generator).to(device)
+
+ t = torch.linspace(1, 0, steps + 1, device=device)[:-1]
+ step_list = get_crash_schedule(t)
+
+ pipe = DanceDiffusionPipeline(unet=diffusers_model, scheduler=diffusers_scheduler)
+
+ generator = torch.manual_seed(33)
+ audio = pipe(num_inference_steps=steps, generator=generator).audios
+
+ generated = sampling.iplms_sample(orig_model, noise, step_list, {})
+ generated = generated.clamp(-1, 1)
+
+ diff_sum = (generated - audio).abs().sum()
+ diff_max = (generated - audio).abs().max()
+
+ if args.save:
+ pipe.save_pretrained(args.checkpoint_path)
+
+ print("Diff sum", diff_sum)
+ print("Diff max", diff_max)
+
+ assert diff_max < 1e-3, f"Diff max: {diff_max} is too much :-/"
+
+ print(f"Conversion for {model_name} successful!")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument("--model_path", default=None, type=str, required=True, help="Path to the model to convert.")
+ parser.add_argument(
+ "--save", default=True, type=bool, required=False, help="Whether to save the converted model or not."
+ )
+ parser.add_argument("--checkpoint_path", default=None, type=str, required=True, help="Path to the output model.")
+ args = parser.parse_args()
+
+ main(args)
diff --git a/diffusers/scripts/convert_ddpm_original_checkpoint_to_diffusers.py b/diffusers/scripts/convert_ddpm_original_checkpoint_to_diffusers.py
new file mode 100644
index 0000000000000000000000000000000000000000..46595784b0bac0016b623b7122082275248363e9
--- /dev/null
+++ b/diffusers/scripts/convert_ddpm_original_checkpoint_to_diffusers.py
@@ -0,0 +1,431 @@
+import argparse
+import json
+
+import torch
+
+from diffusers import AutoencoderKL, DDPMPipeline, DDPMScheduler, UNet2DModel, VQModel
+
+
+def shave_segments(path, n_shave_prefix_segments=1):
+ """
+ Removes segments. Positive values shave the first segments, negative shave the last segments.
+ """
+ if n_shave_prefix_segments >= 0:
+ return ".".join(path.split(".")[n_shave_prefix_segments:])
+ else:
+ return ".".join(path.split(".")[:n_shave_prefix_segments])
+
+
+def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
+ mapping = []
+ for old_item in old_list:
+ new_item = old_item
+ new_item = new_item.replace("block.", "resnets.")
+ new_item = new_item.replace("conv_shorcut", "conv1")
+ new_item = new_item.replace("in_shortcut", "conv_shortcut")
+ new_item = new_item.replace("temb_proj", "time_emb_proj")
+
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
+
+ mapping.append({"old": old_item, "new": new_item})
+
+ return mapping
+
+
+def renew_attention_paths(old_list, n_shave_prefix_segments=0, in_mid=False):
+ mapping = []
+ for old_item in old_list:
+ new_item = old_item
+
+ # In `model.mid`, the layer is called `attn`.
+ if not in_mid:
+ new_item = new_item.replace("attn", "attentions")
+ new_item = new_item.replace(".k.", ".key.")
+ new_item = new_item.replace(".v.", ".value.")
+ new_item = new_item.replace(".q.", ".query.")
+
+ new_item = new_item.replace("proj_out", "proj_attn")
+ new_item = new_item.replace("norm", "group_norm")
+
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
+ mapping.append({"old": old_item, "new": new_item})
+
+ return mapping
+
+
+def assign_to_checkpoint(
+ paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
+):
+ assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
+
+ if attention_paths_to_split is not None:
+ if config is None:
+ raise ValueError("Please specify the config if setting 'attention_paths_to_split' to 'True'.")
+
+ for path, path_map in attention_paths_to_split.items():
+ old_tensor = old_checkpoint[path]
+ channels = old_tensor.shape[0] // 3
+
+ target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
+
+ num_heads = old_tensor.shape[0] // config.get("num_head_channels", 1) // 3
+
+ old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
+ query, key, value = old_tensor.split(channels // num_heads, dim=1)
+
+ checkpoint[path_map["query"]] = query.reshape(target_shape).squeeze()
+ checkpoint[path_map["key"]] = key.reshape(target_shape).squeeze()
+ checkpoint[path_map["value"]] = value.reshape(target_shape).squeeze()
+
+ for path in paths:
+ new_path = path["new"]
+
+ if attention_paths_to_split is not None and new_path in attention_paths_to_split:
+ continue
+
+ new_path = new_path.replace("down.", "down_blocks.")
+ new_path = new_path.replace("up.", "up_blocks.")
+
+ if additional_replacements is not None:
+ for replacement in additional_replacements:
+ new_path = new_path.replace(replacement["old"], replacement["new"])
+
+ if "attentions" in new_path:
+ checkpoint[new_path] = old_checkpoint[path["old"]].squeeze()
+ else:
+ checkpoint[new_path] = old_checkpoint[path["old"]]
+
+
+def convert_ddpm_checkpoint(checkpoint, config):
+ """
+ Takes a state dict and a config, and returns a converted checkpoint.
+ """
+ new_checkpoint = {}
+
+ new_checkpoint["time_embedding.linear_1.weight"] = checkpoint["temb.dense.0.weight"]
+ new_checkpoint["time_embedding.linear_1.bias"] = checkpoint["temb.dense.0.bias"]
+ new_checkpoint["time_embedding.linear_2.weight"] = checkpoint["temb.dense.1.weight"]
+ new_checkpoint["time_embedding.linear_2.bias"] = checkpoint["temb.dense.1.bias"]
+
+ new_checkpoint["conv_norm_out.weight"] = checkpoint["norm_out.weight"]
+ new_checkpoint["conv_norm_out.bias"] = checkpoint["norm_out.bias"]
+
+ new_checkpoint["conv_in.weight"] = checkpoint["conv_in.weight"]
+ new_checkpoint["conv_in.bias"] = checkpoint["conv_in.bias"]
+ new_checkpoint["conv_out.weight"] = checkpoint["conv_out.weight"]
+ new_checkpoint["conv_out.bias"] = checkpoint["conv_out.bias"]
+
+ num_down_blocks = len({".".join(layer.split(".")[:2]) for layer in checkpoint if "down" in layer})
+ down_blocks = {
+ layer_id: [key for key in checkpoint if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
+ }
+
+ num_up_blocks = len({".".join(layer.split(".")[:2]) for layer in checkpoint if "up" in layer})
+ up_blocks = {layer_id: [key for key in checkpoint if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)}
+
+ for i in range(num_down_blocks):
+ block_id = (i - 1) // (config["layers_per_block"] + 1)
+
+ if any("downsample" in layer for layer in down_blocks[i]):
+ new_checkpoint[f"down_blocks.{i}.downsamplers.0.conv.weight"] = checkpoint[
+ f"down.{i}.downsample.op.weight"
+ ]
+ new_checkpoint[f"down_blocks.{i}.downsamplers.0.conv.bias"] = checkpoint[f"down.{i}.downsample.op.bias"]
+ # new_checkpoint[f'down_blocks.{i}.downsamplers.0.op.weight'] = checkpoint[f'down.{i}.downsample.conv.weight']
+ # new_checkpoint[f'down_blocks.{i}.downsamplers.0.op.bias'] = checkpoint[f'down.{i}.downsample.conv.bias']
+
+ if any("block" in layer for layer in down_blocks[i]):
+ num_blocks = len(
+ {".".join(shave_segments(layer, 2).split(".")[:2]) for layer in down_blocks[i] if "block" in layer}
+ )
+ blocks = {
+ layer_id: [key for key in down_blocks[i] if f"block.{layer_id}" in key]
+ for layer_id in range(num_blocks)
+ }
+
+ if num_blocks > 0:
+ for j in range(config["layers_per_block"]):
+ paths = renew_resnet_paths(blocks[j])
+ assign_to_checkpoint(paths, new_checkpoint, checkpoint)
+
+ if any("attn" in layer for layer in down_blocks[i]):
+ num_attn = len(
+ {".".join(shave_segments(layer, 2).split(".")[:2]) for layer in down_blocks[i] if "attn" in layer}
+ )
+ attns = {
+ layer_id: [key for key in down_blocks[i] if f"attn.{layer_id}" in key]
+ for layer_id in range(num_blocks)
+ }
+
+ if num_attn > 0:
+ for j in range(config["layers_per_block"]):
+ paths = renew_attention_paths(attns[j])
+ assign_to_checkpoint(paths, new_checkpoint, checkpoint, config=config)
+
+ mid_block_1_layers = [key for key in checkpoint if "mid.block_1" in key]
+ mid_block_2_layers = [key for key in checkpoint if "mid.block_2" in key]
+ mid_attn_1_layers = [key for key in checkpoint if "mid.attn_1" in key]
+
+ # Mid new 2
+ paths = renew_resnet_paths(mid_block_1_layers)
+ assign_to_checkpoint(
+ paths,
+ new_checkpoint,
+ checkpoint,
+ additional_replacements=[{"old": "mid.", "new": "mid_new_2."}, {"old": "block_1", "new": "resnets.0"}],
+ )
+
+ paths = renew_resnet_paths(mid_block_2_layers)
+ assign_to_checkpoint(
+ paths,
+ new_checkpoint,
+ checkpoint,
+ additional_replacements=[{"old": "mid.", "new": "mid_new_2."}, {"old": "block_2", "new": "resnets.1"}],
+ )
+
+ paths = renew_attention_paths(mid_attn_1_layers, in_mid=True)
+ assign_to_checkpoint(
+ paths,
+ new_checkpoint,
+ checkpoint,
+ additional_replacements=[{"old": "mid.", "new": "mid_new_2."}, {"old": "attn_1", "new": "attentions.0"}],
+ )
+
+ for i in range(num_up_blocks):
+ block_id = num_up_blocks - 1 - i
+
+ if any("upsample" in layer for layer in up_blocks[i]):
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = checkpoint[
+ f"up.{i}.upsample.conv.weight"
+ ]
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = checkpoint[f"up.{i}.upsample.conv.bias"]
+
+ if any("block" in layer for layer in up_blocks[i]):
+ num_blocks = len(
+ {".".join(shave_segments(layer, 2).split(".")[:2]) for layer in up_blocks[i] if "block" in layer}
+ )
+ blocks = {
+ layer_id: [key for key in up_blocks[i] if f"block.{layer_id}" in key] for layer_id in range(num_blocks)
+ }
+
+ if num_blocks > 0:
+ for j in range(config["layers_per_block"] + 1):
+ replace_indices = {"old": f"up_blocks.{i}", "new": f"up_blocks.{block_id}"}
+ paths = renew_resnet_paths(blocks[j])
+ assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[replace_indices])
+
+ if any("attn" in layer for layer in up_blocks[i]):
+ num_attn = len(
+ {".".join(shave_segments(layer, 2).split(".")[:2]) for layer in up_blocks[i] if "attn" in layer}
+ )
+ attns = {
+ layer_id: [key for key in up_blocks[i] if f"attn.{layer_id}" in key] for layer_id in range(num_blocks)
+ }
+
+ if num_attn > 0:
+ for j in range(config["layers_per_block"] + 1):
+ replace_indices = {"old": f"up_blocks.{i}", "new": f"up_blocks.{block_id}"}
+ paths = renew_attention_paths(attns[j])
+ assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[replace_indices])
+
+ new_checkpoint = {k.replace("mid_new_2", "mid_block"): v for k, v in new_checkpoint.items()}
+ return new_checkpoint
+
+
+def convert_vq_autoenc_checkpoint(checkpoint, config):
+ """
+ Takes a state dict and a config, and returns a converted checkpoint.
+ """
+ new_checkpoint = {}
+
+ new_checkpoint["encoder.conv_norm_out.weight"] = checkpoint["encoder.norm_out.weight"]
+ new_checkpoint["encoder.conv_norm_out.bias"] = checkpoint["encoder.norm_out.bias"]
+
+ new_checkpoint["encoder.conv_in.weight"] = checkpoint["encoder.conv_in.weight"]
+ new_checkpoint["encoder.conv_in.bias"] = checkpoint["encoder.conv_in.bias"]
+ new_checkpoint["encoder.conv_out.weight"] = checkpoint["encoder.conv_out.weight"]
+ new_checkpoint["encoder.conv_out.bias"] = checkpoint["encoder.conv_out.bias"]
+
+ new_checkpoint["decoder.conv_norm_out.weight"] = checkpoint["decoder.norm_out.weight"]
+ new_checkpoint["decoder.conv_norm_out.bias"] = checkpoint["decoder.norm_out.bias"]
+
+ new_checkpoint["decoder.conv_in.weight"] = checkpoint["decoder.conv_in.weight"]
+ new_checkpoint["decoder.conv_in.bias"] = checkpoint["decoder.conv_in.bias"]
+ new_checkpoint["decoder.conv_out.weight"] = checkpoint["decoder.conv_out.weight"]
+ new_checkpoint["decoder.conv_out.bias"] = checkpoint["decoder.conv_out.bias"]
+
+ num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in checkpoint if "down" in layer})
+ down_blocks = {
+ layer_id: [key for key in checkpoint if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
+ }
+
+ num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in checkpoint if "up" in layer})
+ up_blocks = {layer_id: [key for key in checkpoint if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)}
+
+ for i in range(num_down_blocks):
+ block_id = (i - 1) // (config["layers_per_block"] + 1)
+
+ if any("downsample" in layer for layer in down_blocks[i]):
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = checkpoint[
+ f"encoder.down.{i}.downsample.conv.weight"
+ ]
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = checkpoint[
+ f"encoder.down.{i}.downsample.conv.bias"
+ ]
+
+ if any("block" in layer for layer in down_blocks[i]):
+ num_blocks = len(
+ {".".join(shave_segments(layer, 3).split(".")[:3]) for layer in down_blocks[i] if "block" in layer}
+ )
+ blocks = {
+ layer_id: [key for key in down_blocks[i] if f"block.{layer_id}" in key]
+ for layer_id in range(num_blocks)
+ }
+
+ if num_blocks > 0:
+ for j in range(config["layers_per_block"]):
+ paths = renew_resnet_paths(blocks[j])
+ assign_to_checkpoint(paths, new_checkpoint, checkpoint)
+
+ if any("attn" in layer for layer in down_blocks[i]):
+ num_attn = len(
+ {".".join(shave_segments(layer, 3).split(".")[:3]) for layer in down_blocks[i] if "attn" in layer}
+ )
+ attns = {
+ layer_id: [key for key in down_blocks[i] if f"attn.{layer_id}" in key]
+ for layer_id in range(num_blocks)
+ }
+
+ if num_attn > 0:
+ for j in range(config["layers_per_block"]):
+ paths = renew_attention_paths(attns[j])
+ assign_to_checkpoint(paths, new_checkpoint, checkpoint, config=config)
+
+ mid_block_1_layers = [key for key in checkpoint if "mid.block_1" in key]
+ mid_block_2_layers = [key for key in checkpoint if "mid.block_2" in key]
+ mid_attn_1_layers = [key for key in checkpoint if "mid.attn_1" in key]
+
+ # Mid new 2
+ paths = renew_resnet_paths(mid_block_1_layers)
+ assign_to_checkpoint(
+ paths,
+ new_checkpoint,
+ checkpoint,
+ additional_replacements=[{"old": "mid.", "new": "mid_new_2."}, {"old": "block_1", "new": "resnets.0"}],
+ )
+
+ paths = renew_resnet_paths(mid_block_2_layers)
+ assign_to_checkpoint(
+ paths,
+ new_checkpoint,
+ checkpoint,
+ additional_replacements=[{"old": "mid.", "new": "mid_new_2."}, {"old": "block_2", "new": "resnets.1"}],
+ )
+
+ paths = renew_attention_paths(mid_attn_1_layers, in_mid=True)
+ assign_to_checkpoint(
+ paths,
+ new_checkpoint,
+ checkpoint,
+ additional_replacements=[{"old": "mid.", "new": "mid_new_2."}, {"old": "attn_1", "new": "attentions.0"}],
+ )
+
+ for i in range(num_up_blocks):
+ block_id = num_up_blocks - 1 - i
+
+ if any("upsample" in layer for layer in up_blocks[i]):
+ new_checkpoint[f"decoder.up_blocks.{block_id}.upsamplers.0.conv.weight"] = checkpoint[
+ f"decoder.up.{i}.upsample.conv.weight"
+ ]
+ new_checkpoint[f"decoder.up_blocks.{block_id}.upsamplers.0.conv.bias"] = checkpoint[
+ f"decoder.up.{i}.upsample.conv.bias"
+ ]
+
+ if any("block" in layer for layer in up_blocks[i]):
+ num_blocks = len(
+ {".".join(shave_segments(layer, 3).split(".")[:3]) for layer in up_blocks[i] if "block" in layer}
+ )
+ blocks = {
+ layer_id: [key for key in up_blocks[i] if f"block.{layer_id}" in key] for layer_id in range(num_blocks)
+ }
+
+ if num_blocks > 0:
+ for j in range(config["layers_per_block"] + 1):
+ replace_indices = {"old": f"up_blocks.{i}", "new": f"up_blocks.{block_id}"}
+ paths = renew_resnet_paths(blocks[j])
+ assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[replace_indices])
+
+ if any("attn" in layer for layer in up_blocks[i]):
+ num_attn = len(
+ {".".join(shave_segments(layer, 3).split(".")[:3]) for layer in up_blocks[i] if "attn" in layer}
+ )
+ attns = {
+ layer_id: [key for key in up_blocks[i] if f"attn.{layer_id}" in key] for layer_id in range(num_blocks)
+ }
+
+ if num_attn > 0:
+ for j in range(config["layers_per_block"] + 1):
+ replace_indices = {"old": f"up_blocks.{i}", "new": f"up_blocks.{block_id}"}
+ paths = renew_attention_paths(attns[j])
+ assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[replace_indices])
+
+ new_checkpoint = {k.replace("mid_new_2", "mid_block"): v for k, v in new_checkpoint.items()}
+ new_checkpoint["quant_conv.weight"] = checkpoint["quant_conv.weight"]
+ new_checkpoint["quant_conv.bias"] = checkpoint["quant_conv.bias"]
+ if "quantize.embedding.weight" in checkpoint:
+ new_checkpoint["quantize.embedding.weight"] = checkpoint["quantize.embedding.weight"]
+ new_checkpoint["post_quant_conv.weight"] = checkpoint["post_quant_conv.weight"]
+ new_checkpoint["post_quant_conv.bias"] = checkpoint["post_quant_conv.bias"]
+
+ return new_checkpoint
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ "--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert."
+ )
+
+ parser.add_argument(
+ "--config_file",
+ default=None,
+ type=str,
+ required=True,
+ help="The config json file corresponding to the architecture.",
+ )
+
+ parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
+
+ args = parser.parse_args()
+ checkpoint = torch.load(args.checkpoint_path)
+
+ with open(args.config_file) as f:
+ config = json.loads(f.read())
+
+ # unet case
+ key_prefix_set = {key.split(".")[0] for key in checkpoint.keys()}
+ if "encoder" in key_prefix_set and "decoder" in key_prefix_set:
+ converted_checkpoint = convert_vq_autoenc_checkpoint(checkpoint, config)
+ else:
+ converted_checkpoint = convert_ddpm_checkpoint(checkpoint, config)
+
+ if "ddpm" in config:
+ del config["ddpm"]
+
+ if config["_class_name"] == "VQModel":
+ model = VQModel(**config)
+ model.load_state_dict(converted_checkpoint)
+ model.save_pretrained(args.dump_path)
+ elif config["_class_name"] == "AutoencoderKL":
+ model = AutoencoderKL(**config)
+ model.load_state_dict(converted_checkpoint)
+ model.save_pretrained(args.dump_path)
+ else:
+ model = UNet2DModel(**config)
+ model.load_state_dict(converted_checkpoint)
+
+ scheduler = DDPMScheduler.from_config("/".join(args.checkpoint_path.split("/")[:-1]))
+
+ pipe = DDPMPipeline(unet=model, scheduler=scheduler)
+ pipe.save_pretrained(args.dump_path)
diff --git a/diffusers/scripts/convert_diffusers_sdxl_lora_to_webui.py b/diffusers/scripts/convert_diffusers_sdxl_lora_to_webui.py
new file mode 100644
index 0000000000000000000000000000000000000000..dfb3871275cbc68809379596a9209e08f377936c
--- /dev/null
+++ b/diffusers/scripts/convert_diffusers_sdxl_lora_to_webui.py
@@ -0,0 +1,56 @@
+# Script for converting a Hugging Face Diffusers trained SDXL LoRAs to Kohya format
+# This means that you can input your diffusers-trained LoRAs and
+# Get the output to work with WebUIs such as AUTOMATIC1111, ComfyUI, SD.Next and others.
+
+# To get started you can find some cool `diffusers` trained LoRAs such as this cute Corgy
+# https://huggingface.co/ignasbud/corgy_dog_LoRA/, download its `pytorch_lora_weights.safetensors` file
+# and run the script:
+# python convert_diffusers_sdxl_lora_to_webui.py --input_lora pytorch_lora_weights.safetensors --output_lora corgy.safetensors
+# now you can use corgy.safetensors in your WebUI of choice!
+
+# To train your own, here are some diffusers training scripts and utils that you can use and then convert:
+# LoRA Ease - no code SDXL Dreambooth LoRA trainer: https://huggingface.co/spaces/multimodalart/lora-ease
+# Dreambooth Advanced Training Script - state of the art techniques such as pivotal tuning and prodigy optimizer:
+# - Script: https://github.com/huggingface/diffusers/blob/main/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py
+# - Colab (only on Pro): https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/SDXL_Dreambooth_LoRA_advanced_example.ipynb
+# Canonical diffusers training scripts:
+# - Script: https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth_lora_sdxl.py
+# - Colab (runs on free tier): https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/SDXL_DreamBooth_LoRA_.ipynb
+
+import argparse
+import os
+
+from safetensors.torch import load_file, save_file
+
+from diffusers.utils import convert_all_state_dict_to_peft, convert_state_dict_to_kohya
+
+
+def convert_and_save(input_lora, output_lora=None):
+ if output_lora is None:
+ base_name = os.path.splitext(input_lora)[0]
+ output_lora = f"{base_name}_webui.safetensors"
+
+ diffusers_state_dict = load_file(input_lora)
+ peft_state_dict = convert_all_state_dict_to_peft(diffusers_state_dict)
+ kohya_state_dict = convert_state_dict_to_kohya(peft_state_dict)
+ save_file(kohya_state_dict, output_lora)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="Convert LoRA model to PEFT and then to Kohya format.")
+ parser.add_argument(
+ "--input_lora",
+ type=str,
+ required=True,
+ help="Path to the input LoRA model file in the diffusers format.",
+ )
+ parser.add_argument(
+ "--output_lora",
+ type=str,
+ required=False,
+ help="Path for the converted LoRA (safetensors format for AUTOMATIC1111, ComfyUI, etc.). Optional, defaults to input name with a _webui suffix.",
+ )
+
+ args = parser.parse_args()
+
+ convert_and_save(args.input_lora, args.output_lora)
diff --git a/diffusers/scripts/convert_diffusers_to_original_sdxl.py b/diffusers/scripts/convert_diffusers_to_original_sdxl.py
new file mode 100644
index 0000000000000000000000000000000000000000..648d0376f72eb8d73663247ae40be62b1c960eda
--- /dev/null
+++ b/diffusers/scripts/convert_diffusers_to_original_sdxl.py
@@ -0,0 +1,350 @@
+# Script for converting a HF Diffusers saved pipeline to a Stable Diffusion checkpoint.
+# *Only* converts the UNet, VAE, and Text Encoder.
+# Does not convert optimizer state or any other thing.
+
+import argparse
+import os.path as osp
+import re
+
+import torch
+from safetensors.torch import load_file, save_file
+
+
+# =================#
+# UNet Conversion #
+# =================#
+
+unet_conversion_map = [
+ # (stable-diffusion, HF Diffusers)
+ ("time_embed.0.weight", "time_embedding.linear_1.weight"),
+ ("time_embed.0.bias", "time_embedding.linear_1.bias"),
+ ("time_embed.2.weight", "time_embedding.linear_2.weight"),
+ ("time_embed.2.bias", "time_embedding.linear_2.bias"),
+ ("input_blocks.0.0.weight", "conv_in.weight"),
+ ("input_blocks.0.0.bias", "conv_in.bias"),
+ ("out.0.weight", "conv_norm_out.weight"),
+ ("out.0.bias", "conv_norm_out.bias"),
+ ("out.2.weight", "conv_out.weight"),
+ ("out.2.bias", "conv_out.bias"),
+ # the following are for sdxl
+ ("label_emb.0.0.weight", "add_embedding.linear_1.weight"),
+ ("label_emb.0.0.bias", "add_embedding.linear_1.bias"),
+ ("label_emb.0.2.weight", "add_embedding.linear_2.weight"),
+ ("label_emb.0.2.bias", "add_embedding.linear_2.bias"),
+]
+
+unet_conversion_map_resnet = [
+ # (stable-diffusion, HF Diffusers)
+ ("in_layers.0", "norm1"),
+ ("in_layers.2", "conv1"),
+ ("out_layers.0", "norm2"),
+ ("out_layers.3", "conv2"),
+ ("emb_layers.1", "time_emb_proj"),
+ ("skip_connection", "conv_shortcut"),
+]
+
+unet_conversion_map_layer = []
+# hardcoded number of downblocks and resnets/attentions...
+# would need smarter logic for other networks.
+for i in range(3):
+ # loop over downblocks/upblocks
+
+ for j in range(2):
+ # loop over resnets/attentions for downblocks
+ hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
+ sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
+ unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
+
+ if i > 0:
+ hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
+ sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
+ unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
+
+ for j in range(4):
+ # loop over resnets/attentions for upblocks
+ hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
+ sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
+ unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
+
+ if i < 2:
+ # no attention layers in up_blocks.0
+ hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
+ sd_up_atn_prefix = f"output_blocks.{3 * i + j}.1."
+ unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
+
+ if i < 3:
+ # no downsample in down_blocks.3
+ hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
+ sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
+ unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
+
+ # no upsample in up_blocks.3
+ hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
+ sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}."
+ unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
+unet_conversion_map_layer.append(("output_blocks.2.2.conv.", "output_blocks.2.1.conv."))
+
+hf_mid_atn_prefix = "mid_block.attentions.0."
+sd_mid_atn_prefix = "middle_block.1."
+unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
+for j in range(2):
+ hf_mid_res_prefix = f"mid_block.resnets.{j}."
+ sd_mid_res_prefix = f"middle_block.{2*j}."
+ unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
+
+
+def convert_unet_state_dict(unet_state_dict):
+ # buyer beware: this is a *brittle* function,
+ # and correct output requires that all of these pieces interact in
+ # the exact order in which I have arranged them.
+ mapping = {k: k for k in unet_state_dict.keys()}
+ for sd_name, hf_name in unet_conversion_map:
+ mapping[hf_name] = sd_name
+ for k, v in mapping.items():
+ if "resnets" in k:
+ for sd_part, hf_part in unet_conversion_map_resnet:
+ v = v.replace(hf_part, sd_part)
+ mapping[k] = v
+ for k, v in mapping.items():
+ for sd_part, hf_part in unet_conversion_map_layer:
+ v = v.replace(hf_part, sd_part)
+ mapping[k] = v
+ new_state_dict = {sd_name: unet_state_dict[hf_name] for hf_name, sd_name in mapping.items()}
+ return new_state_dict
+
+
+# ================#
+# VAE Conversion #
+# ================#
+
+vae_conversion_map = [
+ # (stable-diffusion, HF Diffusers)
+ ("nin_shortcut", "conv_shortcut"),
+ ("norm_out", "conv_norm_out"),
+ ("mid.attn_1.", "mid_block.attentions.0."),
+]
+
+for i in range(4):
+ # down_blocks have two resnets
+ for j in range(2):
+ hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}."
+ sd_down_prefix = f"encoder.down.{i}.block.{j}."
+ vae_conversion_map.append((sd_down_prefix, hf_down_prefix))
+
+ if i < 3:
+ hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0."
+ sd_downsample_prefix = f"down.{i}.downsample."
+ vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
+
+ hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
+ sd_upsample_prefix = f"up.{3-i}.upsample."
+ vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
+
+ # up_blocks have three resnets
+ # also, up blocks in hf are numbered in reverse from sd
+ for j in range(3):
+ hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
+ sd_up_prefix = f"decoder.up.{3-i}.block.{j}."
+ vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
+
+# this part accounts for mid blocks in both the encoder and the decoder
+for i in range(2):
+ hf_mid_res_prefix = f"mid_block.resnets.{i}."
+ sd_mid_res_prefix = f"mid.block_{i+1}."
+ vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
+
+
+vae_conversion_map_attn = [
+ # (stable-diffusion, HF Diffusers)
+ ("norm.", "group_norm."),
+ # the following are for SDXL
+ ("q.", "to_q."),
+ ("k.", "to_k."),
+ ("v.", "to_v."),
+ ("proj_out.", "to_out.0."),
+]
+
+
+def reshape_weight_for_sd(w):
+ # convert HF linear weights to SD conv2d weights
+ if not w.ndim == 1:
+ return w.reshape(*w.shape, 1, 1)
+ else:
+ return w
+
+
+def convert_vae_state_dict(vae_state_dict):
+ mapping = {k: k for k in vae_state_dict.keys()}
+ for k, v in mapping.items():
+ for sd_part, hf_part in vae_conversion_map:
+ v = v.replace(hf_part, sd_part)
+ mapping[k] = v
+ for k, v in mapping.items():
+ if "attentions" in k:
+ for sd_part, hf_part in vae_conversion_map_attn:
+ v = v.replace(hf_part, sd_part)
+ mapping[k] = v
+ new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}
+ weights_to_convert = ["q", "k", "v", "proj_out"]
+ for k, v in new_state_dict.items():
+ for weight_name in weights_to_convert:
+ if f"mid.attn_1.{weight_name}.weight" in k:
+ print(f"Reshaping {k} for SD format")
+ new_state_dict[k] = reshape_weight_for_sd(v)
+ return new_state_dict
+
+
+# =========================#
+# Text Encoder Conversion #
+# =========================#
+
+
+textenc_conversion_lst = [
+ # (stable-diffusion, HF Diffusers)
+ ("transformer.resblocks.", "text_model.encoder.layers."),
+ ("ln_1", "layer_norm1"),
+ ("ln_2", "layer_norm2"),
+ (".c_fc.", ".fc1."),
+ (".c_proj.", ".fc2."),
+ (".attn", ".self_attn"),
+ ("ln_final.", "text_model.final_layer_norm."),
+ ("token_embedding.weight", "text_model.embeddings.token_embedding.weight"),
+ ("positional_embedding", "text_model.embeddings.position_embedding.weight"),
+]
+protected = {re.escape(x[1]): x[0] for x in textenc_conversion_lst}
+textenc_pattern = re.compile("|".join(protected.keys()))
+
+# Ordering is from https://github.com/pytorch/pytorch/blob/master/test/cpp/api/modules.cpp
+code2idx = {"q": 0, "k": 1, "v": 2}
+
+
+def convert_openclip_text_enc_state_dict(text_enc_dict):
+ new_state_dict = {}
+ capture_qkv_weight = {}
+ capture_qkv_bias = {}
+ for k, v in text_enc_dict.items():
+ if (
+ k.endswith(".self_attn.q_proj.weight")
+ or k.endswith(".self_attn.k_proj.weight")
+ or k.endswith(".self_attn.v_proj.weight")
+ ):
+ k_pre = k[: -len(".q_proj.weight")]
+ k_code = k[-len("q_proj.weight")]
+ if k_pre not in capture_qkv_weight:
+ capture_qkv_weight[k_pre] = [None, None, None]
+ capture_qkv_weight[k_pre][code2idx[k_code]] = v
+ continue
+
+ if (
+ k.endswith(".self_attn.q_proj.bias")
+ or k.endswith(".self_attn.k_proj.bias")
+ or k.endswith(".self_attn.v_proj.bias")
+ ):
+ k_pre = k[: -len(".q_proj.bias")]
+ k_code = k[-len("q_proj.bias")]
+ if k_pre not in capture_qkv_bias:
+ capture_qkv_bias[k_pre] = [None, None, None]
+ capture_qkv_bias[k_pre][code2idx[k_code]] = v
+ continue
+
+ relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k)
+ new_state_dict[relabelled_key] = v
+
+ for k_pre, tensors in capture_qkv_weight.items():
+ if None in tensors:
+ raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
+ relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
+ new_state_dict[relabelled_key + ".in_proj_weight"] = torch.cat(tensors)
+
+ for k_pre, tensors in capture_qkv_bias.items():
+ if None in tensors:
+ raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
+ relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
+ new_state_dict[relabelled_key + ".in_proj_bias"] = torch.cat(tensors)
+
+ return new_state_dict
+
+
+def convert_openai_text_enc_state_dict(text_enc_dict):
+ return text_enc_dict
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument("--model_path", default=None, type=str, required=True, help="Path to the model to convert.")
+ parser.add_argument("--checkpoint_path", default=None, type=str, required=True, help="Path to the output model.")
+ parser.add_argument("--half", action="store_true", help="Save weights in half precision.")
+ parser.add_argument(
+ "--use_safetensors", action="store_true", help="Save weights use safetensors, default is ckpt."
+ )
+
+ args = parser.parse_args()
+
+ assert args.model_path is not None, "Must provide a model path!"
+
+ assert args.checkpoint_path is not None, "Must provide a checkpoint path!"
+
+ # Path for safetensors
+ unet_path = osp.join(args.model_path, "unet", "diffusion_pytorch_model.safetensors")
+ vae_path = osp.join(args.model_path, "vae", "diffusion_pytorch_model.safetensors")
+ text_enc_path = osp.join(args.model_path, "text_encoder", "model.safetensors")
+ text_enc_2_path = osp.join(args.model_path, "text_encoder_2", "model.safetensors")
+
+ # Load models from safetensors if it exists, if it doesn't pytorch
+ if osp.exists(unet_path):
+ unet_state_dict = load_file(unet_path, device="cpu")
+ else:
+ unet_path = osp.join(args.model_path, "unet", "diffusion_pytorch_model.bin")
+ unet_state_dict = torch.load(unet_path, map_location="cpu")
+
+ if osp.exists(vae_path):
+ vae_state_dict = load_file(vae_path, device="cpu")
+ else:
+ vae_path = osp.join(args.model_path, "vae", "diffusion_pytorch_model.bin")
+ vae_state_dict = torch.load(vae_path, map_location="cpu")
+
+ if osp.exists(text_enc_path):
+ text_enc_dict = load_file(text_enc_path, device="cpu")
+ else:
+ text_enc_path = osp.join(args.model_path, "text_encoder", "pytorch_model.bin")
+ text_enc_dict = torch.load(text_enc_path, map_location="cpu")
+
+ if osp.exists(text_enc_2_path):
+ text_enc_2_dict = load_file(text_enc_2_path, device="cpu")
+ else:
+ text_enc_2_path = osp.join(args.model_path, "text_encoder_2", "pytorch_model.bin")
+ text_enc_2_dict = torch.load(text_enc_2_path, map_location="cpu")
+
+ # Convert the UNet model
+ unet_state_dict = convert_unet_state_dict(unet_state_dict)
+ unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()}
+
+ # Convert the VAE model
+ vae_state_dict = convert_vae_state_dict(vae_state_dict)
+ vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()}
+
+ # Convert text encoder 1
+ text_enc_dict = convert_openai_text_enc_state_dict(text_enc_dict)
+ text_enc_dict = {"conditioner.embedders.0.transformer." + k: v for k, v in text_enc_dict.items()}
+
+ # Convert text encoder 2
+ text_enc_2_dict = convert_openclip_text_enc_state_dict(text_enc_2_dict)
+ text_enc_2_dict = {"conditioner.embedders.1.model." + k: v for k, v in text_enc_2_dict.items()}
+ # We call the `.T.contiguous()` to match what's done in
+ # https://github.com/huggingface/diffusers/blob/84905ca7287876b925b6bf8e9bb92fec21c78764/src/diffusers/loaders/single_file_utils.py#L1085
+ text_enc_2_dict["conditioner.embedders.1.model.text_projection"] = text_enc_2_dict.pop(
+ "conditioner.embedders.1.model.text_projection.weight"
+ ).T.contiguous()
+
+ # Put together new checkpoint
+ state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict, **text_enc_2_dict}
+
+ if args.half:
+ state_dict = {k: v.half() for k, v in state_dict.items()}
+
+ if args.use_safetensors:
+ save_file(state_dict, args.checkpoint_path)
+ else:
+ state_dict = {"state_dict": state_dict}
+ torch.save(state_dict, args.checkpoint_path)
diff --git a/diffusers/scripts/convert_diffusers_to_original_stable_diffusion.py b/diffusers/scripts/convert_diffusers_to_original_stable_diffusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..d1b7df070c436b4d1a5fad2b9a096f1999aacc33
--- /dev/null
+++ b/diffusers/scripts/convert_diffusers_to_original_stable_diffusion.py
@@ -0,0 +1,353 @@
+# Script for converting a HF Diffusers saved pipeline to a Stable Diffusion checkpoint.
+# *Only* converts the UNet, VAE, and Text Encoder.
+# Does not convert optimizer state or any other thing.
+
+import argparse
+import os.path as osp
+import re
+
+import torch
+from safetensors.torch import load_file, save_file
+
+
+# =================#
+# UNet Conversion #
+# =================#
+
+unet_conversion_map = [
+ # (stable-diffusion, HF Diffusers)
+ ("time_embed.0.weight", "time_embedding.linear_1.weight"),
+ ("time_embed.0.bias", "time_embedding.linear_1.bias"),
+ ("time_embed.2.weight", "time_embedding.linear_2.weight"),
+ ("time_embed.2.bias", "time_embedding.linear_2.bias"),
+ ("input_blocks.0.0.weight", "conv_in.weight"),
+ ("input_blocks.0.0.bias", "conv_in.bias"),
+ ("out.0.weight", "conv_norm_out.weight"),
+ ("out.0.bias", "conv_norm_out.bias"),
+ ("out.2.weight", "conv_out.weight"),
+ ("out.2.bias", "conv_out.bias"),
+]
+
+unet_conversion_map_resnet = [
+ # (stable-diffusion, HF Diffusers)
+ ("in_layers.0", "norm1"),
+ ("in_layers.2", "conv1"),
+ ("out_layers.0", "norm2"),
+ ("out_layers.3", "conv2"),
+ ("emb_layers.1", "time_emb_proj"),
+ ("skip_connection", "conv_shortcut"),
+]
+
+unet_conversion_map_layer = []
+# hardcoded number of downblocks and resnets/attentions...
+# would need smarter logic for other networks.
+for i in range(4):
+ # loop over downblocks/upblocks
+
+ for j in range(2):
+ # loop over resnets/attentions for downblocks
+ hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
+ sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
+ unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
+
+ if i < 3:
+ # no attention layers in down_blocks.3
+ hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
+ sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
+ unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
+
+ for j in range(3):
+ # loop over resnets/attentions for upblocks
+ hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
+ sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
+ unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
+
+ if i > 0:
+ # no attention layers in up_blocks.0
+ hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
+ sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
+ unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
+
+ if i < 3:
+ # no downsample in down_blocks.3
+ hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
+ sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
+ unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
+
+ # no upsample in up_blocks.3
+ hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
+ sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}."
+ unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
+
+hf_mid_atn_prefix = "mid_block.attentions.0."
+sd_mid_atn_prefix = "middle_block.1."
+unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
+
+for j in range(2):
+ hf_mid_res_prefix = f"mid_block.resnets.{j}."
+ sd_mid_res_prefix = f"middle_block.{2*j}."
+ unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
+
+
+def convert_unet_state_dict(unet_state_dict):
+ # buyer beware: this is a *brittle* function,
+ # and correct output requires that all of these pieces interact in
+ # the exact order in which I have arranged them.
+ mapping = {k: k for k in unet_state_dict.keys()}
+ for sd_name, hf_name in unet_conversion_map:
+ mapping[hf_name] = sd_name
+ for k, v in mapping.items():
+ if "resnets" in k:
+ for sd_part, hf_part in unet_conversion_map_resnet:
+ v = v.replace(hf_part, sd_part)
+ mapping[k] = v
+ for k, v in mapping.items():
+ for sd_part, hf_part in unet_conversion_map_layer:
+ v = v.replace(hf_part, sd_part)
+ mapping[k] = v
+ new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
+ return new_state_dict
+
+
+# ================#
+# VAE Conversion #
+# ================#
+
+vae_conversion_map = [
+ # (stable-diffusion, HF Diffusers)
+ ("nin_shortcut", "conv_shortcut"),
+ ("norm_out", "conv_norm_out"),
+ ("mid.attn_1.", "mid_block.attentions.0."),
+]
+
+for i in range(4):
+ # down_blocks have two resnets
+ for j in range(2):
+ hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}."
+ sd_down_prefix = f"encoder.down.{i}.block.{j}."
+ vae_conversion_map.append((sd_down_prefix, hf_down_prefix))
+
+ if i < 3:
+ hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0."
+ sd_downsample_prefix = f"down.{i}.downsample."
+ vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
+
+ hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
+ sd_upsample_prefix = f"up.{3-i}.upsample."
+ vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
+
+ # up_blocks have three resnets
+ # also, up blocks in hf are numbered in reverse from sd
+ for j in range(3):
+ hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
+ sd_up_prefix = f"decoder.up.{3-i}.block.{j}."
+ vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
+
+# this part accounts for mid blocks in both the encoder and the decoder
+for i in range(2):
+ hf_mid_res_prefix = f"mid_block.resnets.{i}."
+ sd_mid_res_prefix = f"mid.block_{i+1}."
+ vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
+
+
+vae_conversion_map_attn = [
+ # (stable-diffusion, HF Diffusers)
+ ("norm.", "group_norm."),
+ ("q.", "query."),
+ ("k.", "key."),
+ ("v.", "value."),
+ ("proj_out.", "proj_attn."),
+]
+
+# This is probably not the most ideal solution, but it does work.
+vae_extra_conversion_map = [
+ ("to_q", "q"),
+ ("to_k", "k"),
+ ("to_v", "v"),
+ ("to_out.0", "proj_out"),
+]
+
+
+def reshape_weight_for_sd(w):
+ # convert HF linear weights to SD conv2d weights
+ if not w.ndim == 1:
+ return w.reshape(*w.shape, 1, 1)
+ else:
+ return w
+
+
+def convert_vae_state_dict(vae_state_dict):
+ mapping = {k: k for k in vae_state_dict.keys()}
+ for k, v in mapping.items():
+ for sd_part, hf_part in vae_conversion_map:
+ v = v.replace(hf_part, sd_part)
+ mapping[k] = v
+ for k, v in mapping.items():
+ if "attentions" in k:
+ for sd_part, hf_part in vae_conversion_map_attn:
+ v = v.replace(hf_part, sd_part)
+ mapping[k] = v
+ new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}
+ weights_to_convert = ["q", "k", "v", "proj_out"]
+ keys_to_rename = {}
+ for k, v in new_state_dict.items():
+ for weight_name in weights_to_convert:
+ if f"mid.attn_1.{weight_name}.weight" in k:
+ print(f"Reshaping {k} for SD format")
+ new_state_dict[k] = reshape_weight_for_sd(v)
+ for weight_name, real_weight_name in vae_extra_conversion_map:
+ if f"mid.attn_1.{weight_name}.weight" in k or f"mid.attn_1.{weight_name}.bias" in k:
+ keys_to_rename[k] = k.replace(weight_name, real_weight_name)
+ for k, v in keys_to_rename.items():
+ if k in new_state_dict:
+ print(f"Renaming {k} to {v}")
+ new_state_dict[v] = reshape_weight_for_sd(new_state_dict[k])
+ del new_state_dict[k]
+ return new_state_dict
+
+
+# =========================#
+# Text Encoder Conversion #
+# =========================#
+
+
+textenc_conversion_lst = [
+ # (stable-diffusion, HF Diffusers)
+ ("resblocks.", "text_model.encoder.layers."),
+ ("ln_1", "layer_norm1"),
+ ("ln_2", "layer_norm2"),
+ (".c_fc.", ".fc1."),
+ (".c_proj.", ".fc2."),
+ (".attn", ".self_attn"),
+ ("ln_final.", "transformer.text_model.final_layer_norm."),
+ ("token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"),
+ ("positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"),
+]
+protected = {re.escape(x[1]): x[0] for x in textenc_conversion_lst}
+textenc_pattern = re.compile("|".join(protected.keys()))
+
+# Ordering is from https://github.com/pytorch/pytorch/blob/master/test/cpp/api/modules.cpp
+code2idx = {"q": 0, "k": 1, "v": 2}
+
+
+def convert_text_enc_state_dict_v20(text_enc_dict):
+ new_state_dict = {}
+ capture_qkv_weight = {}
+ capture_qkv_bias = {}
+ for k, v in text_enc_dict.items():
+ if (
+ k.endswith(".self_attn.q_proj.weight")
+ or k.endswith(".self_attn.k_proj.weight")
+ or k.endswith(".self_attn.v_proj.weight")
+ ):
+ k_pre = k[: -len(".q_proj.weight")]
+ k_code = k[-len("q_proj.weight")]
+ if k_pre not in capture_qkv_weight:
+ capture_qkv_weight[k_pre] = [None, None, None]
+ capture_qkv_weight[k_pre][code2idx[k_code]] = v
+ continue
+
+ if (
+ k.endswith(".self_attn.q_proj.bias")
+ or k.endswith(".self_attn.k_proj.bias")
+ or k.endswith(".self_attn.v_proj.bias")
+ ):
+ k_pre = k[: -len(".q_proj.bias")]
+ k_code = k[-len("q_proj.bias")]
+ if k_pre not in capture_qkv_bias:
+ capture_qkv_bias[k_pre] = [None, None, None]
+ capture_qkv_bias[k_pre][code2idx[k_code]] = v
+ continue
+
+ relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k)
+ new_state_dict[relabelled_key] = v
+
+ for k_pre, tensors in capture_qkv_weight.items():
+ if None in tensors:
+ raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
+ relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
+ new_state_dict[relabelled_key + ".in_proj_weight"] = torch.cat(tensors)
+
+ for k_pre, tensors in capture_qkv_bias.items():
+ if None in tensors:
+ raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
+ relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
+ new_state_dict[relabelled_key + ".in_proj_bias"] = torch.cat(tensors)
+
+ return new_state_dict
+
+
+def convert_text_enc_state_dict(text_enc_dict):
+ return text_enc_dict
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument("--model_path", default=None, type=str, required=True, help="Path to the model to convert.")
+ parser.add_argument("--checkpoint_path", default=None, type=str, required=True, help="Path to the output model.")
+ parser.add_argument("--half", action="store_true", help="Save weights in half precision.")
+ parser.add_argument(
+ "--use_safetensors", action="store_true", help="Save weights use safetensors, default is ckpt."
+ )
+
+ args = parser.parse_args()
+
+ assert args.model_path is not None, "Must provide a model path!"
+
+ assert args.checkpoint_path is not None, "Must provide a checkpoint path!"
+
+ # Path for safetensors
+ unet_path = osp.join(args.model_path, "unet", "diffusion_pytorch_model.safetensors")
+ vae_path = osp.join(args.model_path, "vae", "diffusion_pytorch_model.safetensors")
+ text_enc_path = osp.join(args.model_path, "text_encoder", "model.safetensors")
+
+ # Load models from safetensors if it exists, if it doesn't pytorch
+ if osp.exists(unet_path):
+ unet_state_dict = load_file(unet_path, device="cpu")
+ else:
+ unet_path = osp.join(args.model_path, "unet", "diffusion_pytorch_model.bin")
+ unet_state_dict = torch.load(unet_path, map_location="cpu")
+
+ if osp.exists(vae_path):
+ vae_state_dict = load_file(vae_path, device="cpu")
+ else:
+ vae_path = osp.join(args.model_path, "vae", "diffusion_pytorch_model.bin")
+ vae_state_dict = torch.load(vae_path, map_location="cpu")
+
+ if osp.exists(text_enc_path):
+ text_enc_dict = load_file(text_enc_path, device="cpu")
+ else:
+ text_enc_path = osp.join(args.model_path, "text_encoder", "pytorch_model.bin")
+ text_enc_dict = torch.load(text_enc_path, map_location="cpu")
+
+ # Convert the UNet model
+ unet_state_dict = convert_unet_state_dict(unet_state_dict)
+ unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()}
+
+ # Convert the VAE model
+ vae_state_dict = convert_vae_state_dict(vae_state_dict)
+ vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()}
+
+ # Easiest way to identify v2.0 model seems to be that the text encoder (OpenCLIP) is deeper
+ is_v20_model = "text_model.encoder.layers.22.layer_norm2.bias" in text_enc_dict
+
+ if is_v20_model:
+ # Need to add the tag 'transformer' in advance so we can knock it out from the final layer-norm
+ text_enc_dict = {"transformer." + k: v for k, v in text_enc_dict.items()}
+ text_enc_dict = convert_text_enc_state_dict_v20(text_enc_dict)
+ text_enc_dict = {"cond_stage_model.model." + k: v for k, v in text_enc_dict.items()}
+ else:
+ text_enc_dict = convert_text_enc_state_dict(text_enc_dict)
+ text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()}
+
+ # Put together new checkpoint
+ state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict}
+ if args.half:
+ state_dict = {k: v.half() for k, v in state_dict.items()}
+
+ if args.use_safetensors:
+ save_file(state_dict, args.checkpoint_path)
+ else:
+ state_dict = {"state_dict": state_dict}
+ torch.save(state_dict, args.checkpoint_path)
diff --git a/diffusers/scripts/convert_dit_to_diffusers.py b/diffusers/scripts/convert_dit_to_diffusers.py
new file mode 100644
index 0000000000000000000000000000000000000000..dc127f69555c260f594e70444b1540faa196e3fb
--- /dev/null
+++ b/diffusers/scripts/convert_dit_to_diffusers.py
@@ -0,0 +1,162 @@
+import argparse
+import os
+
+import torch
+from torchvision.datasets.utils import download_url
+
+from diffusers import AutoencoderKL, DDIMScheduler, DiTPipeline, Transformer2DModel
+
+
+pretrained_models = {512: "DiT-XL-2-512x512.pt", 256: "DiT-XL-2-256x256.pt"}
+
+
+def download_model(model_name):
+ """
+ Downloads a pre-trained DiT model from the web.
+ """
+ local_path = f"pretrained_models/{model_name}"
+ if not os.path.isfile(local_path):
+ os.makedirs("pretrained_models", exist_ok=True)
+ web_path = f"https://dl.fbaipublicfiles.com/DiT/models/{model_name}"
+ download_url(web_path, "pretrained_models")
+ model = torch.load(local_path, map_location=lambda storage, loc: storage)
+ return model
+
+
+def main(args):
+ state_dict = download_model(pretrained_models[args.image_size])
+
+ state_dict["pos_embed.proj.weight"] = state_dict["x_embedder.proj.weight"]
+ state_dict["pos_embed.proj.bias"] = state_dict["x_embedder.proj.bias"]
+ state_dict.pop("x_embedder.proj.weight")
+ state_dict.pop("x_embedder.proj.bias")
+
+ for depth in range(28):
+ state_dict[f"transformer_blocks.{depth}.norm1.emb.timestep_embedder.linear_1.weight"] = state_dict[
+ "t_embedder.mlp.0.weight"
+ ]
+ state_dict[f"transformer_blocks.{depth}.norm1.emb.timestep_embedder.linear_1.bias"] = state_dict[
+ "t_embedder.mlp.0.bias"
+ ]
+ state_dict[f"transformer_blocks.{depth}.norm1.emb.timestep_embedder.linear_2.weight"] = state_dict[
+ "t_embedder.mlp.2.weight"
+ ]
+ state_dict[f"transformer_blocks.{depth}.norm1.emb.timestep_embedder.linear_2.bias"] = state_dict[
+ "t_embedder.mlp.2.bias"
+ ]
+ state_dict[f"transformer_blocks.{depth}.norm1.emb.class_embedder.embedding_table.weight"] = state_dict[
+ "y_embedder.embedding_table.weight"
+ ]
+
+ state_dict[f"transformer_blocks.{depth}.norm1.linear.weight"] = state_dict[
+ f"blocks.{depth}.adaLN_modulation.1.weight"
+ ]
+ state_dict[f"transformer_blocks.{depth}.norm1.linear.bias"] = state_dict[
+ f"blocks.{depth}.adaLN_modulation.1.bias"
+ ]
+
+ q, k, v = torch.chunk(state_dict[f"blocks.{depth}.attn.qkv.weight"], 3, dim=0)
+ q_bias, k_bias, v_bias = torch.chunk(state_dict[f"blocks.{depth}.attn.qkv.bias"], 3, dim=0)
+
+ state_dict[f"transformer_blocks.{depth}.attn1.to_q.weight"] = q
+ state_dict[f"transformer_blocks.{depth}.attn1.to_q.bias"] = q_bias
+ state_dict[f"transformer_blocks.{depth}.attn1.to_k.weight"] = k
+ state_dict[f"transformer_blocks.{depth}.attn1.to_k.bias"] = k_bias
+ state_dict[f"transformer_blocks.{depth}.attn1.to_v.weight"] = v
+ state_dict[f"transformer_blocks.{depth}.attn1.to_v.bias"] = v_bias
+
+ state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.weight"] = state_dict[
+ f"blocks.{depth}.attn.proj.weight"
+ ]
+ state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.bias"] = state_dict[f"blocks.{depth}.attn.proj.bias"]
+
+ state_dict[f"transformer_blocks.{depth}.ff.net.0.proj.weight"] = state_dict[f"blocks.{depth}.mlp.fc1.weight"]
+ state_dict[f"transformer_blocks.{depth}.ff.net.0.proj.bias"] = state_dict[f"blocks.{depth}.mlp.fc1.bias"]
+ state_dict[f"transformer_blocks.{depth}.ff.net.2.weight"] = state_dict[f"blocks.{depth}.mlp.fc2.weight"]
+ state_dict[f"transformer_blocks.{depth}.ff.net.2.bias"] = state_dict[f"blocks.{depth}.mlp.fc2.bias"]
+
+ state_dict.pop(f"blocks.{depth}.attn.qkv.weight")
+ state_dict.pop(f"blocks.{depth}.attn.qkv.bias")
+ state_dict.pop(f"blocks.{depth}.attn.proj.weight")
+ state_dict.pop(f"blocks.{depth}.attn.proj.bias")
+ state_dict.pop(f"blocks.{depth}.mlp.fc1.weight")
+ state_dict.pop(f"blocks.{depth}.mlp.fc1.bias")
+ state_dict.pop(f"blocks.{depth}.mlp.fc2.weight")
+ state_dict.pop(f"blocks.{depth}.mlp.fc2.bias")
+ state_dict.pop(f"blocks.{depth}.adaLN_modulation.1.weight")
+ state_dict.pop(f"blocks.{depth}.adaLN_modulation.1.bias")
+
+ state_dict.pop("t_embedder.mlp.0.weight")
+ state_dict.pop("t_embedder.mlp.0.bias")
+ state_dict.pop("t_embedder.mlp.2.weight")
+ state_dict.pop("t_embedder.mlp.2.bias")
+ state_dict.pop("y_embedder.embedding_table.weight")
+
+ state_dict["proj_out_1.weight"] = state_dict["final_layer.adaLN_modulation.1.weight"]
+ state_dict["proj_out_1.bias"] = state_dict["final_layer.adaLN_modulation.1.bias"]
+ state_dict["proj_out_2.weight"] = state_dict["final_layer.linear.weight"]
+ state_dict["proj_out_2.bias"] = state_dict["final_layer.linear.bias"]
+
+ state_dict.pop("final_layer.linear.weight")
+ state_dict.pop("final_layer.linear.bias")
+ state_dict.pop("final_layer.adaLN_modulation.1.weight")
+ state_dict.pop("final_layer.adaLN_modulation.1.bias")
+
+ # DiT XL/2
+ transformer = Transformer2DModel(
+ sample_size=args.image_size // 8,
+ num_layers=28,
+ attention_head_dim=72,
+ in_channels=4,
+ out_channels=8,
+ patch_size=2,
+ attention_bias=True,
+ num_attention_heads=16,
+ activation_fn="gelu-approximate",
+ num_embeds_ada_norm=1000,
+ norm_type="ada_norm_zero",
+ norm_elementwise_affine=False,
+ )
+ transformer.load_state_dict(state_dict, strict=True)
+
+ scheduler = DDIMScheduler(
+ num_train_timesteps=1000,
+ beta_schedule="linear",
+ prediction_type="epsilon",
+ clip_sample=False,
+ )
+
+ vae = AutoencoderKL.from_pretrained(args.vae_model)
+
+ pipeline = DiTPipeline(transformer=transformer, vae=vae, scheduler=scheduler)
+
+ if args.save:
+ pipeline.save_pretrained(args.checkpoint_path)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ "--image_size",
+ default=256,
+ type=int,
+ required=False,
+ help="Image size of pretrained model, either 256 or 512.",
+ )
+ parser.add_argument(
+ "--vae_model",
+ default="stabilityai/sd-vae-ft-ema",
+ type=str,
+ required=False,
+ help="Path to pretrained VAE model, either stabilityai/sd-vae-ft-mse or stabilityai/sd-vae-ft-ema.",
+ )
+ parser.add_argument(
+ "--save", default=True, type=bool, required=False, help="Whether to save the converted pipeline or not."
+ )
+ parser.add_argument(
+ "--checkpoint_path", default=None, type=str, required=True, help="Path to the output pipeline."
+ )
+
+ args = parser.parse_args()
+ main(args)
diff --git a/diffusers/scripts/convert_flux_to_diffusers.py b/diffusers/scripts/convert_flux_to_diffusers.py
new file mode 100644
index 0000000000000000000000000000000000000000..05a1da256d336aaefc8ad630fc17b1a587dd5d09
--- /dev/null
+++ b/diffusers/scripts/convert_flux_to_diffusers.py
@@ -0,0 +1,303 @@
+import argparse
+from contextlib import nullcontext
+
+import safetensors.torch
+import torch
+from accelerate import init_empty_weights
+from huggingface_hub import hf_hub_download
+
+from diffusers import AutoencoderKL, FluxTransformer2DModel
+from diffusers.loaders.single_file_utils import convert_ldm_vae_checkpoint
+from diffusers.utils.import_utils import is_accelerate_available
+
+
+"""
+# Transformer
+
+python scripts/convert_flux_to_diffusers.py \
+--original_state_dict_repo_id "black-forest-labs/FLUX.1-schnell" \
+--filename "flux1-schnell.sft"
+--output_path "flux-schnell" \
+--transformer
+"""
+
+"""
+# VAE
+
+python scripts/convert_flux_to_diffusers.py \
+--original_state_dict_repo_id "black-forest-labs/FLUX.1-schnell" \
+--filename "ae.sft"
+--output_path "flux-schnell" \
+--vae
+"""
+
+CTX = init_empty_weights if is_accelerate_available else nullcontext
+
+parser = argparse.ArgumentParser()
+parser.add_argument("--original_state_dict_repo_id", default=None, type=str)
+parser.add_argument("--filename", default="flux.safetensors", type=str)
+parser.add_argument("--checkpoint_path", default=None, type=str)
+parser.add_argument("--vae", action="store_true")
+parser.add_argument("--transformer", action="store_true")
+parser.add_argument("--output_path", type=str)
+parser.add_argument("--dtype", type=str, default="bf16")
+
+args = parser.parse_args()
+dtype = torch.bfloat16 if args.dtype == "bf16" else torch.float32
+
+
+def load_original_checkpoint(args):
+ if args.original_state_dict_repo_id is not None:
+ ckpt_path = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename=args.filename)
+ elif args.checkpoint_path is not None:
+ ckpt_path = args.checkpoint_path
+ else:
+ raise ValueError(" please provide either `original_state_dict_repo_id` or a local `checkpoint_path`")
+
+ original_state_dict = safetensors.torch.load_file(ckpt_path)
+ return original_state_dict
+
+
+# in SD3 original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale;
+# while in diffusers it split into scale, shift. Here we swap the linear projection weights in order to be able to use diffusers implementation
+def swap_scale_shift(weight):
+ shift, scale = weight.chunk(2, dim=0)
+ new_weight = torch.cat([scale, shift], dim=0)
+ return new_weight
+
+
+def convert_flux_transformer_checkpoint_to_diffusers(
+ original_state_dict, num_layers, num_single_layers, inner_dim, mlp_ratio=4.0
+):
+ converted_state_dict = {}
+
+ ## time_text_embed.timestep_embedder <- time_in
+ converted_state_dict["time_text_embed.timestep_embedder.linear_1.weight"] = original_state_dict.pop(
+ "time_in.in_layer.weight"
+ )
+ converted_state_dict["time_text_embed.timestep_embedder.linear_1.bias"] = original_state_dict.pop(
+ "time_in.in_layer.bias"
+ )
+ converted_state_dict["time_text_embed.timestep_embedder.linear_2.weight"] = original_state_dict.pop(
+ "time_in.out_layer.weight"
+ )
+ converted_state_dict["time_text_embed.timestep_embedder.linear_2.bias"] = original_state_dict.pop(
+ "time_in.out_layer.bias"
+ )
+
+ ## time_text_embed.text_embedder <- vector_in
+ converted_state_dict["time_text_embed.text_embedder.linear_1.weight"] = original_state_dict.pop(
+ "vector_in.in_layer.weight"
+ )
+ converted_state_dict["time_text_embed.text_embedder.linear_1.bias"] = original_state_dict.pop(
+ "vector_in.in_layer.bias"
+ )
+ converted_state_dict["time_text_embed.text_embedder.linear_2.weight"] = original_state_dict.pop(
+ "vector_in.out_layer.weight"
+ )
+ converted_state_dict["time_text_embed.text_embedder.linear_2.bias"] = original_state_dict.pop(
+ "vector_in.out_layer.bias"
+ )
+
+ # guidance
+ has_guidance = any("guidance" in k for k in original_state_dict)
+ if has_guidance:
+ converted_state_dict["time_text_embed.guidance_embedder.linear_1.weight"] = original_state_dict.pop(
+ "guidance_in.in_layer.weight"
+ )
+ converted_state_dict["time_text_embed.guidance_embedder.linear_1.bias"] = original_state_dict.pop(
+ "guidance_in.in_layer.bias"
+ )
+ converted_state_dict["time_text_embed.guidance_embedder.linear_2.weight"] = original_state_dict.pop(
+ "guidance_in.out_layer.weight"
+ )
+ converted_state_dict["time_text_embed.guidance_embedder.linear_2.bias"] = original_state_dict.pop(
+ "guidance_in.out_layer.bias"
+ )
+
+ # context_embedder
+ converted_state_dict["context_embedder.weight"] = original_state_dict.pop("txt_in.weight")
+ converted_state_dict["context_embedder.bias"] = original_state_dict.pop("txt_in.bias")
+
+ # x_embedder
+ converted_state_dict["x_embedder.weight"] = original_state_dict.pop("img_in.weight")
+ converted_state_dict["x_embedder.bias"] = original_state_dict.pop("img_in.bias")
+
+ # double transformer blocks
+ for i in range(num_layers):
+ block_prefix = f"transformer_blocks.{i}."
+ # norms.
+ ## norm1
+ converted_state_dict[f"{block_prefix}norm1.linear.weight"] = original_state_dict.pop(
+ f"double_blocks.{i}.img_mod.lin.weight"
+ )
+ converted_state_dict[f"{block_prefix}norm1.linear.bias"] = original_state_dict.pop(
+ f"double_blocks.{i}.img_mod.lin.bias"
+ )
+ ## norm1_context
+ converted_state_dict[f"{block_prefix}norm1_context.linear.weight"] = original_state_dict.pop(
+ f"double_blocks.{i}.txt_mod.lin.weight"
+ )
+ converted_state_dict[f"{block_prefix}norm1_context.linear.bias"] = original_state_dict.pop(
+ f"double_blocks.{i}.txt_mod.lin.bias"
+ )
+ # Q, K, V
+ sample_q, sample_k, sample_v = torch.chunk(
+ original_state_dict.pop(f"double_blocks.{i}.img_attn.qkv.weight"), 3, dim=0
+ )
+ context_q, context_k, context_v = torch.chunk(
+ original_state_dict.pop(f"double_blocks.{i}.txt_attn.qkv.weight"), 3, dim=0
+ )
+ sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk(
+ original_state_dict.pop(f"double_blocks.{i}.img_attn.qkv.bias"), 3, dim=0
+ )
+ context_q_bias, context_k_bias, context_v_bias = torch.chunk(
+ original_state_dict.pop(f"double_blocks.{i}.txt_attn.qkv.bias"), 3, dim=0
+ )
+ converted_state_dict[f"{block_prefix}attn.to_q.weight"] = torch.cat([sample_q])
+ converted_state_dict[f"{block_prefix}attn.to_q.bias"] = torch.cat([sample_q_bias])
+ converted_state_dict[f"{block_prefix}attn.to_k.weight"] = torch.cat([sample_k])
+ converted_state_dict[f"{block_prefix}attn.to_k.bias"] = torch.cat([sample_k_bias])
+ converted_state_dict[f"{block_prefix}attn.to_v.weight"] = torch.cat([sample_v])
+ converted_state_dict[f"{block_prefix}attn.to_v.bias"] = torch.cat([sample_v_bias])
+ converted_state_dict[f"{block_prefix}attn.add_q_proj.weight"] = torch.cat([context_q])
+ converted_state_dict[f"{block_prefix}attn.add_q_proj.bias"] = torch.cat([context_q_bias])
+ converted_state_dict[f"{block_prefix}attn.add_k_proj.weight"] = torch.cat([context_k])
+ converted_state_dict[f"{block_prefix}attn.add_k_proj.bias"] = torch.cat([context_k_bias])
+ converted_state_dict[f"{block_prefix}attn.add_v_proj.weight"] = torch.cat([context_v])
+ converted_state_dict[f"{block_prefix}attn.add_v_proj.bias"] = torch.cat([context_v_bias])
+ # qk_norm
+ converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = original_state_dict.pop(
+ f"double_blocks.{i}.img_attn.norm.query_norm.scale"
+ )
+ converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = original_state_dict.pop(
+ f"double_blocks.{i}.img_attn.norm.key_norm.scale"
+ )
+ converted_state_dict[f"{block_prefix}attn.norm_added_q.weight"] = original_state_dict.pop(
+ f"double_blocks.{i}.txt_attn.norm.query_norm.scale"
+ )
+ converted_state_dict[f"{block_prefix}attn.norm_added_k.weight"] = original_state_dict.pop(
+ f"double_blocks.{i}.txt_attn.norm.key_norm.scale"
+ )
+ # ff img_mlp
+ converted_state_dict[f"{block_prefix}ff.net.0.proj.weight"] = original_state_dict.pop(
+ f"double_blocks.{i}.img_mlp.0.weight"
+ )
+ converted_state_dict[f"{block_prefix}ff.net.0.proj.bias"] = original_state_dict.pop(
+ f"double_blocks.{i}.img_mlp.0.bias"
+ )
+ converted_state_dict[f"{block_prefix}ff.net.2.weight"] = original_state_dict.pop(
+ f"double_blocks.{i}.img_mlp.2.weight"
+ )
+ converted_state_dict[f"{block_prefix}ff.net.2.bias"] = original_state_dict.pop(
+ f"double_blocks.{i}.img_mlp.2.bias"
+ )
+ converted_state_dict[f"{block_prefix}ff_context.net.0.proj.weight"] = original_state_dict.pop(
+ f"double_blocks.{i}.txt_mlp.0.weight"
+ )
+ converted_state_dict[f"{block_prefix}ff_context.net.0.proj.bias"] = original_state_dict.pop(
+ f"double_blocks.{i}.txt_mlp.0.bias"
+ )
+ converted_state_dict[f"{block_prefix}ff_context.net.2.weight"] = original_state_dict.pop(
+ f"double_blocks.{i}.txt_mlp.2.weight"
+ )
+ converted_state_dict[f"{block_prefix}ff_context.net.2.bias"] = original_state_dict.pop(
+ f"double_blocks.{i}.txt_mlp.2.bias"
+ )
+ # output projections.
+ converted_state_dict[f"{block_prefix}attn.to_out.0.weight"] = original_state_dict.pop(
+ f"double_blocks.{i}.img_attn.proj.weight"
+ )
+ converted_state_dict[f"{block_prefix}attn.to_out.0.bias"] = original_state_dict.pop(
+ f"double_blocks.{i}.img_attn.proj.bias"
+ )
+ converted_state_dict[f"{block_prefix}attn.to_add_out.weight"] = original_state_dict.pop(
+ f"double_blocks.{i}.txt_attn.proj.weight"
+ )
+ converted_state_dict[f"{block_prefix}attn.to_add_out.bias"] = original_state_dict.pop(
+ f"double_blocks.{i}.txt_attn.proj.bias"
+ )
+
+ # single transfomer blocks
+ for i in range(num_single_layers):
+ block_prefix = f"single_transformer_blocks.{i}."
+ # norm.linear <- single_blocks.0.modulation.lin
+ converted_state_dict[f"{block_prefix}norm.linear.weight"] = original_state_dict.pop(
+ f"single_blocks.{i}.modulation.lin.weight"
+ )
+ converted_state_dict[f"{block_prefix}norm.linear.bias"] = original_state_dict.pop(
+ f"single_blocks.{i}.modulation.lin.bias"
+ )
+ # Q, K, V, mlp
+ mlp_hidden_dim = int(inner_dim * mlp_ratio)
+ split_size = (inner_dim, inner_dim, inner_dim, mlp_hidden_dim)
+ q, k, v, mlp = torch.split(original_state_dict.pop(f"single_blocks.{i}.linear1.weight"), split_size, dim=0)
+ q_bias, k_bias, v_bias, mlp_bias = torch.split(
+ original_state_dict.pop(f"single_blocks.{i}.linear1.bias"), split_size, dim=0
+ )
+ converted_state_dict[f"{block_prefix}attn.to_q.weight"] = torch.cat([q])
+ converted_state_dict[f"{block_prefix}attn.to_q.bias"] = torch.cat([q_bias])
+ converted_state_dict[f"{block_prefix}attn.to_k.weight"] = torch.cat([k])
+ converted_state_dict[f"{block_prefix}attn.to_k.bias"] = torch.cat([k_bias])
+ converted_state_dict[f"{block_prefix}attn.to_v.weight"] = torch.cat([v])
+ converted_state_dict[f"{block_prefix}attn.to_v.bias"] = torch.cat([v_bias])
+ converted_state_dict[f"{block_prefix}proj_mlp.weight"] = torch.cat([mlp])
+ converted_state_dict[f"{block_prefix}proj_mlp.bias"] = torch.cat([mlp_bias])
+ # qk norm
+ converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = original_state_dict.pop(
+ f"single_blocks.{i}.norm.query_norm.scale"
+ )
+ converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = original_state_dict.pop(
+ f"single_blocks.{i}.norm.key_norm.scale"
+ )
+ # output projections.
+ converted_state_dict[f"{block_prefix}proj_out.weight"] = original_state_dict.pop(
+ f"single_blocks.{i}.linear2.weight"
+ )
+ converted_state_dict[f"{block_prefix}proj_out.bias"] = original_state_dict.pop(
+ f"single_blocks.{i}.linear2.bias"
+ )
+
+ converted_state_dict["proj_out.weight"] = original_state_dict.pop("final_layer.linear.weight")
+ converted_state_dict["proj_out.bias"] = original_state_dict.pop("final_layer.linear.bias")
+ converted_state_dict["norm_out.linear.weight"] = swap_scale_shift(
+ original_state_dict.pop("final_layer.adaLN_modulation.1.weight")
+ )
+ converted_state_dict["norm_out.linear.bias"] = swap_scale_shift(
+ original_state_dict.pop("final_layer.adaLN_modulation.1.bias")
+ )
+
+ return converted_state_dict
+
+
+def main(args):
+ original_ckpt = load_original_checkpoint(args)
+ has_guidance = any("guidance" in k for k in original_ckpt)
+
+ if args.transformer:
+ num_layers = 19
+ num_single_layers = 38
+ inner_dim = 3072
+ mlp_ratio = 4.0
+ converted_transformer_state_dict = convert_flux_transformer_checkpoint_to_diffusers(
+ original_ckpt, num_layers, num_single_layers, inner_dim, mlp_ratio=mlp_ratio
+ )
+ transformer = FluxTransformer2DModel(guidance_embeds=has_guidance)
+ transformer.load_state_dict(converted_transformer_state_dict, strict=True)
+
+ print(
+ f"Saving Flux Transformer in Diffusers format. Variant: {'guidance-distilled' if has_guidance else 'timestep-distilled'}"
+ )
+ transformer.to(dtype).save_pretrained(f"{args.output_path}/transformer")
+
+ if args.vae:
+ config = AutoencoderKL.load_config("stabilityai/stable-diffusion-3-medium-diffusers", subfolder="vae")
+ vae = AutoencoderKL.from_config(config, scaling_factor=0.3611, shift_factor=0.1159).to(torch.bfloat16)
+
+ converted_vae_state_dict = convert_ldm_vae_checkpoint(original_ckpt, vae.config)
+ vae.load_state_dict(converted_vae_state_dict, strict=True)
+ vae.to(dtype).save_pretrained(f"{args.output_path}/vae")
+
+
+if __name__ == "__main__":
+ main(args)
diff --git a/diffusers/scripts/convert_gligen_to_diffusers.py b/diffusers/scripts/convert_gligen_to_diffusers.py
new file mode 100644
index 0000000000000000000000000000000000000000..83c1f928e407f9a9cc93a07a609ae2472c2a7b18
--- /dev/null
+++ b/diffusers/scripts/convert_gligen_to_diffusers.py
@@ -0,0 +1,581 @@
+import argparse
+import re
+
+import torch
+import yaml
+from transformers import (
+ CLIPProcessor,
+ CLIPTextModel,
+ CLIPTokenizer,
+ CLIPVisionModelWithProjection,
+)
+
+from diffusers import (
+ AutoencoderKL,
+ DDIMScheduler,
+ StableDiffusionGLIGENPipeline,
+ StableDiffusionGLIGENTextImagePipeline,
+ UNet2DConditionModel,
+)
+from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
+ assign_to_checkpoint,
+ conv_attn_to_linear,
+ protected,
+ renew_attention_paths,
+ renew_resnet_paths,
+ renew_vae_attention_paths,
+ renew_vae_resnet_paths,
+ shave_segments,
+ textenc_conversion_map,
+ textenc_pattern,
+)
+
+
+def convert_open_clip_checkpoint(checkpoint):
+ checkpoint = checkpoint["text_encoder"]
+ text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
+
+ keys = list(checkpoint.keys())
+
+ text_model_dict = {}
+
+ if "cond_stage_model.model.text_projection" in checkpoint:
+ d_model = int(checkpoint["cond_stage_model.model.text_projection"].shape[0])
+ else:
+ d_model = 1024
+
+ for key in keys:
+ if "resblocks.23" in key: # Diffusers drops the final layer and only uses the penultimate layer
+ continue
+ if key in textenc_conversion_map:
+ text_model_dict[textenc_conversion_map[key]] = checkpoint[key]
+ # if key.startswith("cond_stage_model.model.transformer."):
+ new_key = key[len("transformer.") :]
+ if new_key.endswith(".in_proj_weight"):
+ new_key = new_key[: -len(".in_proj_weight")]
+ new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key)
+ text_model_dict[new_key + ".q_proj.weight"] = checkpoint[key][:d_model, :]
+ text_model_dict[new_key + ".k_proj.weight"] = checkpoint[key][d_model : d_model * 2, :]
+ text_model_dict[new_key + ".v_proj.weight"] = checkpoint[key][d_model * 2 :, :]
+ elif new_key.endswith(".in_proj_bias"):
+ new_key = new_key[: -len(".in_proj_bias")]
+ new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key)
+ text_model_dict[new_key + ".q_proj.bias"] = checkpoint[key][:d_model]
+ text_model_dict[new_key + ".k_proj.bias"] = checkpoint[key][d_model : d_model * 2]
+ text_model_dict[new_key + ".v_proj.bias"] = checkpoint[key][d_model * 2 :]
+ else:
+ if key != "transformer.text_model.embeddings.position_ids":
+ new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key)
+
+ text_model_dict[new_key] = checkpoint[key]
+
+ if key == "transformer.text_model.embeddings.token_embedding.weight":
+ text_model_dict["text_model.embeddings.token_embedding.weight"] = checkpoint[key]
+
+ text_model_dict.pop("text_model.embeddings.transformer.text_model.embeddings.token_embedding.weight")
+
+ text_model.load_state_dict(text_model_dict)
+
+ return text_model
+
+
+def convert_gligen_vae_checkpoint(checkpoint, config):
+ checkpoint = checkpoint["autoencoder"]
+ vae_state_dict = {}
+ vae_key = "first_stage_model."
+ keys = list(checkpoint.keys())
+ for key in keys:
+ vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
+
+ new_checkpoint = {}
+
+ new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
+ new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
+ new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
+ new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
+ new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
+ new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
+
+ new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
+ new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
+ new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
+ new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
+ new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
+ new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
+
+ new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
+ new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
+ new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
+ new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
+
+ # Retrieves the keys for the encoder down blocks only
+ num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
+ down_blocks = {
+ layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
+ }
+
+ # Retrieves the keys for the decoder up blocks only
+ num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
+ up_blocks = {
+ layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
+ }
+
+ for i in range(num_down_blocks):
+ resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
+
+ if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
+ f"encoder.down.{i}.downsample.conv.weight"
+ )
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
+ f"encoder.down.{i}.downsample.conv.bias"
+ )
+
+ paths = renew_vae_resnet_paths(resnets)
+ meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
+
+ mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
+ num_mid_res_blocks = 2
+ for i in range(1, num_mid_res_blocks + 1):
+ resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
+
+ paths = renew_vae_resnet_paths(resnets)
+ meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
+
+ mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
+ paths = renew_vae_attention_paths(mid_attentions)
+ meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
+ conv_attn_to_linear(new_checkpoint)
+
+ for i in range(num_up_blocks):
+ block_id = num_up_blocks - 1 - i
+ resnets = [
+ key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
+ ]
+
+ if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
+ f"decoder.up.{block_id}.upsample.conv.weight"
+ ]
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
+ f"decoder.up.{block_id}.upsample.conv.bias"
+ ]
+
+ paths = renew_vae_resnet_paths(resnets)
+ meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
+
+ mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
+ num_mid_res_blocks = 2
+ for i in range(1, num_mid_res_blocks + 1):
+ resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
+
+ paths = renew_vae_resnet_paths(resnets)
+ meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
+
+ mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
+ paths = renew_vae_attention_paths(mid_attentions)
+ meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
+ conv_attn_to_linear(new_checkpoint)
+
+ for key in new_checkpoint.keys():
+ if "encoder.mid_block.attentions.0" in key or "decoder.mid_block.attentions.0" in key:
+ if "query" in key:
+ new_checkpoint[key.replace("query", "to_q")] = new_checkpoint.pop(key)
+ if "value" in key:
+ new_checkpoint[key.replace("value", "to_v")] = new_checkpoint.pop(key)
+ if "key" in key:
+ new_checkpoint[key.replace("key", "to_k")] = new_checkpoint.pop(key)
+ if "proj_attn" in key:
+ new_checkpoint[key.replace("proj_attn", "to_out.0")] = new_checkpoint.pop(key)
+
+ return new_checkpoint
+
+
+def convert_gligen_unet_checkpoint(checkpoint, config, path=None, extract_ema=False):
+ unet_state_dict = {}
+ checkpoint = checkpoint["model"]
+ keys = list(checkpoint.keys())
+
+ unet_key = "model.diffusion_model."
+
+ if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema:
+ print(f"Checkpoint {path} has bot EMA and non-EMA weights.")
+ print(
+ "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA"
+ " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag."
+ )
+ for key in keys:
+ if key.startswith("model.diffusion_model"):
+ flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
+ unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key)
+ else:
+ if sum(k.startswith("model_ema") for k in keys) > 100:
+ print(
+ "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA"
+ " weights (usually better for inference), please make sure to add the `--extract_ema` flag."
+ )
+ for key in keys:
+ unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
+
+ new_checkpoint = {}
+
+ new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
+ new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
+ new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
+ new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
+
+ new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
+ new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
+
+ new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
+ new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
+ new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
+ new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
+
+ # Retrieves the keys for the input blocks only
+ num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
+ input_blocks = {
+ layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key]
+ for layer_id in range(num_input_blocks)
+ }
+
+ # Retrieves the keys for the middle blocks only
+ num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
+ middle_blocks = {
+ layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
+ for layer_id in range(num_middle_blocks)
+ }
+
+ # Retrieves the keys for the output blocks only
+ num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
+ output_blocks = {
+ layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key]
+ for layer_id in range(num_output_blocks)
+ }
+
+ for i in range(1, num_input_blocks):
+ block_id = (i - 1) // (config["layers_per_block"] + 1)
+ layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
+
+ resnets = [
+ key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
+ ]
+ attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
+
+ if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
+ f"input_blocks.{i}.0.op.weight"
+ )
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
+ f"input_blocks.{i}.0.op.bias"
+ )
+
+ paths = renew_resnet_paths(resnets)
+ meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
+ assign_to_checkpoint(
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
+ )
+
+ if len(attentions):
+ paths = renew_attention_paths(attentions)
+ meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
+ assign_to_checkpoint(
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
+ )
+
+ resnet_0 = middle_blocks[0]
+ attentions = middle_blocks[1]
+ resnet_1 = middle_blocks[2]
+
+ resnet_0_paths = renew_resnet_paths(resnet_0)
+ assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
+
+ resnet_1_paths = renew_resnet_paths(resnet_1)
+ assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
+
+ attentions_paths = renew_attention_paths(attentions)
+ meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
+ assign_to_checkpoint(
+ attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
+ )
+
+ for i in range(num_output_blocks):
+ block_id = i // (config["layers_per_block"] + 1)
+ layer_in_block_id = i % (config["layers_per_block"] + 1)
+ output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
+ output_block_list = {}
+
+ for layer in output_block_layers:
+ layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
+ if layer_id in output_block_list:
+ output_block_list[layer_id].append(layer_name)
+ else:
+ output_block_list[layer_id] = [layer_name]
+
+ if len(output_block_list) > 1:
+ resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
+ attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
+
+ resnet_0_paths = renew_resnet_paths(resnets)
+ paths = renew_resnet_paths(resnets)
+
+ meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
+ assign_to_checkpoint(
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
+ )
+
+ output_block_list = {k: sorted(v) for k, v in output_block_list.items()}
+ if ["conv.bias", "conv.weight"] in output_block_list.values():
+ index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
+ f"output_blocks.{i}.{index}.conv.weight"
+ ]
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
+ f"output_blocks.{i}.{index}.conv.bias"
+ ]
+
+ # Clear attentions as they have been attributed above.
+ if len(attentions) == 2:
+ attentions = []
+
+ if len(attentions):
+ paths = renew_attention_paths(attentions)
+ meta_path = {
+ "old": f"output_blocks.{i}.1",
+ "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
+ }
+ assign_to_checkpoint(
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
+ )
+ else:
+ resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
+ for path in resnet_0_paths:
+ old_path = ".".join(["output_blocks", str(i), path["old"]])
+ new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
+
+ new_checkpoint[new_path] = unet_state_dict[old_path]
+
+ for key in keys:
+ if "position_net" in key:
+ new_checkpoint[key] = unet_state_dict[key]
+
+ return new_checkpoint
+
+
+def create_vae_config(original_config, image_size: int):
+ vae_params = original_config["autoencoder"]["params"]["ddconfig"]
+ _ = original_config["autoencoder"]["params"]["embed_dim"]
+
+ block_out_channels = [vae_params["ch"] * mult for mult in vae_params["ch_mult"]]
+ down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
+ up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
+
+ config = {
+ "sample_size": image_size,
+ "in_channels": vae_params["in_channels"],
+ "out_channels": vae_params["out_ch"],
+ "down_block_types": tuple(down_block_types),
+ "up_block_types": tuple(up_block_types),
+ "block_out_channels": tuple(block_out_channels),
+ "latent_channels": vae_params["z_channels"],
+ "layers_per_block": vae_params["num_res_blocks"],
+ }
+
+ return config
+
+
+def create_unet_config(original_config, image_size: int, attention_type):
+ unet_params = original_config["model"]["params"]
+ vae_params = original_config["autoencoder"]["params"]["ddconfig"]
+
+ block_out_channels = [unet_params["model_channels"] * mult for mult in unet_params["channel_mult"]]
+
+ down_block_types = []
+ resolution = 1
+ for i in range(len(block_out_channels)):
+ block_type = "CrossAttnDownBlock2D" if resolution in unet_params["attention_resolutions"] else "DownBlock2D"
+ down_block_types.append(block_type)
+ if i != len(block_out_channels) - 1:
+ resolution *= 2
+
+ up_block_types = []
+ for i in range(len(block_out_channels)):
+ block_type = "CrossAttnUpBlock2D" if resolution in unet_params["attention_resolutions"] else "UpBlock2D"
+ up_block_types.append(block_type)
+ resolution //= 2
+
+ vae_scale_factor = 2 ** (len(vae_params["ch_mult"]) - 1)
+
+ head_dim = unet_params["num_heads"] if "num_heads" in unet_params else None
+ use_linear_projection = (
+ unet_params["use_linear_in_transformer"] if "use_linear_in_transformer" in unet_params else False
+ )
+ if use_linear_projection:
+ if head_dim is None:
+ head_dim = [5, 10, 20, 20]
+
+ config = {
+ "sample_size": image_size // vae_scale_factor,
+ "in_channels": unet_params["in_channels"],
+ "down_block_types": tuple(down_block_types),
+ "block_out_channels": tuple(block_out_channels),
+ "layers_per_block": unet_params["num_res_blocks"],
+ "cross_attention_dim": unet_params["context_dim"],
+ "attention_head_dim": head_dim,
+ "use_linear_projection": use_linear_projection,
+ "attention_type": attention_type,
+ }
+
+ return config
+
+
+def convert_gligen_to_diffusers(
+ checkpoint_path: str,
+ original_config_file: str,
+ attention_type: str,
+ image_size: int = 512,
+ extract_ema: bool = False,
+ num_in_channels: int = None,
+ device: str = None,
+):
+ if device is None:
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ checkpoint = torch.load(checkpoint_path, map_location=device)
+ else:
+ checkpoint = torch.load(checkpoint_path, map_location=device)
+
+ if "global_step" in checkpoint:
+ checkpoint["global_step"]
+ else:
+ print("global_step key not found in model")
+
+ original_config = yaml.safe_load(original_config_file)
+
+ if num_in_channels is not None:
+ original_config["model"]["params"]["in_channels"] = num_in_channels
+
+ num_train_timesteps = original_config["diffusion"]["params"]["timesteps"]
+ beta_start = original_config["diffusion"]["params"]["linear_start"]
+ beta_end = original_config["diffusion"]["params"]["linear_end"]
+
+ scheduler = DDIMScheduler(
+ beta_end=beta_end,
+ beta_schedule="scaled_linear",
+ beta_start=beta_start,
+ num_train_timesteps=num_train_timesteps,
+ steps_offset=1,
+ clip_sample=False,
+ set_alpha_to_one=False,
+ prediction_type="epsilon",
+ )
+
+ # Convert the UNet2DConditionalModel model
+ unet_config = create_unet_config(original_config, image_size, attention_type)
+ unet = UNet2DConditionModel(**unet_config)
+
+ converted_unet_checkpoint = convert_gligen_unet_checkpoint(
+ checkpoint, unet_config, path=checkpoint_path, extract_ema=extract_ema
+ )
+
+ unet.load_state_dict(converted_unet_checkpoint)
+
+ # Convert the VAE model
+ vae_config = create_vae_config(original_config, image_size)
+ converted_vae_checkpoint = convert_gligen_vae_checkpoint(checkpoint, vae_config)
+
+ vae = AutoencoderKL(**vae_config)
+ vae.load_state_dict(converted_vae_checkpoint)
+
+ # Convert the text model
+ text_encoder = convert_open_clip_checkpoint(checkpoint)
+ tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
+
+ if attention_type == "gated-text-image":
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14")
+ processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
+
+ pipe = StableDiffusionGLIGENTextImagePipeline(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ image_encoder=image_encoder,
+ processor=processor,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=None,
+ feature_extractor=None,
+ )
+ elif attention_type == "gated":
+ pipe = StableDiffusionGLIGENPipeline(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=None,
+ feature_extractor=None,
+ )
+
+ return pipe
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ "--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert."
+ )
+ parser.add_argument(
+ "--original_config_file",
+ default=None,
+ type=str,
+ required=True,
+ help="The YAML config file corresponding to the gligen architecture.",
+ )
+ parser.add_argument(
+ "--num_in_channels",
+ default=None,
+ type=int,
+ help="The number of input channels. If `None` number of input channels will be automatically inferred.",
+ )
+ parser.add_argument(
+ "--extract_ema",
+ action="store_true",
+ help=(
+ "Only relevant for checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights"
+ " or not. Defaults to `False`. Add `--extract_ema` to extract the EMA weights. EMA weights usually yield"
+ " higher quality images for inference. Non-EMA weights are usually better to continue fine-tuning."
+ ),
+ )
+ parser.add_argument(
+ "--attention_type",
+ default=None,
+ type=str,
+ required=True,
+ help="Type of attention ex: gated or gated-text-image",
+ )
+ parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
+ parser.add_argument("--device", type=str, help="Device to use.")
+ parser.add_argument("--half", action="store_true", help="Save weights in half precision.")
+
+ args = parser.parse_args()
+
+ pipe = convert_gligen_to_diffusers(
+ checkpoint_path=args.checkpoint_path,
+ original_config_file=args.original_config_file,
+ attention_type=args.attention_type,
+ extract_ema=args.extract_ema,
+ num_in_channels=args.num_in_channels,
+ device=args.device,
+ )
+
+ if args.half:
+ pipe.to(dtype=torch.float16)
+
+ pipe.save_pretrained(args.dump_path)
diff --git a/diffusers/scripts/convert_hunyuandit_controlnet_to_diffusers.py b/diffusers/scripts/convert_hunyuandit_controlnet_to_diffusers.py
new file mode 100644
index 0000000000000000000000000000000000000000..1c838369089077c93301f42acbc5e17ee5d64bf9
--- /dev/null
+++ b/diffusers/scripts/convert_hunyuandit_controlnet_to_diffusers.py
@@ -0,0 +1,241 @@
+import argparse
+
+import torch
+
+from diffusers import HunyuanDiT2DControlNetModel
+
+
+def main(args):
+ state_dict = torch.load(args.pt_checkpoint_path, map_location="cpu")
+
+ if args.load_key != "none":
+ try:
+ state_dict = state_dict[args.load_key]
+ except KeyError:
+ raise KeyError(
+ f"{args.load_key} not found in the checkpoint."
+ "Please load from the following keys:{state_dict.keys()}"
+ )
+ device = "cuda"
+
+ model_config = HunyuanDiT2DControlNetModel.load_config(
+ "Tencent-Hunyuan/HunyuanDiT-v1.2-Diffusers", subfolder="transformer"
+ )
+ model_config[
+ "use_style_cond_and_image_meta_size"
+ ] = args.use_style_cond_and_image_meta_size ### version <= v1.1: True; version >= v1.2: False
+ print(model_config)
+
+ for key in state_dict:
+ print("local:", key)
+
+ model = HunyuanDiT2DControlNetModel.from_config(model_config).to(device)
+
+ for key in model.state_dict():
+ print("diffusers:", key)
+
+ num_layers = 19
+ for i in range(num_layers):
+ # attn1
+ # Wkqv -> to_q, to_k, to_v
+ q, k, v = torch.chunk(state_dict[f"blocks.{i}.attn1.Wqkv.weight"], 3, dim=0)
+ q_bias, k_bias, v_bias = torch.chunk(state_dict[f"blocks.{i}.attn1.Wqkv.bias"], 3, dim=0)
+ state_dict[f"blocks.{i}.attn1.to_q.weight"] = q
+ state_dict[f"blocks.{i}.attn1.to_q.bias"] = q_bias
+ state_dict[f"blocks.{i}.attn1.to_k.weight"] = k
+ state_dict[f"blocks.{i}.attn1.to_k.bias"] = k_bias
+ state_dict[f"blocks.{i}.attn1.to_v.weight"] = v
+ state_dict[f"blocks.{i}.attn1.to_v.bias"] = v_bias
+ state_dict.pop(f"blocks.{i}.attn1.Wqkv.weight")
+ state_dict.pop(f"blocks.{i}.attn1.Wqkv.bias")
+
+ # q_norm, k_norm -> norm_q, norm_k
+ state_dict[f"blocks.{i}.attn1.norm_q.weight"] = state_dict[f"blocks.{i}.attn1.q_norm.weight"]
+ state_dict[f"blocks.{i}.attn1.norm_q.bias"] = state_dict[f"blocks.{i}.attn1.q_norm.bias"]
+ state_dict[f"blocks.{i}.attn1.norm_k.weight"] = state_dict[f"blocks.{i}.attn1.k_norm.weight"]
+ state_dict[f"blocks.{i}.attn1.norm_k.bias"] = state_dict[f"blocks.{i}.attn1.k_norm.bias"]
+
+ state_dict.pop(f"blocks.{i}.attn1.q_norm.weight")
+ state_dict.pop(f"blocks.{i}.attn1.q_norm.bias")
+ state_dict.pop(f"blocks.{i}.attn1.k_norm.weight")
+ state_dict.pop(f"blocks.{i}.attn1.k_norm.bias")
+
+ # out_proj -> to_out
+ state_dict[f"blocks.{i}.attn1.to_out.0.weight"] = state_dict[f"blocks.{i}.attn1.out_proj.weight"]
+ state_dict[f"blocks.{i}.attn1.to_out.0.bias"] = state_dict[f"blocks.{i}.attn1.out_proj.bias"]
+ state_dict.pop(f"blocks.{i}.attn1.out_proj.weight")
+ state_dict.pop(f"blocks.{i}.attn1.out_proj.bias")
+
+ # attn2
+ # kq_proj -> to_k, to_v
+ k, v = torch.chunk(state_dict[f"blocks.{i}.attn2.kv_proj.weight"], 2, dim=0)
+ k_bias, v_bias = torch.chunk(state_dict[f"blocks.{i}.attn2.kv_proj.bias"], 2, dim=0)
+ state_dict[f"blocks.{i}.attn2.to_k.weight"] = k
+ state_dict[f"blocks.{i}.attn2.to_k.bias"] = k_bias
+ state_dict[f"blocks.{i}.attn2.to_v.weight"] = v
+ state_dict[f"blocks.{i}.attn2.to_v.bias"] = v_bias
+ state_dict.pop(f"blocks.{i}.attn2.kv_proj.weight")
+ state_dict.pop(f"blocks.{i}.attn2.kv_proj.bias")
+
+ # q_proj -> to_q
+ state_dict[f"blocks.{i}.attn2.to_q.weight"] = state_dict[f"blocks.{i}.attn2.q_proj.weight"]
+ state_dict[f"blocks.{i}.attn2.to_q.bias"] = state_dict[f"blocks.{i}.attn2.q_proj.bias"]
+ state_dict.pop(f"blocks.{i}.attn2.q_proj.weight")
+ state_dict.pop(f"blocks.{i}.attn2.q_proj.bias")
+
+ # q_norm, k_norm -> norm_q, norm_k
+ state_dict[f"blocks.{i}.attn2.norm_q.weight"] = state_dict[f"blocks.{i}.attn2.q_norm.weight"]
+ state_dict[f"blocks.{i}.attn2.norm_q.bias"] = state_dict[f"blocks.{i}.attn2.q_norm.bias"]
+ state_dict[f"blocks.{i}.attn2.norm_k.weight"] = state_dict[f"blocks.{i}.attn2.k_norm.weight"]
+ state_dict[f"blocks.{i}.attn2.norm_k.bias"] = state_dict[f"blocks.{i}.attn2.k_norm.bias"]
+
+ state_dict.pop(f"blocks.{i}.attn2.q_norm.weight")
+ state_dict.pop(f"blocks.{i}.attn2.q_norm.bias")
+ state_dict.pop(f"blocks.{i}.attn2.k_norm.weight")
+ state_dict.pop(f"blocks.{i}.attn2.k_norm.bias")
+
+ # out_proj -> to_out
+ state_dict[f"blocks.{i}.attn2.to_out.0.weight"] = state_dict[f"blocks.{i}.attn2.out_proj.weight"]
+ state_dict[f"blocks.{i}.attn2.to_out.0.bias"] = state_dict[f"blocks.{i}.attn2.out_proj.bias"]
+ state_dict.pop(f"blocks.{i}.attn2.out_proj.weight")
+ state_dict.pop(f"blocks.{i}.attn2.out_proj.bias")
+
+ # switch norm 2 and norm 3
+ norm2_weight = state_dict[f"blocks.{i}.norm2.weight"]
+ norm2_bias = state_dict[f"blocks.{i}.norm2.bias"]
+ state_dict[f"blocks.{i}.norm2.weight"] = state_dict[f"blocks.{i}.norm3.weight"]
+ state_dict[f"blocks.{i}.norm2.bias"] = state_dict[f"blocks.{i}.norm3.bias"]
+ state_dict[f"blocks.{i}.norm3.weight"] = norm2_weight
+ state_dict[f"blocks.{i}.norm3.bias"] = norm2_bias
+
+ # norm1 -> norm1.norm
+ # default_modulation.1 -> norm1.linear
+ state_dict[f"blocks.{i}.norm1.norm.weight"] = state_dict[f"blocks.{i}.norm1.weight"]
+ state_dict[f"blocks.{i}.norm1.norm.bias"] = state_dict[f"blocks.{i}.norm1.bias"]
+ state_dict[f"blocks.{i}.norm1.linear.weight"] = state_dict[f"blocks.{i}.default_modulation.1.weight"]
+ state_dict[f"blocks.{i}.norm1.linear.bias"] = state_dict[f"blocks.{i}.default_modulation.1.bias"]
+ state_dict.pop(f"blocks.{i}.norm1.weight")
+ state_dict.pop(f"blocks.{i}.norm1.bias")
+ state_dict.pop(f"blocks.{i}.default_modulation.1.weight")
+ state_dict.pop(f"blocks.{i}.default_modulation.1.bias")
+
+ # mlp.fc1 -> ff.net.0, mlp.fc2 -> ff.net.2
+ state_dict[f"blocks.{i}.ff.net.0.proj.weight"] = state_dict[f"blocks.{i}.mlp.fc1.weight"]
+ state_dict[f"blocks.{i}.ff.net.0.proj.bias"] = state_dict[f"blocks.{i}.mlp.fc1.bias"]
+ state_dict[f"blocks.{i}.ff.net.2.weight"] = state_dict[f"blocks.{i}.mlp.fc2.weight"]
+ state_dict[f"blocks.{i}.ff.net.2.bias"] = state_dict[f"blocks.{i}.mlp.fc2.bias"]
+ state_dict.pop(f"blocks.{i}.mlp.fc1.weight")
+ state_dict.pop(f"blocks.{i}.mlp.fc1.bias")
+ state_dict.pop(f"blocks.{i}.mlp.fc2.weight")
+ state_dict.pop(f"blocks.{i}.mlp.fc2.bias")
+
+ # after_proj_list -> controlnet_blocks
+ state_dict[f"controlnet_blocks.{i}.weight"] = state_dict[f"after_proj_list.{i}.weight"]
+ state_dict[f"controlnet_blocks.{i}.bias"] = state_dict[f"after_proj_list.{i}.bias"]
+ state_dict.pop(f"after_proj_list.{i}.weight")
+ state_dict.pop(f"after_proj_list.{i}.bias")
+
+ # before_proj -> input_block
+ state_dict["input_block.weight"] = state_dict["before_proj.weight"]
+ state_dict["input_block.bias"] = state_dict["before_proj.bias"]
+ state_dict.pop("before_proj.weight")
+ state_dict.pop("before_proj.bias")
+
+ # pooler -> time_extra_emb
+ state_dict["time_extra_emb.pooler.positional_embedding"] = state_dict["pooler.positional_embedding"]
+ state_dict["time_extra_emb.pooler.k_proj.weight"] = state_dict["pooler.k_proj.weight"]
+ state_dict["time_extra_emb.pooler.k_proj.bias"] = state_dict["pooler.k_proj.bias"]
+ state_dict["time_extra_emb.pooler.q_proj.weight"] = state_dict["pooler.q_proj.weight"]
+ state_dict["time_extra_emb.pooler.q_proj.bias"] = state_dict["pooler.q_proj.bias"]
+ state_dict["time_extra_emb.pooler.v_proj.weight"] = state_dict["pooler.v_proj.weight"]
+ state_dict["time_extra_emb.pooler.v_proj.bias"] = state_dict["pooler.v_proj.bias"]
+ state_dict["time_extra_emb.pooler.c_proj.weight"] = state_dict["pooler.c_proj.weight"]
+ state_dict["time_extra_emb.pooler.c_proj.bias"] = state_dict["pooler.c_proj.bias"]
+ state_dict.pop("pooler.k_proj.weight")
+ state_dict.pop("pooler.k_proj.bias")
+ state_dict.pop("pooler.q_proj.weight")
+ state_dict.pop("pooler.q_proj.bias")
+ state_dict.pop("pooler.v_proj.weight")
+ state_dict.pop("pooler.v_proj.bias")
+ state_dict.pop("pooler.c_proj.weight")
+ state_dict.pop("pooler.c_proj.bias")
+ state_dict.pop("pooler.positional_embedding")
+
+ # t_embedder -> time_embedding (`TimestepEmbedding`)
+ state_dict["time_extra_emb.timestep_embedder.linear_1.bias"] = state_dict["t_embedder.mlp.0.bias"]
+ state_dict["time_extra_emb.timestep_embedder.linear_1.weight"] = state_dict["t_embedder.mlp.0.weight"]
+ state_dict["time_extra_emb.timestep_embedder.linear_2.bias"] = state_dict["t_embedder.mlp.2.bias"]
+ state_dict["time_extra_emb.timestep_embedder.linear_2.weight"] = state_dict["t_embedder.mlp.2.weight"]
+
+ state_dict.pop("t_embedder.mlp.0.bias")
+ state_dict.pop("t_embedder.mlp.0.weight")
+ state_dict.pop("t_embedder.mlp.2.bias")
+ state_dict.pop("t_embedder.mlp.2.weight")
+
+ # x_embedder -> pos_embd (`PatchEmbed`)
+ state_dict["pos_embed.proj.weight"] = state_dict["x_embedder.proj.weight"]
+ state_dict["pos_embed.proj.bias"] = state_dict["x_embedder.proj.bias"]
+ state_dict.pop("x_embedder.proj.weight")
+ state_dict.pop("x_embedder.proj.bias")
+
+ # mlp_t5 -> text_embedder
+ state_dict["text_embedder.linear_1.bias"] = state_dict["mlp_t5.0.bias"]
+ state_dict["text_embedder.linear_1.weight"] = state_dict["mlp_t5.0.weight"]
+ state_dict["text_embedder.linear_2.bias"] = state_dict["mlp_t5.2.bias"]
+ state_dict["text_embedder.linear_2.weight"] = state_dict["mlp_t5.2.weight"]
+ state_dict.pop("mlp_t5.0.bias")
+ state_dict.pop("mlp_t5.0.weight")
+ state_dict.pop("mlp_t5.2.bias")
+ state_dict.pop("mlp_t5.2.weight")
+
+ # extra_embedder -> extra_embedder
+ state_dict["time_extra_emb.extra_embedder.linear_1.bias"] = state_dict["extra_embedder.0.bias"]
+ state_dict["time_extra_emb.extra_embedder.linear_1.weight"] = state_dict["extra_embedder.0.weight"]
+ state_dict["time_extra_emb.extra_embedder.linear_2.bias"] = state_dict["extra_embedder.2.bias"]
+ state_dict["time_extra_emb.extra_embedder.linear_2.weight"] = state_dict["extra_embedder.2.weight"]
+ state_dict.pop("extra_embedder.0.bias")
+ state_dict.pop("extra_embedder.0.weight")
+ state_dict.pop("extra_embedder.2.bias")
+ state_dict.pop("extra_embedder.2.weight")
+
+ # style_embedder
+ if model_config["use_style_cond_and_image_meta_size"]:
+ print(state_dict["style_embedder.weight"])
+ print(state_dict["style_embedder.weight"].shape)
+ state_dict["time_extra_emb.style_embedder.weight"] = state_dict["style_embedder.weight"][0:1]
+ state_dict.pop("style_embedder.weight")
+
+ model.load_state_dict(state_dict)
+
+ if args.save:
+ model.save_pretrained(args.output_checkpoint_path)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ "--save", default=True, type=bool, required=False, help="Whether to save the converted pipeline or not."
+ )
+ parser.add_argument(
+ "--pt_checkpoint_path", default=None, type=str, required=True, help="Path to the .pt pretrained model."
+ )
+ parser.add_argument(
+ "--output_checkpoint_path",
+ default=None,
+ type=str,
+ required=False,
+ help="Path to the output converted diffusers pipeline.",
+ )
+ parser.add_argument(
+ "--load_key", default="none", type=str, required=False, help="The key to load from the pretrained .pt file"
+ )
+ parser.add_argument(
+ "--use_style_cond_and_image_meta_size",
+ type=bool,
+ default=False,
+ help="version <= v1.1: True; version >= v1.2: False",
+ )
+
+ args = parser.parse_args()
+ main(args)
diff --git a/diffusers/scripts/convert_hunyuandit_to_diffusers.py b/diffusers/scripts/convert_hunyuandit_to_diffusers.py
new file mode 100644
index 0000000000000000000000000000000000000000..da3af8333ee3e34f4b8af8b3f58e38e8dec5925c
--- /dev/null
+++ b/diffusers/scripts/convert_hunyuandit_to_diffusers.py
@@ -0,0 +1,267 @@
+import argparse
+
+import torch
+
+from diffusers import HunyuanDiT2DModel
+
+
+def main(args):
+ state_dict = torch.load(args.pt_checkpoint_path, map_location="cpu")
+
+ if args.load_key != "none":
+ try:
+ state_dict = state_dict[args.load_key]
+ except KeyError:
+ raise KeyError(
+ f"{args.load_key} not found in the checkpoint."
+ f"Please load from the following keys:{state_dict.keys()}"
+ )
+
+ device = "cuda"
+ model_config = HunyuanDiT2DModel.load_config("Tencent-Hunyuan/HunyuanDiT-Diffusers", subfolder="transformer")
+ model_config[
+ "use_style_cond_and_image_meta_size"
+ ] = args.use_style_cond_and_image_meta_size ### version <= v1.1: True; version >= v1.2: False
+
+ # input_size -> sample_size, text_dim -> cross_attention_dim
+ for key in state_dict:
+ print("local:", key)
+
+ model = HunyuanDiT2DModel.from_config(model_config).to(device)
+
+ for key in model.state_dict():
+ print("diffusers:", key)
+
+ num_layers = 40
+ for i in range(num_layers):
+ # attn1
+ # Wkqv -> to_q, to_k, to_v
+ q, k, v = torch.chunk(state_dict[f"blocks.{i}.attn1.Wqkv.weight"], 3, dim=0)
+ q_bias, k_bias, v_bias = torch.chunk(state_dict[f"blocks.{i}.attn1.Wqkv.bias"], 3, dim=0)
+ state_dict[f"blocks.{i}.attn1.to_q.weight"] = q
+ state_dict[f"blocks.{i}.attn1.to_q.bias"] = q_bias
+ state_dict[f"blocks.{i}.attn1.to_k.weight"] = k
+ state_dict[f"blocks.{i}.attn1.to_k.bias"] = k_bias
+ state_dict[f"blocks.{i}.attn1.to_v.weight"] = v
+ state_dict[f"blocks.{i}.attn1.to_v.bias"] = v_bias
+ state_dict.pop(f"blocks.{i}.attn1.Wqkv.weight")
+ state_dict.pop(f"blocks.{i}.attn1.Wqkv.bias")
+
+ # q_norm, k_norm -> norm_q, norm_k
+ state_dict[f"blocks.{i}.attn1.norm_q.weight"] = state_dict[f"blocks.{i}.attn1.q_norm.weight"]
+ state_dict[f"blocks.{i}.attn1.norm_q.bias"] = state_dict[f"blocks.{i}.attn1.q_norm.bias"]
+ state_dict[f"blocks.{i}.attn1.norm_k.weight"] = state_dict[f"blocks.{i}.attn1.k_norm.weight"]
+ state_dict[f"blocks.{i}.attn1.norm_k.bias"] = state_dict[f"blocks.{i}.attn1.k_norm.bias"]
+
+ state_dict.pop(f"blocks.{i}.attn1.q_norm.weight")
+ state_dict.pop(f"blocks.{i}.attn1.q_norm.bias")
+ state_dict.pop(f"blocks.{i}.attn1.k_norm.weight")
+ state_dict.pop(f"blocks.{i}.attn1.k_norm.bias")
+
+ # out_proj -> to_out
+ state_dict[f"blocks.{i}.attn1.to_out.0.weight"] = state_dict[f"blocks.{i}.attn1.out_proj.weight"]
+ state_dict[f"blocks.{i}.attn1.to_out.0.bias"] = state_dict[f"blocks.{i}.attn1.out_proj.bias"]
+ state_dict.pop(f"blocks.{i}.attn1.out_proj.weight")
+ state_dict.pop(f"blocks.{i}.attn1.out_proj.bias")
+
+ # attn2
+ # kq_proj -> to_k, to_v
+ k, v = torch.chunk(state_dict[f"blocks.{i}.attn2.kv_proj.weight"], 2, dim=0)
+ k_bias, v_bias = torch.chunk(state_dict[f"blocks.{i}.attn2.kv_proj.bias"], 2, dim=0)
+ state_dict[f"blocks.{i}.attn2.to_k.weight"] = k
+ state_dict[f"blocks.{i}.attn2.to_k.bias"] = k_bias
+ state_dict[f"blocks.{i}.attn2.to_v.weight"] = v
+ state_dict[f"blocks.{i}.attn2.to_v.bias"] = v_bias
+ state_dict.pop(f"blocks.{i}.attn2.kv_proj.weight")
+ state_dict.pop(f"blocks.{i}.attn2.kv_proj.bias")
+
+ # q_proj -> to_q
+ state_dict[f"blocks.{i}.attn2.to_q.weight"] = state_dict[f"blocks.{i}.attn2.q_proj.weight"]
+ state_dict[f"blocks.{i}.attn2.to_q.bias"] = state_dict[f"blocks.{i}.attn2.q_proj.bias"]
+ state_dict.pop(f"blocks.{i}.attn2.q_proj.weight")
+ state_dict.pop(f"blocks.{i}.attn2.q_proj.bias")
+
+ # q_norm, k_norm -> norm_q, norm_k
+ state_dict[f"blocks.{i}.attn2.norm_q.weight"] = state_dict[f"blocks.{i}.attn2.q_norm.weight"]
+ state_dict[f"blocks.{i}.attn2.norm_q.bias"] = state_dict[f"blocks.{i}.attn2.q_norm.bias"]
+ state_dict[f"blocks.{i}.attn2.norm_k.weight"] = state_dict[f"blocks.{i}.attn2.k_norm.weight"]
+ state_dict[f"blocks.{i}.attn2.norm_k.bias"] = state_dict[f"blocks.{i}.attn2.k_norm.bias"]
+
+ state_dict.pop(f"blocks.{i}.attn2.q_norm.weight")
+ state_dict.pop(f"blocks.{i}.attn2.q_norm.bias")
+ state_dict.pop(f"blocks.{i}.attn2.k_norm.weight")
+ state_dict.pop(f"blocks.{i}.attn2.k_norm.bias")
+
+ # out_proj -> to_out
+ state_dict[f"blocks.{i}.attn2.to_out.0.weight"] = state_dict[f"blocks.{i}.attn2.out_proj.weight"]
+ state_dict[f"blocks.{i}.attn2.to_out.0.bias"] = state_dict[f"blocks.{i}.attn2.out_proj.bias"]
+ state_dict.pop(f"blocks.{i}.attn2.out_proj.weight")
+ state_dict.pop(f"blocks.{i}.attn2.out_proj.bias")
+
+ # switch norm 2 and norm 3
+ norm2_weight = state_dict[f"blocks.{i}.norm2.weight"]
+ norm2_bias = state_dict[f"blocks.{i}.norm2.bias"]
+ state_dict[f"blocks.{i}.norm2.weight"] = state_dict[f"blocks.{i}.norm3.weight"]
+ state_dict[f"blocks.{i}.norm2.bias"] = state_dict[f"blocks.{i}.norm3.bias"]
+ state_dict[f"blocks.{i}.norm3.weight"] = norm2_weight
+ state_dict[f"blocks.{i}.norm3.bias"] = norm2_bias
+
+ # norm1 -> norm1.norm
+ # default_modulation.1 -> norm1.linear
+ state_dict[f"blocks.{i}.norm1.norm.weight"] = state_dict[f"blocks.{i}.norm1.weight"]
+ state_dict[f"blocks.{i}.norm1.norm.bias"] = state_dict[f"blocks.{i}.norm1.bias"]
+ state_dict[f"blocks.{i}.norm1.linear.weight"] = state_dict[f"blocks.{i}.default_modulation.1.weight"]
+ state_dict[f"blocks.{i}.norm1.linear.bias"] = state_dict[f"blocks.{i}.default_modulation.1.bias"]
+ state_dict.pop(f"blocks.{i}.norm1.weight")
+ state_dict.pop(f"blocks.{i}.norm1.bias")
+ state_dict.pop(f"blocks.{i}.default_modulation.1.weight")
+ state_dict.pop(f"blocks.{i}.default_modulation.1.bias")
+
+ # mlp.fc1 -> ff.net.0, mlp.fc2 -> ff.net.2
+ state_dict[f"blocks.{i}.ff.net.0.proj.weight"] = state_dict[f"blocks.{i}.mlp.fc1.weight"]
+ state_dict[f"blocks.{i}.ff.net.0.proj.bias"] = state_dict[f"blocks.{i}.mlp.fc1.bias"]
+ state_dict[f"blocks.{i}.ff.net.2.weight"] = state_dict[f"blocks.{i}.mlp.fc2.weight"]
+ state_dict[f"blocks.{i}.ff.net.2.bias"] = state_dict[f"blocks.{i}.mlp.fc2.bias"]
+ state_dict.pop(f"blocks.{i}.mlp.fc1.weight")
+ state_dict.pop(f"blocks.{i}.mlp.fc1.bias")
+ state_dict.pop(f"blocks.{i}.mlp.fc2.weight")
+ state_dict.pop(f"blocks.{i}.mlp.fc2.bias")
+
+ # pooler -> time_extra_emb
+ state_dict["time_extra_emb.pooler.positional_embedding"] = state_dict["pooler.positional_embedding"]
+ state_dict["time_extra_emb.pooler.k_proj.weight"] = state_dict["pooler.k_proj.weight"]
+ state_dict["time_extra_emb.pooler.k_proj.bias"] = state_dict["pooler.k_proj.bias"]
+ state_dict["time_extra_emb.pooler.q_proj.weight"] = state_dict["pooler.q_proj.weight"]
+ state_dict["time_extra_emb.pooler.q_proj.bias"] = state_dict["pooler.q_proj.bias"]
+ state_dict["time_extra_emb.pooler.v_proj.weight"] = state_dict["pooler.v_proj.weight"]
+ state_dict["time_extra_emb.pooler.v_proj.bias"] = state_dict["pooler.v_proj.bias"]
+ state_dict["time_extra_emb.pooler.c_proj.weight"] = state_dict["pooler.c_proj.weight"]
+ state_dict["time_extra_emb.pooler.c_proj.bias"] = state_dict["pooler.c_proj.bias"]
+ state_dict.pop("pooler.k_proj.weight")
+ state_dict.pop("pooler.k_proj.bias")
+ state_dict.pop("pooler.q_proj.weight")
+ state_dict.pop("pooler.q_proj.bias")
+ state_dict.pop("pooler.v_proj.weight")
+ state_dict.pop("pooler.v_proj.bias")
+ state_dict.pop("pooler.c_proj.weight")
+ state_dict.pop("pooler.c_proj.bias")
+ state_dict.pop("pooler.positional_embedding")
+
+ # t_embedder -> time_embedding (`TimestepEmbedding`)
+ state_dict["time_extra_emb.timestep_embedder.linear_1.bias"] = state_dict["t_embedder.mlp.0.bias"]
+ state_dict["time_extra_emb.timestep_embedder.linear_1.weight"] = state_dict["t_embedder.mlp.0.weight"]
+ state_dict["time_extra_emb.timestep_embedder.linear_2.bias"] = state_dict["t_embedder.mlp.2.bias"]
+ state_dict["time_extra_emb.timestep_embedder.linear_2.weight"] = state_dict["t_embedder.mlp.2.weight"]
+
+ state_dict.pop("t_embedder.mlp.0.bias")
+ state_dict.pop("t_embedder.mlp.0.weight")
+ state_dict.pop("t_embedder.mlp.2.bias")
+ state_dict.pop("t_embedder.mlp.2.weight")
+
+ # x_embedder -> pos_embd (`PatchEmbed`)
+ state_dict["pos_embed.proj.weight"] = state_dict["x_embedder.proj.weight"]
+ state_dict["pos_embed.proj.bias"] = state_dict["x_embedder.proj.bias"]
+ state_dict.pop("x_embedder.proj.weight")
+ state_dict.pop("x_embedder.proj.bias")
+
+ # mlp_t5 -> text_embedder
+ state_dict["text_embedder.linear_1.bias"] = state_dict["mlp_t5.0.bias"]
+ state_dict["text_embedder.linear_1.weight"] = state_dict["mlp_t5.0.weight"]
+ state_dict["text_embedder.linear_2.bias"] = state_dict["mlp_t5.2.bias"]
+ state_dict["text_embedder.linear_2.weight"] = state_dict["mlp_t5.2.weight"]
+ state_dict.pop("mlp_t5.0.bias")
+ state_dict.pop("mlp_t5.0.weight")
+ state_dict.pop("mlp_t5.2.bias")
+ state_dict.pop("mlp_t5.2.weight")
+
+ # extra_embedder -> extra_embedder
+ state_dict["time_extra_emb.extra_embedder.linear_1.bias"] = state_dict["extra_embedder.0.bias"]
+ state_dict["time_extra_emb.extra_embedder.linear_1.weight"] = state_dict["extra_embedder.0.weight"]
+ state_dict["time_extra_emb.extra_embedder.linear_2.bias"] = state_dict["extra_embedder.2.bias"]
+ state_dict["time_extra_emb.extra_embedder.linear_2.weight"] = state_dict["extra_embedder.2.weight"]
+ state_dict.pop("extra_embedder.0.bias")
+ state_dict.pop("extra_embedder.0.weight")
+ state_dict.pop("extra_embedder.2.bias")
+ state_dict.pop("extra_embedder.2.weight")
+
+ # model.final_adaLN_modulation.1 -> norm_out.linear
+ def swap_scale_shift(weight):
+ shift, scale = weight.chunk(2, dim=0)
+ new_weight = torch.cat([scale, shift], dim=0)
+ return new_weight
+
+ state_dict["norm_out.linear.weight"] = swap_scale_shift(state_dict["final_layer.adaLN_modulation.1.weight"])
+ state_dict["norm_out.linear.bias"] = swap_scale_shift(state_dict["final_layer.adaLN_modulation.1.bias"])
+ state_dict.pop("final_layer.adaLN_modulation.1.weight")
+ state_dict.pop("final_layer.adaLN_modulation.1.bias")
+
+ # final_linear -> proj_out
+ state_dict["proj_out.weight"] = state_dict["final_layer.linear.weight"]
+ state_dict["proj_out.bias"] = state_dict["final_layer.linear.bias"]
+ state_dict.pop("final_layer.linear.weight")
+ state_dict.pop("final_layer.linear.bias")
+
+ # style_embedder
+ if model_config["use_style_cond_and_image_meta_size"]:
+ print(state_dict["style_embedder.weight"])
+ print(state_dict["style_embedder.weight"].shape)
+ state_dict["time_extra_emb.style_embedder.weight"] = state_dict["style_embedder.weight"][0:1]
+ state_dict.pop("style_embedder.weight")
+
+ model.load_state_dict(state_dict)
+
+ from diffusers import HunyuanDiTPipeline
+
+ if args.use_style_cond_and_image_meta_size:
+ pipe = HunyuanDiTPipeline.from_pretrained(
+ "Tencent-Hunyuan/HunyuanDiT-Diffusers", transformer=model, torch_dtype=torch.float32
+ )
+ else:
+ pipe = HunyuanDiTPipeline.from_pretrained(
+ "Tencent-Hunyuan/HunyuanDiT-v1.2-Diffusers", transformer=model, torch_dtype=torch.float32
+ )
+ pipe.to("cuda")
+ pipe.to(dtype=torch.float32)
+
+ if args.save:
+ pipe.save_pretrained(args.output_checkpoint_path)
+
+ # ### NOTE: HunyuanDiT supports both Chinese and English inputs
+ prompt = "一个宇航员在骑马"
+ # prompt = "An astronaut riding a horse"
+ generator = torch.Generator(device="cuda").manual_seed(0)
+ image = pipe(
+ height=1024, width=1024, prompt=prompt, generator=generator, num_inference_steps=25, guidance_scale=5.0
+ ).images[0]
+
+ image.save("img.png")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ "--save", default=True, type=bool, required=False, help="Whether to save the converted pipeline or not."
+ )
+ parser.add_argument(
+ "--pt_checkpoint_path", default=None, type=str, required=True, help="Path to the .pt pretrained model."
+ )
+ parser.add_argument(
+ "--output_checkpoint_path",
+ default=None,
+ type=str,
+ required=False,
+ help="Path to the output converted diffusers pipeline.",
+ )
+ parser.add_argument(
+ "--load_key", default="none", type=str, required=False, help="The key to load from the pretrained .pt file"
+ )
+ parser.add_argument(
+ "--use_style_cond_and_image_meta_size",
+ type=bool,
+ default=False,
+ help="version <= v1.1: True; version >= v1.2: False",
+ )
+
+ args = parser.parse_args()
+ main(args)
diff --git a/diffusers/scripts/convert_i2vgen_to_diffusers.py b/diffusers/scripts/convert_i2vgen_to_diffusers.py
new file mode 100644
index 0000000000000000000000000000000000000000..b9e3ff2cd35cb3334f7228a1fd5cc4e713b50a6c
--- /dev/null
+++ b/diffusers/scripts/convert_i2vgen_to_diffusers.py
@@ -0,0 +1,510 @@
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Conversion script for the LDM checkpoints."""
+
+import argparse
+
+import torch
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
+
+from diffusers import DDIMScheduler, I2VGenXLPipeline, I2VGenXLUNet, StableDiffusionPipeline
+
+
+CLIP_ID = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K"
+
+
+def assign_to_checkpoint(
+ paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
+):
+ """
+ This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits
+ attention layers, and takes into account additional replacements that may arise.
+
+ Assigns the weights to the new checkpoint.
+ """
+ assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
+
+ # Splits the attention layers into three variables.
+ if attention_paths_to_split is not None:
+ for path, path_map in attention_paths_to_split.items():
+ old_tensor = old_checkpoint[path]
+ channels = old_tensor.shape[0] // 3
+
+ target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
+
+ num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
+
+ old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
+ query, key, value = old_tensor.split(channels // num_heads, dim=1)
+
+ checkpoint[path_map["query"]] = query.reshape(target_shape)
+ checkpoint[path_map["key"]] = key.reshape(target_shape)
+ checkpoint[path_map["value"]] = value.reshape(target_shape)
+
+ for path in paths:
+ new_path = path["new"]
+
+ # These have already been assigned
+ if attention_paths_to_split is not None and new_path in attention_paths_to_split:
+ continue
+
+ if additional_replacements is not None:
+ for replacement in additional_replacements:
+ new_path = new_path.replace(replacement["old"], replacement["new"])
+
+ # proj_attn.weight has to be converted from conv 1D to linear
+ weight = old_checkpoint[path["old"]]
+ names = ["proj_attn.weight"]
+ names_2 = ["proj_out.weight", "proj_in.weight"]
+ if any(k in new_path for k in names):
+ checkpoint[new_path] = weight[:, :, 0]
+ elif any(k in new_path for k in names_2) and len(weight.shape) > 2 and ".attentions." not in new_path:
+ checkpoint[new_path] = weight[:, :, 0]
+ else:
+ checkpoint[new_path] = weight
+
+
+def renew_attention_paths(old_list, n_shave_prefix_segments=0):
+ """
+ Updates paths inside attentions to the new naming scheme (local renaming)
+ """
+ mapping = []
+ for old_item in old_list:
+ new_item = old_item
+ mapping.append({"old": old_item, "new": new_item})
+
+ return mapping
+
+
+def shave_segments(path, n_shave_prefix_segments=1):
+ """
+ Removes segments. Positive values shave the first segments, negative shave the last segments.
+ """
+ if n_shave_prefix_segments >= 0:
+ return ".".join(path.split(".")[n_shave_prefix_segments:])
+ else:
+ return ".".join(path.split(".")[:n_shave_prefix_segments])
+
+
+def renew_temp_conv_paths(old_list, n_shave_prefix_segments=0):
+ """
+ Updates paths inside resnets to the new naming scheme (local renaming)
+ """
+ mapping = []
+ for old_item in old_list:
+ mapping.append({"old": old_item, "new": old_item})
+
+ return mapping
+
+
+def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
+ """
+ Updates paths inside resnets to the new naming scheme (local renaming)
+ """
+ mapping = []
+ for old_item in old_list:
+ new_item = old_item.replace("in_layers.0", "norm1")
+ new_item = new_item.replace("in_layers.2", "conv1")
+
+ new_item = new_item.replace("out_layers.0", "norm2")
+ new_item = new_item.replace("out_layers.3", "conv2")
+
+ new_item = new_item.replace("emb_layers.1", "time_emb_proj")
+ new_item = new_item.replace("skip_connection", "conv_shortcut")
+
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
+
+ if "temopral_conv" not in old_item:
+ mapping.append({"old": old_item, "new": new_item})
+
+ return mapping
+
+
+def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False):
+ """
+ Takes a state dict and a config, and returns a converted checkpoint.
+ """
+
+ # extract state_dict for UNet
+ unet_state_dict = {}
+ keys = list(checkpoint.keys())
+
+ unet_key = "model.diffusion_model."
+
+ # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
+ if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema:
+ print(f"Checkpoint {path} has both EMA and non-EMA weights.")
+ print(
+ "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA"
+ " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag."
+ )
+ for key in keys:
+ if key.startswith("model.diffusion_model"):
+ flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
+ unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key)
+ else:
+ if sum(k.startswith("model_ema") for k in keys) > 100:
+ print(
+ "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA"
+ " weights (usually better for inference), please make sure to add the `--extract_ema` flag."
+ )
+
+ for key in keys:
+ unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
+
+ new_checkpoint = {}
+
+ new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
+ new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
+ new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
+ new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
+
+ additional_embedding_substrings = [
+ "local_image_concat",
+ "context_embedding",
+ "local_image_embedding",
+ "fps_embedding",
+ ]
+ for k in unet_state_dict:
+ if any(substring in k for substring in additional_embedding_substrings):
+ diffusers_key = k.replace("local_image_concat", "image_latents_proj_in").replace(
+ "local_image_embedding", "image_latents_context_embedding"
+ )
+ new_checkpoint[diffusers_key] = unet_state_dict[k]
+
+ # temporal encoder.
+ new_checkpoint["image_latents_temporal_encoder.norm1.weight"] = unet_state_dict[
+ "local_temporal_encoder.layers.0.0.norm.weight"
+ ]
+ new_checkpoint["image_latents_temporal_encoder.norm1.bias"] = unet_state_dict[
+ "local_temporal_encoder.layers.0.0.norm.bias"
+ ]
+
+ # attention
+ qkv = unet_state_dict["local_temporal_encoder.layers.0.0.fn.to_qkv.weight"]
+ q, k, v = torch.chunk(qkv, 3, dim=0)
+ new_checkpoint["image_latents_temporal_encoder.attn1.to_q.weight"] = q
+ new_checkpoint["image_latents_temporal_encoder.attn1.to_k.weight"] = k
+ new_checkpoint["image_latents_temporal_encoder.attn1.to_v.weight"] = v
+ new_checkpoint["image_latents_temporal_encoder.attn1.to_out.0.weight"] = unet_state_dict[
+ "local_temporal_encoder.layers.0.0.fn.to_out.0.weight"
+ ]
+ new_checkpoint["image_latents_temporal_encoder.attn1.to_out.0.bias"] = unet_state_dict[
+ "local_temporal_encoder.layers.0.0.fn.to_out.0.bias"
+ ]
+
+ # feedforward
+ new_checkpoint["image_latents_temporal_encoder.ff.net.0.proj.weight"] = unet_state_dict[
+ "local_temporal_encoder.layers.0.1.net.0.0.weight"
+ ]
+ new_checkpoint["image_latents_temporal_encoder.ff.net.0.proj.bias"] = unet_state_dict[
+ "local_temporal_encoder.layers.0.1.net.0.0.bias"
+ ]
+ new_checkpoint["image_latents_temporal_encoder.ff.net.2.weight"] = unet_state_dict[
+ "local_temporal_encoder.layers.0.1.net.2.weight"
+ ]
+ new_checkpoint["image_latents_temporal_encoder.ff.net.2.bias"] = unet_state_dict[
+ "local_temporal_encoder.layers.0.1.net.2.bias"
+ ]
+
+ if "class_embed_type" in config:
+ if config["class_embed_type"] is None:
+ # No parameters to port
+ ...
+ elif config["class_embed_type"] == "timestep" or config["class_embed_type"] == "projection":
+ new_checkpoint["class_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"]
+ new_checkpoint["class_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"]
+ new_checkpoint["class_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"]
+ new_checkpoint["class_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"]
+ else:
+ raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}")
+
+ new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
+ new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
+
+ first_temp_attention = [v for v in unet_state_dict if v.startswith("input_blocks.0.1")]
+ paths = renew_attention_paths(first_temp_attention)
+ meta_path = {"old": "input_blocks.0.1", "new": "transformer_in"}
+ assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
+
+ new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
+ new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
+ new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
+ new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
+
+ # Retrieves the keys for the input blocks only
+ num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
+ input_blocks = {
+ layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key]
+ for layer_id in range(num_input_blocks)
+ }
+
+ # Retrieves the keys for the middle blocks only
+ num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
+ middle_blocks = {
+ layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
+ for layer_id in range(num_middle_blocks)
+ }
+
+ # Retrieves the keys for the output blocks only
+ num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
+ output_blocks = {
+ layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key]
+ for layer_id in range(num_output_blocks)
+ }
+
+ for i in range(1, num_input_blocks):
+ block_id = (i - 1) // (config["layers_per_block"] + 1)
+ layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
+
+ resnets = [
+ key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
+ ]
+ attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
+ temp_attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.2" in key]
+
+ if f"input_blocks.{i}.op.weight" in unet_state_dict:
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
+ f"input_blocks.{i}.op.weight"
+ )
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
+ f"input_blocks.{i}.op.bias"
+ )
+
+ paths = renew_resnet_paths(resnets)
+ meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
+ assign_to_checkpoint(
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
+ )
+
+ temporal_convs = [key for key in resnets if "temopral_conv" in key]
+ paths = renew_temp_conv_paths(temporal_convs)
+ meta_path = {
+ "old": f"input_blocks.{i}.0.temopral_conv",
+ "new": f"down_blocks.{block_id}.temp_convs.{layer_in_block_id}",
+ }
+ assign_to_checkpoint(
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
+ )
+
+ if len(attentions):
+ paths = renew_attention_paths(attentions)
+ meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
+ assign_to_checkpoint(
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
+ )
+
+ if len(temp_attentions):
+ paths = renew_attention_paths(temp_attentions)
+ meta_path = {
+ "old": f"input_blocks.{i}.2",
+ "new": f"down_blocks.{block_id}.temp_attentions.{layer_in_block_id}",
+ }
+ assign_to_checkpoint(
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
+ )
+
+ resnet_0 = middle_blocks[0]
+ temporal_convs_0 = [key for key in resnet_0 if "temopral_conv" in key]
+ attentions = middle_blocks[1]
+ temp_attentions = middle_blocks[2]
+ resnet_1 = middle_blocks[3]
+ temporal_convs_1 = [key for key in resnet_1 if "temopral_conv" in key]
+
+ resnet_0_paths = renew_resnet_paths(resnet_0)
+ meta_path = {"old": "middle_block.0", "new": "mid_block.resnets.0"}
+ assign_to_checkpoint(
+ resnet_0_paths, new_checkpoint, unet_state_dict, config=config, additional_replacements=[meta_path]
+ )
+
+ temp_conv_0_paths = renew_temp_conv_paths(temporal_convs_0)
+ meta_path = {"old": "middle_block.0.temopral_conv", "new": "mid_block.temp_convs.0"}
+ assign_to_checkpoint(
+ temp_conv_0_paths, new_checkpoint, unet_state_dict, config=config, additional_replacements=[meta_path]
+ )
+
+ resnet_1_paths = renew_resnet_paths(resnet_1)
+ meta_path = {"old": "middle_block.3", "new": "mid_block.resnets.1"}
+ assign_to_checkpoint(
+ resnet_1_paths, new_checkpoint, unet_state_dict, config=config, additional_replacements=[meta_path]
+ )
+
+ temp_conv_1_paths = renew_temp_conv_paths(temporal_convs_1)
+ meta_path = {"old": "middle_block.3.temopral_conv", "new": "mid_block.temp_convs.1"}
+ assign_to_checkpoint(
+ temp_conv_1_paths, new_checkpoint, unet_state_dict, config=config, additional_replacements=[meta_path]
+ )
+
+ attentions_paths = renew_attention_paths(attentions)
+ meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
+ assign_to_checkpoint(
+ attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
+ )
+
+ temp_attentions_paths = renew_attention_paths(temp_attentions)
+ meta_path = {"old": "middle_block.2", "new": "mid_block.temp_attentions.0"}
+ assign_to_checkpoint(
+ temp_attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
+ )
+
+ for i in range(num_output_blocks):
+ block_id = i // (config["layers_per_block"] + 1)
+ layer_in_block_id = i % (config["layers_per_block"] + 1)
+ output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
+ output_block_list = {}
+
+ for layer in output_block_layers:
+ layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
+ if layer_id in output_block_list:
+ output_block_list[layer_id].append(layer_name)
+ else:
+ output_block_list[layer_id] = [layer_name]
+
+ if len(output_block_list) > 1:
+ resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
+ attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
+ temp_attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.2" in key]
+
+ resnet_0_paths = renew_resnet_paths(resnets)
+ paths = renew_resnet_paths(resnets)
+
+ meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
+ assign_to_checkpoint(
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
+ )
+
+ temporal_convs = [key for key in resnets if "temopral_conv" in key]
+ paths = renew_temp_conv_paths(temporal_convs)
+ meta_path = {
+ "old": f"output_blocks.{i}.0.temopral_conv",
+ "new": f"up_blocks.{block_id}.temp_convs.{layer_in_block_id}",
+ }
+ assign_to_checkpoint(
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
+ )
+
+ output_block_list = {k: sorted(v) for k, v in output_block_list.items()}
+ if ["conv.bias", "conv.weight"] in output_block_list.values():
+ index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
+ f"output_blocks.{i}.{index}.conv.weight"
+ ]
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
+ f"output_blocks.{i}.{index}.conv.bias"
+ ]
+
+ # Clear attentions as they have been attributed above.
+ if len(attentions) == 2:
+ attentions = []
+
+ if len(attentions):
+ paths = renew_attention_paths(attentions)
+ meta_path = {
+ "old": f"output_blocks.{i}.1",
+ "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
+ }
+ assign_to_checkpoint(
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
+ )
+
+ if len(temp_attentions):
+ paths = renew_attention_paths(temp_attentions)
+ meta_path = {
+ "old": f"output_blocks.{i}.2",
+ "new": f"up_blocks.{block_id}.temp_attentions.{layer_in_block_id}",
+ }
+ assign_to_checkpoint(
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
+ )
+ else:
+ resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
+ for path in resnet_0_paths:
+ old_path = ".".join(["output_blocks", str(i), path["old"]])
+ new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
+ new_checkpoint[new_path] = unet_state_dict[old_path]
+
+ temopral_conv_paths = [l for l in output_block_layers if "temopral_conv" in l]
+ for path in temopral_conv_paths:
+ pruned_path = path.split("temopral_conv.")[-1]
+ old_path = ".".join(["output_blocks", str(i), str(block_id), "temopral_conv", pruned_path])
+ new_path = ".".join(["up_blocks", str(block_id), "temp_convs", str(layer_in_block_id), pruned_path])
+ new_checkpoint[new_path] = unet_state_dict[old_path]
+
+ return new_checkpoint
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ "--unet_checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert."
+ )
+ parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
+ parser.add_argument("--push_to_hub", action="store_true")
+ args = parser.parse_args()
+
+ # UNet
+ unet_checkpoint = torch.load(args.unet_checkpoint_path, map_location="cpu")
+ unet_checkpoint = unet_checkpoint["state_dict"]
+ unet = I2VGenXLUNet(sample_size=32)
+
+ converted_ckpt = convert_ldm_unet_checkpoint(unet_checkpoint, unet.config)
+
+ diff_0 = set(unet.state_dict().keys()) - set(converted_ckpt.keys())
+ diff_1 = set(converted_ckpt.keys()) - set(unet.state_dict().keys())
+
+ assert len(diff_0) == len(diff_1) == 0, "Converted weights don't match"
+
+ unet.load_state_dict(converted_ckpt, strict=True)
+
+ # vae
+ temp_pipe = StableDiffusionPipeline.from_single_file(
+ "https://huggingface.co/ali-vilab/i2vgen-xl/blob/main/models/v2-1_512-ema-pruned.ckpt"
+ )
+ vae = temp_pipe.vae
+ del temp_pipe
+
+ # text encoder and tokenizer
+ text_encoder = CLIPTextModel.from_pretrained(CLIP_ID)
+ tokenizer = CLIPTokenizer.from_pretrained(CLIP_ID)
+
+ # image encoder and feature extractor
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(CLIP_ID)
+ feature_extractor = CLIPImageProcessor.from_pretrained(CLIP_ID)
+
+ # scheduler
+ # https://github.com/ali-vilab/i2vgen-xl/blob/main/configs/i2vgen_xl_train.yaml
+ scheduler = DDIMScheduler(
+ beta_schedule="squaredcos_cap_v2",
+ rescale_betas_zero_snr=True,
+ set_alpha_to_one=True,
+ clip_sample=False,
+ steps_offset=1,
+ timestep_spacing="leading",
+ prediction_type="v_prediction",
+ )
+
+ # final
+ pipeline = I2VGenXLPipeline(
+ unet=unet,
+ vae=vae,
+ image_encoder=image_encoder,
+ feature_extractor=feature_extractor,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ scheduler=scheduler,
+ )
+
+ pipeline.save_pretrained(args.dump_path, push_to_hub=args.push_to_hub)
diff --git a/diffusers/scripts/convert_if.py b/diffusers/scripts/convert_if.py
new file mode 100644
index 0000000000000000000000000000000000000000..85c739ca92f04bd2e1fefdaca3ff15b59f75c66e
--- /dev/null
+++ b/diffusers/scripts/convert_if.py
@@ -0,0 +1,1250 @@
+import argparse
+import inspect
+import os
+
+import numpy as np
+import torch
+import yaml
+from torch.nn import functional as F
+from transformers import CLIPConfig, CLIPImageProcessor, CLIPVisionModelWithProjection, T5EncoderModel, T5Tokenizer
+
+from diffusers import DDPMScheduler, IFPipeline, IFSuperResolutionPipeline, UNet2DConditionModel
+from diffusers.pipelines.deepfloyd_if.safety_checker import IFSafetyChecker
+
+
+def parse_args():
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument("--dump_path", required=False, default=None, type=str)
+
+ parser.add_argument("--dump_path_stage_2", required=False, default=None, type=str)
+
+ parser.add_argument("--dump_path_stage_3", required=False, default=None, type=str)
+
+ parser.add_argument("--unet_config", required=False, default=None, type=str, help="Path to unet config file")
+
+ parser.add_argument(
+ "--unet_checkpoint_path", required=False, default=None, type=str, help="Path to unet checkpoint file"
+ )
+
+ parser.add_argument(
+ "--unet_checkpoint_path_stage_2",
+ required=False,
+ default=None,
+ type=str,
+ help="Path to stage 2 unet checkpoint file",
+ )
+
+ parser.add_argument(
+ "--unet_checkpoint_path_stage_3",
+ required=False,
+ default=None,
+ type=str,
+ help="Path to stage 3 unet checkpoint file",
+ )
+
+ parser.add_argument("--p_head_path", type=str, required=True)
+
+ parser.add_argument("--w_head_path", type=str, required=True)
+
+ args = parser.parse_args()
+
+ return args
+
+
+def main(args):
+ tokenizer = T5Tokenizer.from_pretrained("google/t5-v1_1-xxl")
+ text_encoder = T5EncoderModel.from_pretrained("google/t5-v1_1-xxl")
+
+ feature_extractor = CLIPImageProcessor.from_pretrained("openai/clip-vit-large-patch14")
+ safety_checker = convert_safety_checker(p_head_path=args.p_head_path, w_head_path=args.w_head_path)
+
+ if args.unet_config is not None and args.unet_checkpoint_path is not None and args.dump_path is not None:
+ convert_stage_1_pipeline(tokenizer, text_encoder, feature_extractor, safety_checker, args)
+
+ if args.unet_checkpoint_path_stage_2 is not None and args.dump_path_stage_2 is not None:
+ convert_super_res_pipeline(tokenizer, text_encoder, feature_extractor, safety_checker, args, stage=2)
+
+ if args.unet_checkpoint_path_stage_3 is not None and args.dump_path_stage_3 is not None:
+ convert_super_res_pipeline(tokenizer, text_encoder, feature_extractor, safety_checker, args, stage=3)
+
+
+def convert_stage_1_pipeline(tokenizer, text_encoder, feature_extractor, safety_checker, args):
+ unet = get_stage_1_unet(args.unet_config, args.unet_checkpoint_path)
+
+ scheduler = DDPMScheduler(
+ variance_type="learned_range",
+ beta_schedule="squaredcos_cap_v2",
+ prediction_type="epsilon",
+ thresholding=True,
+ dynamic_thresholding_ratio=0.95,
+ sample_max_value=1.5,
+ )
+
+ pipe = IFPipeline(
+ tokenizer=tokenizer,
+ text_encoder=text_encoder,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ requires_safety_checker=True,
+ )
+
+ pipe.save_pretrained(args.dump_path)
+
+
+def convert_super_res_pipeline(tokenizer, text_encoder, feature_extractor, safety_checker, args, stage):
+ if stage == 2:
+ unet_checkpoint_path = args.unet_checkpoint_path_stage_2
+ sample_size = None
+ dump_path = args.dump_path_stage_2
+ elif stage == 3:
+ unet_checkpoint_path = args.unet_checkpoint_path_stage_3
+ sample_size = 1024
+ dump_path = args.dump_path_stage_3
+ else:
+ assert False
+
+ unet = get_super_res_unet(unet_checkpoint_path, verify_param_count=False, sample_size=sample_size)
+
+ image_noising_scheduler = DDPMScheduler(
+ beta_schedule="squaredcos_cap_v2",
+ )
+
+ scheduler = DDPMScheduler(
+ variance_type="learned_range",
+ beta_schedule="squaredcos_cap_v2",
+ prediction_type="epsilon",
+ thresholding=True,
+ dynamic_thresholding_ratio=0.95,
+ sample_max_value=1.0,
+ )
+
+ pipe = IFSuperResolutionPipeline(
+ tokenizer=tokenizer,
+ text_encoder=text_encoder,
+ unet=unet,
+ scheduler=scheduler,
+ image_noising_scheduler=image_noising_scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ requires_safety_checker=True,
+ )
+
+ pipe.save_pretrained(dump_path)
+
+
+def get_stage_1_unet(unet_config, unet_checkpoint_path):
+ original_unet_config = yaml.safe_load(unet_config)
+ original_unet_config = original_unet_config["params"]
+
+ unet_diffusers_config = create_unet_diffusers_config(original_unet_config)
+
+ unet = UNet2DConditionModel(**unet_diffusers_config)
+
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ unet_checkpoint = torch.load(unet_checkpoint_path, map_location=device)
+
+ converted_unet_checkpoint = convert_ldm_unet_checkpoint(
+ unet_checkpoint, unet_diffusers_config, path=unet_checkpoint_path
+ )
+
+ unet.load_state_dict(converted_unet_checkpoint)
+
+ return unet
+
+
+def convert_safety_checker(p_head_path, w_head_path):
+ state_dict = {}
+
+ # p head
+
+ p_head = np.load(p_head_path)
+
+ p_head_weights = p_head["weights"]
+ p_head_weights = torch.from_numpy(p_head_weights)
+ p_head_weights = p_head_weights.unsqueeze(0)
+
+ p_head_biases = p_head["biases"]
+ p_head_biases = torch.from_numpy(p_head_biases)
+ p_head_biases = p_head_biases.unsqueeze(0)
+
+ state_dict["p_head.weight"] = p_head_weights
+ state_dict["p_head.bias"] = p_head_biases
+
+ # w head
+
+ w_head = np.load(w_head_path)
+
+ w_head_weights = w_head["weights"]
+ w_head_weights = torch.from_numpy(w_head_weights)
+ w_head_weights = w_head_weights.unsqueeze(0)
+
+ w_head_biases = w_head["biases"]
+ w_head_biases = torch.from_numpy(w_head_biases)
+ w_head_biases = w_head_biases.unsqueeze(0)
+
+ state_dict["w_head.weight"] = w_head_weights
+ state_dict["w_head.bias"] = w_head_biases
+
+ # vision model
+
+ vision_model = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14")
+ vision_model_state_dict = vision_model.state_dict()
+
+ for key, value in vision_model_state_dict.items():
+ key = f"vision_model.{key}"
+ state_dict[key] = value
+
+ # full model
+
+ config = CLIPConfig.from_pretrained("openai/clip-vit-large-patch14")
+ safety_checker = IFSafetyChecker(config)
+
+ safety_checker.load_state_dict(state_dict)
+
+ return safety_checker
+
+
+def create_unet_diffusers_config(original_unet_config, class_embed_type=None):
+ attention_resolutions = parse_list(original_unet_config["attention_resolutions"])
+ attention_resolutions = [original_unet_config["image_size"] // int(res) for res in attention_resolutions]
+
+ channel_mult = parse_list(original_unet_config["channel_mult"])
+ block_out_channels = [original_unet_config["model_channels"] * mult for mult in channel_mult]
+
+ down_block_types = []
+ resolution = 1
+
+ for i in range(len(block_out_channels)):
+ if resolution in attention_resolutions:
+ block_type = "SimpleCrossAttnDownBlock2D"
+ elif original_unet_config["resblock_updown"]:
+ block_type = "ResnetDownsampleBlock2D"
+ else:
+ block_type = "DownBlock2D"
+
+ down_block_types.append(block_type)
+
+ if i != len(block_out_channels) - 1:
+ resolution *= 2
+
+ up_block_types = []
+ for i in range(len(block_out_channels)):
+ if resolution in attention_resolutions:
+ block_type = "SimpleCrossAttnUpBlock2D"
+ elif original_unet_config["resblock_updown"]:
+ block_type = "ResnetUpsampleBlock2D"
+ else:
+ block_type = "UpBlock2D"
+ up_block_types.append(block_type)
+ resolution //= 2
+
+ head_dim = original_unet_config["num_head_channels"]
+
+ use_linear_projection = (
+ original_unet_config["use_linear_in_transformer"]
+ if "use_linear_in_transformer" in original_unet_config
+ else False
+ )
+ if use_linear_projection:
+ # stable diffusion 2-base-512 and 2-768
+ if head_dim is None:
+ head_dim = [5, 10, 20, 20]
+
+ projection_class_embeddings_input_dim = None
+
+ if class_embed_type is None:
+ if "num_classes" in original_unet_config:
+ if original_unet_config["num_classes"] == "sequential":
+ class_embed_type = "projection"
+ assert "adm_in_channels" in original_unet_config
+ projection_class_embeddings_input_dim = original_unet_config["adm_in_channels"]
+ else:
+ raise NotImplementedError(
+ f"Unknown conditional unet num_classes config: {original_unet_config['num_classes']}"
+ )
+
+ config = {
+ "sample_size": original_unet_config["image_size"],
+ "in_channels": original_unet_config["in_channels"],
+ "down_block_types": tuple(down_block_types),
+ "block_out_channels": tuple(block_out_channels),
+ "layers_per_block": original_unet_config["num_res_blocks"],
+ "cross_attention_dim": original_unet_config["encoder_channels"],
+ "attention_head_dim": head_dim,
+ "use_linear_projection": use_linear_projection,
+ "class_embed_type": class_embed_type,
+ "projection_class_embeddings_input_dim": projection_class_embeddings_input_dim,
+ "out_channels": original_unet_config["out_channels"],
+ "up_block_types": tuple(up_block_types),
+ "upcast_attention": False, # TODO: guessing
+ "cross_attention_norm": "group_norm",
+ "mid_block_type": "UNetMidBlock2DSimpleCrossAttn",
+ "addition_embed_type": "text",
+ "act_fn": "gelu",
+ }
+
+ if original_unet_config["use_scale_shift_norm"]:
+ config["resnet_time_scale_shift"] = "scale_shift"
+
+ if "encoder_dim" in original_unet_config:
+ config["encoder_hid_dim"] = original_unet_config["encoder_dim"]
+
+ return config
+
+
+def convert_ldm_unet_checkpoint(unet_state_dict, config, path=None):
+ """
+ Takes a state dict and a config, and returns a converted checkpoint.
+ """
+ new_checkpoint = {}
+
+ new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
+ new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
+ new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
+ new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
+
+ if config["class_embed_type"] in [None, "identity"]:
+ # No parameters to port
+ ...
+ elif config["class_embed_type"] == "timestep" or config["class_embed_type"] == "projection":
+ new_checkpoint["class_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"]
+ new_checkpoint["class_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"]
+ new_checkpoint["class_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"]
+ new_checkpoint["class_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"]
+ else:
+ raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}")
+
+ new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
+ new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
+
+ new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
+ new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
+ new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
+ new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
+
+ # Retrieves the keys for the input blocks only
+ num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
+ input_blocks = {
+ layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}." in key]
+ for layer_id in range(num_input_blocks)
+ }
+
+ # Retrieves the keys for the middle blocks only
+ num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
+ middle_blocks = {
+ layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
+ for layer_id in range(num_middle_blocks)
+ }
+
+ # Retrieves the keys for the output blocks only
+ num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
+ output_blocks = {
+ layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}." in key]
+ for layer_id in range(num_output_blocks)
+ }
+
+ for i in range(1, num_input_blocks):
+ block_id = (i - 1) // (config["layers_per_block"] + 1)
+ layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
+
+ resnets = [
+ key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
+ ]
+ attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
+
+ if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
+ f"input_blocks.{i}.0.op.weight"
+ )
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
+ f"input_blocks.{i}.0.op.bias"
+ )
+
+ paths = renew_resnet_paths(resnets)
+
+ # TODO need better check than i in [4, 8, 12, 16]
+ block_type = config["down_block_types"][block_id]
+ if (block_type == "ResnetDownsampleBlock2D" or block_type == "SimpleCrossAttnDownBlock2D") and i in [
+ 4,
+ 8,
+ 12,
+ 16,
+ ]:
+ meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.downsamplers.0"}
+ else:
+ meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
+
+ assign_to_checkpoint(
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
+ )
+
+ if len(attentions):
+ old_path = f"input_blocks.{i}.1"
+ new_path = f"down_blocks.{block_id}.attentions.{layer_in_block_id}"
+
+ assign_attention_to_checkpoint(
+ new_checkpoint=new_checkpoint,
+ unet_state_dict=unet_state_dict,
+ old_path=old_path,
+ new_path=new_path,
+ config=config,
+ )
+
+ paths = renew_attention_paths(attentions)
+ meta_path = {"old": old_path, "new": new_path}
+ assign_to_checkpoint(
+ paths,
+ new_checkpoint,
+ unet_state_dict,
+ additional_replacements=[meta_path],
+ config=config,
+ )
+
+ resnet_0 = middle_blocks[0]
+ attentions = middle_blocks[1]
+ resnet_1 = middle_blocks[2]
+
+ resnet_0_paths = renew_resnet_paths(resnet_0)
+ assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
+
+ resnet_1_paths = renew_resnet_paths(resnet_1)
+ assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
+
+ old_path = "middle_block.1"
+ new_path = "mid_block.attentions.0"
+
+ assign_attention_to_checkpoint(
+ new_checkpoint=new_checkpoint,
+ unet_state_dict=unet_state_dict,
+ old_path=old_path,
+ new_path=new_path,
+ config=config,
+ )
+
+ attentions_paths = renew_attention_paths(attentions)
+ meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
+ assign_to_checkpoint(
+ attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
+ )
+
+ for i in range(num_output_blocks):
+ block_id = i // (config["layers_per_block"] + 1)
+ layer_in_block_id = i % (config["layers_per_block"] + 1)
+ output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
+ output_block_list = {}
+
+ for layer in output_block_layers:
+ layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
+ if layer_id in output_block_list:
+ output_block_list[layer_id].append(layer_name)
+ else:
+ output_block_list[layer_id] = [layer_name]
+
+ # len(output_block_list) == 1 -> resnet
+ # len(output_block_list) == 2 -> resnet, attention
+ # len(output_block_list) == 3 -> resnet, attention, upscale resnet
+
+ if len(output_block_list) > 1:
+ resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
+ attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
+
+ paths = renew_resnet_paths(resnets)
+
+ meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
+
+ assign_to_checkpoint(
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
+ )
+
+ output_block_list = {k: sorted(v) for k, v in output_block_list.items()}
+ if ["conv.bias", "conv.weight"] in output_block_list.values():
+ index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
+ f"output_blocks.{i}.{index}.conv.weight"
+ ]
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
+ f"output_blocks.{i}.{index}.conv.bias"
+ ]
+
+ # Clear attentions as they have been attributed above.
+ if len(attentions) == 2:
+ attentions = []
+
+ if len(attentions):
+ old_path = f"output_blocks.{i}.1"
+ new_path = f"up_blocks.{block_id}.attentions.{layer_in_block_id}"
+
+ assign_attention_to_checkpoint(
+ new_checkpoint=new_checkpoint,
+ unet_state_dict=unet_state_dict,
+ old_path=old_path,
+ new_path=new_path,
+ config=config,
+ )
+
+ paths = renew_attention_paths(attentions)
+ meta_path = {
+ "old": old_path,
+ "new": new_path,
+ }
+ assign_to_checkpoint(
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
+ )
+
+ if len(output_block_list) == 3:
+ resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.2" in key]
+ paths = renew_resnet_paths(resnets)
+ meta_path = {"old": f"output_blocks.{i}.2", "new": f"up_blocks.{block_id}.upsamplers.0"}
+ assign_to_checkpoint(
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
+ )
+ else:
+ resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
+ for path in resnet_0_paths:
+ old_path = ".".join(["output_blocks", str(i), path["old"]])
+ new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
+
+ new_checkpoint[new_path] = unet_state_dict[old_path]
+
+ if "encoder_proj.weight" in unet_state_dict:
+ new_checkpoint["encoder_hid_proj.weight"] = unet_state_dict.pop("encoder_proj.weight")
+ new_checkpoint["encoder_hid_proj.bias"] = unet_state_dict.pop("encoder_proj.bias")
+
+ if "encoder_pooling.0.weight" in unet_state_dict:
+ new_checkpoint["add_embedding.norm1.weight"] = unet_state_dict.pop("encoder_pooling.0.weight")
+ new_checkpoint["add_embedding.norm1.bias"] = unet_state_dict.pop("encoder_pooling.0.bias")
+
+ new_checkpoint["add_embedding.pool.positional_embedding"] = unet_state_dict.pop(
+ "encoder_pooling.1.positional_embedding"
+ )
+ new_checkpoint["add_embedding.pool.k_proj.weight"] = unet_state_dict.pop("encoder_pooling.1.k_proj.weight")
+ new_checkpoint["add_embedding.pool.k_proj.bias"] = unet_state_dict.pop("encoder_pooling.1.k_proj.bias")
+ new_checkpoint["add_embedding.pool.q_proj.weight"] = unet_state_dict.pop("encoder_pooling.1.q_proj.weight")
+ new_checkpoint["add_embedding.pool.q_proj.bias"] = unet_state_dict.pop("encoder_pooling.1.q_proj.bias")
+ new_checkpoint["add_embedding.pool.v_proj.weight"] = unet_state_dict.pop("encoder_pooling.1.v_proj.weight")
+ new_checkpoint["add_embedding.pool.v_proj.bias"] = unet_state_dict.pop("encoder_pooling.1.v_proj.bias")
+
+ new_checkpoint["add_embedding.proj.weight"] = unet_state_dict.pop("encoder_pooling.2.weight")
+ new_checkpoint["add_embedding.proj.bias"] = unet_state_dict.pop("encoder_pooling.2.bias")
+
+ new_checkpoint["add_embedding.norm2.weight"] = unet_state_dict.pop("encoder_pooling.3.weight")
+ new_checkpoint["add_embedding.norm2.bias"] = unet_state_dict.pop("encoder_pooling.3.bias")
+
+ return new_checkpoint
+
+
+def shave_segments(path, n_shave_prefix_segments=1):
+ """
+ Removes segments. Positive values shave the first segments, negative shave the last segments.
+ """
+ if n_shave_prefix_segments >= 0:
+ return ".".join(path.split(".")[n_shave_prefix_segments:])
+ else:
+ return ".".join(path.split(".")[:n_shave_prefix_segments])
+
+
+def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
+ """
+ Updates paths inside resnets to the new naming scheme (local renaming)
+ """
+ mapping = []
+ for old_item in old_list:
+ new_item = old_item.replace("in_layers.0", "norm1")
+ new_item = new_item.replace("in_layers.2", "conv1")
+
+ new_item = new_item.replace("out_layers.0", "norm2")
+ new_item = new_item.replace("out_layers.3", "conv2")
+
+ new_item = new_item.replace("emb_layers.1", "time_emb_proj")
+ new_item = new_item.replace("skip_connection", "conv_shortcut")
+
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
+
+ mapping.append({"old": old_item, "new": new_item})
+
+ return mapping
+
+
+def renew_attention_paths(old_list, n_shave_prefix_segments=0):
+ """
+ Updates paths inside attentions to the new naming scheme (local renaming)
+ """
+ mapping = []
+ for old_item in old_list:
+ new_item = old_item
+
+ if "qkv" in new_item:
+ continue
+
+ if "encoder_kv" in new_item:
+ continue
+
+ new_item = new_item.replace("norm.weight", "group_norm.weight")
+ new_item = new_item.replace("norm.bias", "group_norm.bias")
+
+ new_item = new_item.replace("proj_out.weight", "to_out.0.weight")
+ new_item = new_item.replace("proj_out.bias", "to_out.0.bias")
+
+ new_item = new_item.replace("norm_encoder.weight", "norm_cross.weight")
+ new_item = new_item.replace("norm_encoder.bias", "norm_cross.bias")
+
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
+
+ mapping.append({"old": old_item, "new": new_item})
+
+ return mapping
+
+
+def assign_attention_to_checkpoint(new_checkpoint, unet_state_dict, old_path, new_path, config):
+ qkv_weight = unet_state_dict.pop(f"{old_path}.qkv.weight")
+ qkv_weight = qkv_weight[:, :, 0]
+
+ qkv_bias = unet_state_dict.pop(f"{old_path}.qkv.bias")
+
+ is_cross_attn_only = "only_cross_attention" in config and config["only_cross_attention"]
+
+ split = 1 if is_cross_attn_only else 3
+
+ weights, bias = split_attentions(
+ weight=qkv_weight,
+ bias=qkv_bias,
+ split=split,
+ chunk_size=config["attention_head_dim"],
+ )
+
+ if is_cross_attn_only:
+ query_weight, q_bias = weights, bias
+ new_checkpoint[f"{new_path}.to_q.weight"] = query_weight[0]
+ new_checkpoint[f"{new_path}.to_q.bias"] = q_bias[0]
+ else:
+ [query_weight, key_weight, value_weight], [q_bias, k_bias, v_bias] = weights, bias
+ new_checkpoint[f"{new_path}.to_q.weight"] = query_weight
+ new_checkpoint[f"{new_path}.to_q.bias"] = q_bias
+ new_checkpoint[f"{new_path}.to_k.weight"] = key_weight
+ new_checkpoint[f"{new_path}.to_k.bias"] = k_bias
+ new_checkpoint[f"{new_path}.to_v.weight"] = value_weight
+ new_checkpoint[f"{new_path}.to_v.bias"] = v_bias
+
+ encoder_kv_weight = unet_state_dict.pop(f"{old_path}.encoder_kv.weight")
+ encoder_kv_weight = encoder_kv_weight[:, :, 0]
+
+ encoder_kv_bias = unet_state_dict.pop(f"{old_path}.encoder_kv.bias")
+
+ [encoder_k_weight, encoder_v_weight], [encoder_k_bias, encoder_v_bias] = split_attentions(
+ weight=encoder_kv_weight,
+ bias=encoder_kv_bias,
+ split=2,
+ chunk_size=config["attention_head_dim"],
+ )
+
+ new_checkpoint[f"{new_path}.add_k_proj.weight"] = encoder_k_weight
+ new_checkpoint[f"{new_path}.add_k_proj.bias"] = encoder_k_bias
+ new_checkpoint[f"{new_path}.add_v_proj.weight"] = encoder_v_weight
+ new_checkpoint[f"{new_path}.add_v_proj.bias"] = encoder_v_bias
+
+
+def assign_to_checkpoint(paths, checkpoint, old_checkpoint, additional_replacements=None, config=None):
+ """
+ This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits
+ attention layers, and takes into account additional replacements that may arise.
+
+ Assigns the weights to the new checkpoint.
+ """
+ assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
+
+ for path in paths:
+ new_path = path["new"]
+
+ # Global renaming happens here
+ new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
+ new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
+ new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
+
+ if additional_replacements is not None:
+ for replacement in additional_replacements:
+ new_path = new_path.replace(replacement["old"], replacement["new"])
+
+ # proj_attn.weight has to be converted from conv 1D to linear
+ if "proj_attn.weight" in new_path or "to_out.0.weight" in new_path:
+ checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
+ else:
+ checkpoint[new_path] = old_checkpoint[path["old"]]
+
+
+# TODO maybe document and/or can do more efficiently (build indices in for loop and extract once for each split?)
+def split_attentions(*, weight, bias, split, chunk_size):
+ weights = [None] * split
+ biases = [None] * split
+
+ weights_biases_idx = 0
+
+ for starting_row_index in range(0, weight.shape[0], chunk_size):
+ row_indices = torch.arange(starting_row_index, starting_row_index + chunk_size)
+
+ weight_rows = weight[row_indices, :]
+ bias_rows = bias[row_indices]
+
+ if weights[weights_biases_idx] is None:
+ weights[weights_biases_idx] = weight_rows
+ biases[weights_biases_idx] = bias_rows
+ else:
+ assert weights[weights_biases_idx] is not None
+ weights[weights_biases_idx] = torch.concat([weights[weights_biases_idx], weight_rows])
+ biases[weights_biases_idx] = torch.concat([biases[weights_biases_idx], bias_rows])
+
+ weights_biases_idx = (weights_biases_idx + 1) % split
+
+ return weights, biases
+
+
+def parse_list(value):
+ if isinstance(value, str):
+ value = value.split(",")
+ value = [int(v) for v in value]
+ elif isinstance(value, list):
+ pass
+ else:
+ raise ValueError(f"Can't parse list for type: {type(value)}")
+
+ return value
+
+
+# below is copy and pasted from original convert_if_stage_2.py script
+
+
+def get_super_res_unet(unet_checkpoint_path, verify_param_count=True, sample_size=None):
+ orig_path = unet_checkpoint_path
+
+ original_unet_config = yaml.safe_load(os.path.join(orig_path, "config.yml"))
+ original_unet_config = original_unet_config["params"]
+
+ unet_diffusers_config = superres_create_unet_diffusers_config(original_unet_config)
+ unet_diffusers_config["time_embedding_dim"] = original_unet_config["model_channels"] * int(
+ original_unet_config["channel_mult"].split(",")[-1]
+ )
+ if original_unet_config["encoder_dim"] != original_unet_config["encoder_channels"]:
+ unet_diffusers_config["encoder_hid_dim"] = original_unet_config["encoder_dim"]
+ unet_diffusers_config["class_embed_type"] = "timestep"
+ unet_diffusers_config["addition_embed_type"] = "text"
+
+ unet_diffusers_config["time_embedding_act_fn"] = "gelu"
+ unet_diffusers_config["resnet_skip_time_act"] = True
+ unet_diffusers_config["resnet_out_scale_factor"] = 1 / 0.7071
+ unet_diffusers_config["mid_block_scale_factor"] = 1 / 0.7071
+ unet_diffusers_config["only_cross_attention"] = (
+ bool(original_unet_config["disable_self_attentions"])
+ if (
+ "disable_self_attentions" in original_unet_config
+ and isinstance(original_unet_config["disable_self_attentions"], int)
+ )
+ else True
+ )
+
+ if sample_size is None:
+ unet_diffusers_config["sample_size"] = original_unet_config["image_size"]
+ else:
+ # The second upscaler unet's sample size is incorrectly specified
+ # in the config and is instead hardcoded in source
+ unet_diffusers_config["sample_size"] = sample_size
+
+ unet_checkpoint = torch.load(os.path.join(unet_checkpoint_path, "pytorch_model.bin"), map_location="cpu")
+
+ if verify_param_count:
+ # check that architecture matches - is a bit slow
+ verify_param_count(orig_path, unet_diffusers_config)
+
+ converted_unet_checkpoint = superres_convert_ldm_unet_checkpoint(
+ unet_checkpoint, unet_diffusers_config, path=unet_checkpoint_path
+ )
+ converted_keys = converted_unet_checkpoint.keys()
+
+ model = UNet2DConditionModel(**unet_diffusers_config)
+ expected_weights = model.state_dict().keys()
+
+ diff_c_e = set(converted_keys) - set(expected_weights)
+ diff_e_c = set(expected_weights) - set(converted_keys)
+
+ assert len(diff_e_c) == 0, f"Expected, but not converted: {diff_e_c}"
+ assert len(diff_c_e) == 0, f"Converted, but not expected: {diff_c_e}"
+
+ model.load_state_dict(converted_unet_checkpoint)
+
+ return model
+
+
+def superres_create_unet_diffusers_config(original_unet_config):
+ attention_resolutions = parse_list(original_unet_config["attention_resolutions"])
+ attention_resolutions = [original_unet_config["image_size"] // int(res) for res in attention_resolutions]
+
+ channel_mult = parse_list(original_unet_config["channel_mult"])
+ block_out_channels = [original_unet_config["model_channels"] * mult for mult in channel_mult]
+
+ down_block_types = []
+ resolution = 1
+
+ for i in range(len(block_out_channels)):
+ if resolution in attention_resolutions:
+ block_type = "SimpleCrossAttnDownBlock2D"
+ elif original_unet_config["resblock_updown"]:
+ block_type = "ResnetDownsampleBlock2D"
+ else:
+ block_type = "DownBlock2D"
+
+ down_block_types.append(block_type)
+
+ if i != len(block_out_channels) - 1:
+ resolution *= 2
+
+ up_block_types = []
+ for i in range(len(block_out_channels)):
+ if resolution in attention_resolutions:
+ block_type = "SimpleCrossAttnUpBlock2D"
+ elif original_unet_config["resblock_updown"]:
+ block_type = "ResnetUpsampleBlock2D"
+ else:
+ block_type = "UpBlock2D"
+ up_block_types.append(block_type)
+ resolution //= 2
+
+ head_dim = original_unet_config["num_head_channels"]
+ use_linear_projection = (
+ original_unet_config["use_linear_in_transformer"]
+ if "use_linear_in_transformer" in original_unet_config
+ else False
+ )
+ if use_linear_projection:
+ # stable diffusion 2-base-512 and 2-768
+ if head_dim is None:
+ head_dim = [5, 10, 20, 20]
+
+ class_embed_type = None
+ projection_class_embeddings_input_dim = None
+
+ if "num_classes" in original_unet_config:
+ if original_unet_config["num_classes"] == "sequential":
+ class_embed_type = "projection"
+ assert "adm_in_channels" in original_unet_config
+ projection_class_embeddings_input_dim = original_unet_config["adm_in_channels"]
+ else:
+ raise NotImplementedError(
+ f"Unknown conditional unet num_classes config: {original_unet_config['num_classes']}"
+ )
+
+ config = {
+ "in_channels": original_unet_config["in_channels"],
+ "down_block_types": tuple(down_block_types),
+ "block_out_channels": tuple(block_out_channels),
+ "layers_per_block": tuple(original_unet_config["num_res_blocks"]),
+ "cross_attention_dim": original_unet_config["encoder_channels"],
+ "attention_head_dim": head_dim,
+ "use_linear_projection": use_linear_projection,
+ "class_embed_type": class_embed_type,
+ "projection_class_embeddings_input_dim": projection_class_embeddings_input_dim,
+ "out_channels": original_unet_config["out_channels"],
+ "up_block_types": tuple(up_block_types),
+ "upcast_attention": False, # TODO: guessing
+ "cross_attention_norm": "group_norm",
+ "mid_block_type": "UNetMidBlock2DSimpleCrossAttn",
+ "act_fn": "gelu",
+ }
+
+ if original_unet_config["use_scale_shift_norm"]:
+ config["resnet_time_scale_shift"] = "scale_shift"
+
+ return config
+
+
+def superres_convert_ldm_unet_checkpoint(unet_state_dict, config, path=None, extract_ema=False):
+ """
+ Takes a state dict and a config, and returns a converted checkpoint.
+ """
+ new_checkpoint = {}
+
+ new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
+ new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
+ new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
+ new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
+
+ if config["class_embed_type"] is None:
+ # No parameters to port
+ ...
+ elif config["class_embed_type"] == "timestep" or config["class_embed_type"] == "projection":
+ new_checkpoint["class_embedding.linear_1.weight"] = unet_state_dict["aug_proj.0.weight"]
+ new_checkpoint["class_embedding.linear_1.bias"] = unet_state_dict["aug_proj.0.bias"]
+ new_checkpoint["class_embedding.linear_2.weight"] = unet_state_dict["aug_proj.2.weight"]
+ new_checkpoint["class_embedding.linear_2.bias"] = unet_state_dict["aug_proj.2.bias"]
+ else:
+ raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}")
+
+ if "encoder_proj.weight" in unet_state_dict:
+ new_checkpoint["encoder_hid_proj.weight"] = unet_state_dict["encoder_proj.weight"]
+ new_checkpoint["encoder_hid_proj.bias"] = unet_state_dict["encoder_proj.bias"]
+
+ if "encoder_pooling.0.weight" in unet_state_dict:
+ mapping = {
+ "encoder_pooling.0": "add_embedding.norm1",
+ "encoder_pooling.1": "add_embedding.pool",
+ "encoder_pooling.2": "add_embedding.proj",
+ "encoder_pooling.3": "add_embedding.norm2",
+ }
+ for key in unet_state_dict.keys():
+ if key.startswith("encoder_pooling"):
+ prefix = key[: len("encoder_pooling.0")]
+ new_key = key.replace(prefix, mapping[prefix])
+ new_checkpoint[new_key] = unet_state_dict[key]
+
+ new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
+ new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
+
+ new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
+ new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
+ new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
+ new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
+
+ # Retrieves the keys for the input blocks only
+ num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
+ input_blocks = {
+ layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}." in key]
+ for layer_id in range(num_input_blocks)
+ }
+
+ # Retrieves the keys for the middle blocks only
+ num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
+ middle_blocks = {
+ layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
+ for layer_id in range(num_middle_blocks)
+ }
+
+ # Retrieves the keys for the output blocks only
+ num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
+ output_blocks = {
+ layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}." in key]
+ for layer_id in range(num_output_blocks)
+ }
+ if not isinstance(config["layers_per_block"], int):
+ layers_per_block_list = [e + 1 for e in config["layers_per_block"]]
+ layers_per_block_cumsum = list(np.cumsum(layers_per_block_list))
+ downsampler_ids = layers_per_block_cumsum
+ else:
+ # TODO need better check than i in [4, 8, 12, 16]
+ downsampler_ids = [4, 8, 12, 16]
+
+ for i in range(1, num_input_blocks):
+ if isinstance(config["layers_per_block"], int):
+ layers_per_block = config["layers_per_block"]
+ block_id = (i - 1) // (layers_per_block + 1)
+ layer_in_block_id = (i - 1) % (layers_per_block + 1)
+ else:
+ block_id = next(k for k, n in enumerate(layers_per_block_cumsum) if (i - 1) < n)
+ passed_blocks = layers_per_block_cumsum[block_id - 1] if block_id > 0 else 0
+ layer_in_block_id = (i - 1) - passed_blocks
+
+ resnets = [
+ key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
+ ]
+ attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
+
+ if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
+ f"input_blocks.{i}.0.op.weight"
+ )
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
+ f"input_blocks.{i}.0.op.bias"
+ )
+
+ paths = renew_resnet_paths(resnets)
+
+ block_type = config["down_block_types"][block_id]
+ if (
+ block_type == "ResnetDownsampleBlock2D" or block_type == "SimpleCrossAttnDownBlock2D"
+ ) and i in downsampler_ids:
+ meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.downsamplers.0"}
+ else:
+ meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
+
+ assign_to_checkpoint(
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
+ )
+
+ if len(attentions):
+ old_path = f"input_blocks.{i}.1"
+ new_path = f"down_blocks.{block_id}.attentions.{layer_in_block_id}"
+
+ assign_attention_to_checkpoint(
+ new_checkpoint=new_checkpoint,
+ unet_state_dict=unet_state_dict,
+ old_path=old_path,
+ new_path=new_path,
+ config=config,
+ )
+
+ paths = renew_attention_paths(attentions)
+ meta_path = {"old": old_path, "new": new_path}
+ assign_to_checkpoint(
+ paths,
+ new_checkpoint,
+ unet_state_dict,
+ additional_replacements=[meta_path],
+ config=config,
+ )
+
+ resnet_0 = middle_blocks[0]
+ attentions = middle_blocks[1]
+ resnet_1 = middle_blocks[2]
+
+ resnet_0_paths = renew_resnet_paths(resnet_0)
+ assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
+
+ resnet_1_paths = renew_resnet_paths(resnet_1)
+ assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
+
+ old_path = "middle_block.1"
+ new_path = "mid_block.attentions.0"
+
+ assign_attention_to_checkpoint(
+ new_checkpoint=new_checkpoint,
+ unet_state_dict=unet_state_dict,
+ old_path=old_path,
+ new_path=new_path,
+ config=config,
+ )
+
+ attentions_paths = renew_attention_paths(attentions)
+ meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
+ assign_to_checkpoint(
+ attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
+ )
+ if not isinstance(config["layers_per_block"], int):
+ layers_per_block_list = list(reversed([e + 1 for e in config["layers_per_block"]]))
+ layers_per_block_cumsum = list(np.cumsum(layers_per_block_list))
+
+ for i in range(num_output_blocks):
+ if isinstance(config["layers_per_block"], int):
+ layers_per_block = config["layers_per_block"]
+ block_id = i // (layers_per_block + 1)
+ layer_in_block_id = i % (layers_per_block + 1)
+ else:
+ block_id = next(k for k, n in enumerate(layers_per_block_cumsum) if i < n)
+ passed_blocks = layers_per_block_cumsum[block_id - 1] if block_id > 0 else 0
+ layer_in_block_id = i - passed_blocks
+
+ output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
+ output_block_list = {}
+
+ for layer in output_block_layers:
+ layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
+ if layer_id in output_block_list:
+ output_block_list[layer_id].append(layer_name)
+ else:
+ output_block_list[layer_id] = [layer_name]
+
+ # len(output_block_list) == 1 -> resnet
+ # len(output_block_list) == 2 -> resnet, attention or resnet, upscale resnet
+ # len(output_block_list) == 3 -> resnet, attention, upscale resnet
+
+ if len(output_block_list) > 1:
+ resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
+
+ has_attention = True
+ if len(output_block_list) == 2 and any("in_layers" in k for k in output_block_list["1"]):
+ has_attention = False
+
+ maybe_attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
+
+ paths = renew_resnet_paths(resnets)
+
+ meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
+
+ assign_to_checkpoint(
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
+ )
+
+ output_block_list = {k: sorted(v) for k, v in output_block_list.items()}
+ if ["conv.bias", "conv.weight"] in output_block_list.values():
+ index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
+ f"output_blocks.{i}.{index}.conv.weight"
+ ]
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
+ f"output_blocks.{i}.{index}.conv.bias"
+ ]
+
+ # this layer was no attention
+ has_attention = False
+ maybe_attentions = []
+
+ if has_attention:
+ old_path = f"output_blocks.{i}.1"
+ new_path = f"up_blocks.{block_id}.attentions.{layer_in_block_id}"
+
+ assign_attention_to_checkpoint(
+ new_checkpoint=new_checkpoint,
+ unet_state_dict=unet_state_dict,
+ old_path=old_path,
+ new_path=new_path,
+ config=config,
+ )
+
+ paths = renew_attention_paths(maybe_attentions)
+ meta_path = {
+ "old": old_path,
+ "new": new_path,
+ }
+ assign_to_checkpoint(
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
+ )
+
+ if len(output_block_list) == 3 or (not has_attention and len(maybe_attentions) > 0):
+ layer_id = len(output_block_list) - 1
+ resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.{layer_id}" in key]
+ paths = renew_resnet_paths(resnets)
+ meta_path = {"old": f"output_blocks.{i}.{layer_id}", "new": f"up_blocks.{block_id}.upsamplers.0"}
+ assign_to_checkpoint(
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
+ )
+ else:
+ resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
+ for path in resnet_0_paths:
+ old_path = ".".join(["output_blocks", str(i), path["old"]])
+ new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
+
+ new_checkpoint[new_path] = unet_state_dict[old_path]
+
+ return new_checkpoint
+
+
+def verify_param_count(orig_path, unet_diffusers_config):
+ if "-II-" in orig_path:
+ from deepfloyd_if.modules import IFStageII
+
+ if_II = IFStageII(device="cpu", dir_or_name=orig_path)
+ elif "-III-" in orig_path:
+ from deepfloyd_if.modules import IFStageIII
+
+ if_II = IFStageIII(device="cpu", dir_or_name=orig_path)
+ else:
+ assert f"Weird name. Should have -II- or -III- in path: {orig_path}"
+
+ unet = UNet2DConditionModel(**unet_diffusers_config)
+
+ # in params
+ assert_param_count(unet.time_embedding, if_II.model.time_embed)
+ assert_param_count(unet.conv_in, if_II.model.input_blocks[:1])
+
+ # downblocks
+ assert_param_count(unet.down_blocks[0], if_II.model.input_blocks[1:4])
+ assert_param_count(unet.down_blocks[1], if_II.model.input_blocks[4:7])
+ assert_param_count(unet.down_blocks[2], if_II.model.input_blocks[7:11])
+
+ if "-II-" in orig_path:
+ assert_param_count(unet.down_blocks[3], if_II.model.input_blocks[11:17])
+ assert_param_count(unet.down_blocks[4], if_II.model.input_blocks[17:])
+ if "-III-" in orig_path:
+ assert_param_count(unet.down_blocks[3], if_II.model.input_blocks[11:15])
+ assert_param_count(unet.down_blocks[4], if_II.model.input_blocks[15:20])
+ assert_param_count(unet.down_blocks[5], if_II.model.input_blocks[20:])
+
+ # mid block
+ assert_param_count(unet.mid_block, if_II.model.middle_block)
+
+ # up block
+ if "-II-" in orig_path:
+ assert_param_count(unet.up_blocks[0], if_II.model.output_blocks[:6])
+ assert_param_count(unet.up_blocks[1], if_II.model.output_blocks[6:12])
+ assert_param_count(unet.up_blocks[2], if_II.model.output_blocks[12:16])
+ assert_param_count(unet.up_blocks[3], if_II.model.output_blocks[16:19])
+ assert_param_count(unet.up_blocks[4], if_II.model.output_blocks[19:])
+ if "-III-" in orig_path:
+ assert_param_count(unet.up_blocks[0], if_II.model.output_blocks[:5])
+ assert_param_count(unet.up_blocks[1], if_II.model.output_blocks[5:10])
+ assert_param_count(unet.up_blocks[2], if_II.model.output_blocks[10:14])
+ assert_param_count(unet.up_blocks[3], if_II.model.output_blocks[14:18])
+ assert_param_count(unet.up_blocks[4], if_II.model.output_blocks[18:21])
+ assert_param_count(unet.up_blocks[5], if_II.model.output_blocks[21:24])
+
+ # out params
+ assert_param_count(unet.conv_norm_out, if_II.model.out[0])
+ assert_param_count(unet.conv_out, if_II.model.out[2])
+
+ # make sure all model architecture has same param count
+ assert_param_count(unet, if_II.model)
+
+
+def assert_param_count(model_1, model_2):
+ count_1 = sum(p.numel() for p in model_1.parameters())
+ count_2 = sum(p.numel() for p in model_2.parameters())
+ assert count_1 == count_2, f"{model_1.__class__}: {count_1} != {model_2.__class__}: {count_2}"
+
+
+def superres_check_against_original(dump_path, unet_checkpoint_path):
+ model_path = dump_path
+ model = UNet2DConditionModel.from_pretrained(model_path)
+ model.to("cuda")
+ orig_path = unet_checkpoint_path
+
+ if "-II-" in orig_path:
+ from deepfloyd_if.modules import IFStageII
+
+ if_II_model = IFStageII(device="cuda", dir_or_name=orig_path, model_kwargs={"precision": "fp32"}).model
+ elif "-III-" in orig_path:
+ from deepfloyd_if.modules import IFStageIII
+
+ if_II_model = IFStageIII(device="cuda", dir_or_name=orig_path, model_kwargs={"precision": "fp32"}).model
+
+ batch_size = 1
+ channels = model.config.in_channels // 2
+ height = model.config.sample_size
+ width = model.config.sample_size
+ height = 1024
+ width = 1024
+
+ torch.manual_seed(0)
+
+ latents = torch.randn((batch_size, channels, height, width), device=model.device)
+ image_small = torch.randn((batch_size, channels, height // 4, width // 4), device=model.device)
+
+ interpolate_antialias = {}
+ if "antialias" in inspect.signature(F.interpolate).parameters:
+ interpolate_antialias["antialias"] = True
+ image_upscaled = F.interpolate(
+ image_small, size=[height, width], mode="bicubic", align_corners=False, **interpolate_antialias
+ )
+
+ latent_model_input = torch.cat([latents, image_upscaled], dim=1).to(model.dtype)
+ t = torch.tensor([5], device=model.device).to(model.dtype)
+
+ seq_len = 64
+ encoder_hidden_states = torch.randn((batch_size, seq_len, model.config.encoder_hid_dim), device=model.device).to(
+ model.dtype
+ )
+
+ fake_class_labels = torch.tensor([t], device=model.device).to(model.dtype)
+
+ with torch.no_grad():
+ out = if_II_model(latent_model_input, t, aug_steps=fake_class_labels, text_emb=encoder_hidden_states)
+
+ if_II_model.to("cpu")
+ del if_II_model
+ import gc
+
+ torch.cuda.empty_cache()
+ gc.collect()
+ print(50 * "=")
+
+ with torch.no_grad():
+ noise_pred = model(
+ sample=latent_model_input,
+ encoder_hidden_states=encoder_hidden_states,
+ class_labels=fake_class_labels,
+ timestep=t,
+ ).sample
+
+ print("Out shape", noise_pred.shape)
+ print("Diff", (out - noise_pred).abs().sum())
+
+
+if __name__ == "__main__":
+ main(parse_args())
diff --git a/diffusers/scripts/convert_k_upscaler_to_diffusers.py b/diffusers/scripts/convert_k_upscaler_to_diffusers.py
new file mode 100644
index 0000000000000000000000000000000000000000..62abedd737855ca0b0bc9abb75c9b6fb91d5bde2
--- /dev/null
+++ b/diffusers/scripts/convert_k_upscaler_to_diffusers.py
@@ -0,0 +1,297 @@
+import argparse
+
+import huggingface_hub
+import k_diffusion as K
+import torch
+
+from diffusers import UNet2DConditionModel
+
+
+UPSCALER_REPO = "pcuenq/k-upscaler"
+
+
+def resnet_to_diffusers_checkpoint(resnet, checkpoint, *, diffusers_resnet_prefix, resnet_prefix):
+ rv = {
+ # norm1
+ f"{diffusers_resnet_prefix}.norm1.linear.weight": checkpoint[f"{resnet_prefix}.main.0.mapper.weight"],
+ f"{diffusers_resnet_prefix}.norm1.linear.bias": checkpoint[f"{resnet_prefix}.main.0.mapper.bias"],
+ # conv1
+ f"{diffusers_resnet_prefix}.conv1.weight": checkpoint[f"{resnet_prefix}.main.2.weight"],
+ f"{diffusers_resnet_prefix}.conv1.bias": checkpoint[f"{resnet_prefix}.main.2.bias"],
+ # norm2
+ f"{diffusers_resnet_prefix}.norm2.linear.weight": checkpoint[f"{resnet_prefix}.main.4.mapper.weight"],
+ f"{diffusers_resnet_prefix}.norm2.linear.bias": checkpoint[f"{resnet_prefix}.main.4.mapper.bias"],
+ # conv2
+ f"{diffusers_resnet_prefix}.conv2.weight": checkpoint[f"{resnet_prefix}.main.6.weight"],
+ f"{diffusers_resnet_prefix}.conv2.bias": checkpoint[f"{resnet_prefix}.main.6.bias"],
+ }
+
+ if resnet.conv_shortcut is not None:
+ rv.update(
+ {
+ f"{diffusers_resnet_prefix}.conv_shortcut.weight": checkpoint[f"{resnet_prefix}.skip.weight"],
+ }
+ )
+
+ return rv
+
+
+def self_attn_to_diffusers_checkpoint(checkpoint, *, diffusers_attention_prefix, attention_prefix):
+ weight_q, weight_k, weight_v = checkpoint[f"{attention_prefix}.qkv_proj.weight"].chunk(3, dim=0)
+ bias_q, bias_k, bias_v = checkpoint[f"{attention_prefix}.qkv_proj.bias"].chunk(3, dim=0)
+ rv = {
+ # norm
+ f"{diffusers_attention_prefix}.norm1.linear.weight": checkpoint[f"{attention_prefix}.norm_in.mapper.weight"],
+ f"{diffusers_attention_prefix}.norm1.linear.bias": checkpoint[f"{attention_prefix}.norm_in.mapper.bias"],
+ # to_q
+ f"{diffusers_attention_prefix}.attn1.to_q.weight": weight_q.squeeze(-1).squeeze(-1),
+ f"{diffusers_attention_prefix}.attn1.to_q.bias": bias_q,
+ # to_k
+ f"{diffusers_attention_prefix}.attn1.to_k.weight": weight_k.squeeze(-1).squeeze(-1),
+ f"{diffusers_attention_prefix}.attn1.to_k.bias": bias_k,
+ # to_v
+ f"{diffusers_attention_prefix}.attn1.to_v.weight": weight_v.squeeze(-1).squeeze(-1),
+ f"{diffusers_attention_prefix}.attn1.to_v.bias": bias_v,
+ # to_out
+ f"{diffusers_attention_prefix}.attn1.to_out.0.weight": checkpoint[f"{attention_prefix}.out_proj.weight"]
+ .squeeze(-1)
+ .squeeze(-1),
+ f"{diffusers_attention_prefix}.attn1.to_out.0.bias": checkpoint[f"{attention_prefix}.out_proj.bias"],
+ }
+
+ return rv
+
+
+def cross_attn_to_diffusers_checkpoint(
+ checkpoint, *, diffusers_attention_prefix, diffusers_attention_index, attention_prefix
+):
+ weight_k, weight_v = checkpoint[f"{attention_prefix}.kv_proj.weight"].chunk(2, dim=0)
+ bias_k, bias_v = checkpoint[f"{attention_prefix}.kv_proj.bias"].chunk(2, dim=0)
+
+ rv = {
+ # norm2 (ada groupnorm)
+ f"{diffusers_attention_prefix}.norm{diffusers_attention_index}.linear.weight": checkpoint[
+ f"{attention_prefix}.norm_dec.mapper.weight"
+ ],
+ f"{diffusers_attention_prefix}.norm{diffusers_attention_index}.linear.bias": checkpoint[
+ f"{attention_prefix}.norm_dec.mapper.bias"
+ ],
+ # layernorm on encoder_hidden_state
+ f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.norm_cross.weight": checkpoint[
+ f"{attention_prefix}.norm_enc.weight"
+ ],
+ f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.norm_cross.bias": checkpoint[
+ f"{attention_prefix}.norm_enc.bias"
+ ],
+ # to_q
+ f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_q.weight": checkpoint[
+ f"{attention_prefix}.q_proj.weight"
+ ]
+ .squeeze(-1)
+ .squeeze(-1),
+ f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_q.bias": checkpoint[
+ f"{attention_prefix}.q_proj.bias"
+ ],
+ # to_k
+ f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_k.weight": weight_k.squeeze(-1).squeeze(-1),
+ f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_k.bias": bias_k,
+ # to_v
+ f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_v.weight": weight_v.squeeze(-1).squeeze(-1),
+ f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_v.bias": bias_v,
+ # to_out
+ f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_out.0.weight": checkpoint[
+ f"{attention_prefix}.out_proj.weight"
+ ]
+ .squeeze(-1)
+ .squeeze(-1),
+ f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_out.0.bias": checkpoint[
+ f"{attention_prefix}.out_proj.bias"
+ ],
+ }
+
+ return rv
+
+
+def block_to_diffusers_checkpoint(block, checkpoint, block_idx, block_type):
+ block_prefix = "inner_model.u_net.u_blocks" if block_type == "up" else "inner_model.u_net.d_blocks"
+ block_prefix = f"{block_prefix}.{block_idx}"
+
+ diffusers_checkpoint = {}
+
+ if not hasattr(block, "attentions"):
+ n = 1 # resnet only
+ elif not block.attentions[0].add_self_attention:
+ n = 2 # resnet -> cross-attention
+ else:
+ n = 3 # resnet -> self-attention -> cross-attention)
+
+ for resnet_idx, resnet in enumerate(block.resnets):
+ # diffusers_resnet_prefix = f"{diffusers_up_block_prefix}.resnets.{resnet_idx}"
+ diffusers_resnet_prefix = f"{block_type}_blocks.{block_idx}.resnets.{resnet_idx}"
+ idx = n * resnet_idx if block_type == "up" else n * resnet_idx + 1
+ resnet_prefix = f"{block_prefix}.{idx}" if block_type == "up" else f"{block_prefix}.{idx}"
+
+ diffusers_checkpoint.update(
+ resnet_to_diffusers_checkpoint(
+ resnet, checkpoint, diffusers_resnet_prefix=diffusers_resnet_prefix, resnet_prefix=resnet_prefix
+ )
+ )
+
+ if hasattr(block, "attentions"):
+ for attention_idx, attention in enumerate(block.attentions):
+ diffusers_attention_prefix = f"{block_type}_blocks.{block_idx}.attentions.{attention_idx}"
+ idx = n * attention_idx + 1 if block_type == "up" else n * attention_idx + 2
+ self_attention_prefix = f"{block_prefix}.{idx}"
+ cross_attention_prefix = f"{block_prefix}.{idx }"
+ cross_attention_index = 1 if not attention.add_self_attention else 2
+ idx = (
+ n * attention_idx + cross_attention_index
+ if block_type == "up"
+ else n * attention_idx + cross_attention_index + 1
+ )
+ cross_attention_prefix = f"{block_prefix}.{idx }"
+
+ diffusers_checkpoint.update(
+ cross_attn_to_diffusers_checkpoint(
+ checkpoint,
+ diffusers_attention_prefix=diffusers_attention_prefix,
+ diffusers_attention_index=2,
+ attention_prefix=cross_attention_prefix,
+ )
+ )
+
+ if attention.add_self_attention is True:
+ diffusers_checkpoint.update(
+ self_attn_to_diffusers_checkpoint(
+ checkpoint,
+ diffusers_attention_prefix=diffusers_attention_prefix,
+ attention_prefix=self_attention_prefix,
+ )
+ )
+
+ return diffusers_checkpoint
+
+
+def unet_to_diffusers_checkpoint(model, checkpoint):
+ diffusers_checkpoint = {}
+
+ # pre-processing
+ diffusers_checkpoint.update(
+ {
+ "conv_in.weight": checkpoint["inner_model.proj_in.weight"],
+ "conv_in.bias": checkpoint["inner_model.proj_in.bias"],
+ }
+ )
+
+ # timestep and class embedding
+ diffusers_checkpoint.update(
+ {
+ "time_proj.weight": checkpoint["inner_model.timestep_embed.weight"].squeeze(-1),
+ "time_embedding.linear_1.weight": checkpoint["inner_model.mapping.0.weight"],
+ "time_embedding.linear_1.bias": checkpoint["inner_model.mapping.0.bias"],
+ "time_embedding.linear_2.weight": checkpoint["inner_model.mapping.2.weight"],
+ "time_embedding.linear_2.bias": checkpoint["inner_model.mapping.2.bias"],
+ "time_embedding.cond_proj.weight": checkpoint["inner_model.mapping_cond.weight"],
+ }
+ )
+
+ # down_blocks
+ for down_block_idx, down_block in enumerate(model.down_blocks):
+ diffusers_checkpoint.update(block_to_diffusers_checkpoint(down_block, checkpoint, down_block_idx, "down"))
+
+ # up_blocks
+ for up_block_idx, up_block in enumerate(model.up_blocks):
+ diffusers_checkpoint.update(block_to_diffusers_checkpoint(up_block, checkpoint, up_block_idx, "up"))
+
+ # post-processing
+ diffusers_checkpoint.update(
+ {
+ "conv_out.weight": checkpoint["inner_model.proj_out.weight"],
+ "conv_out.bias": checkpoint["inner_model.proj_out.bias"],
+ }
+ )
+
+ return diffusers_checkpoint
+
+
+def unet_model_from_original_config(original_config):
+ in_channels = original_config["input_channels"] + original_config["unet_cond_dim"]
+ out_channels = original_config["input_channels"] + (1 if original_config["has_variance"] else 0)
+
+ block_out_channels = original_config["channels"]
+
+ assert (
+ len(set(original_config["depths"])) == 1
+ ), "UNet2DConditionModel currently do not support blocks with different number of layers"
+ layers_per_block = original_config["depths"][0]
+
+ class_labels_dim = original_config["mapping_cond_dim"]
+ cross_attention_dim = original_config["cross_cond_dim"]
+
+ attn1_types = []
+ attn2_types = []
+ for s, c in zip(original_config["self_attn_depths"], original_config["cross_attn_depths"]):
+ if s:
+ a1 = "self"
+ a2 = "cross" if c else None
+ elif c:
+ a1 = "cross"
+ a2 = None
+ else:
+ a1 = None
+ a2 = None
+ attn1_types.append(a1)
+ attn2_types.append(a2)
+
+ unet = UNet2DConditionModel(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ down_block_types=("KDownBlock2D", "KCrossAttnDownBlock2D", "KCrossAttnDownBlock2D", "KCrossAttnDownBlock2D"),
+ mid_block_type=None,
+ up_block_types=("KCrossAttnUpBlock2D", "KCrossAttnUpBlock2D", "KCrossAttnUpBlock2D", "KUpBlock2D"),
+ block_out_channels=block_out_channels,
+ layers_per_block=layers_per_block,
+ act_fn="gelu",
+ norm_num_groups=None,
+ cross_attention_dim=cross_attention_dim,
+ attention_head_dim=64,
+ time_cond_proj_dim=class_labels_dim,
+ resnet_time_scale_shift="scale_shift",
+ time_embedding_type="fourier",
+ timestep_post_act="gelu",
+ conv_in_kernel=1,
+ conv_out_kernel=1,
+ )
+
+ return unet
+
+
+def main(args):
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+ orig_config_path = huggingface_hub.hf_hub_download(UPSCALER_REPO, "config_laion_text_cond_latent_upscaler_2.json")
+ orig_weights_path = huggingface_hub.hf_hub_download(
+ UPSCALER_REPO, "laion_text_cond_latent_upscaler_2_1_00470000_slim.pth"
+ )
+ print(f"loading original model configuration from {orig_config_path}")
+ print(f"loading original model checkpoint from {orig_weights_path}")
+
+ print("converting to diffusers unet")
+ orig_config = K.config.load_config(open(orig_config_path))["model"]
+ model = unet_model_from_original_config(orig_config)
+
+ orig_checkpoint = torch.load(orig_weights_path, map_location=device)["model_ema"]
+ converted_checkpoint = unet_to_diffusers_checkpoint(model, orig_checkpoint)
+
+ model.load_state_dict(converted_checkpoint, strict=True)
+ model.save_pretrained(args.dump_path)
+ print(f"saving converted unet model in {args.dump_path}")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
+ args = parser.parse_args()
+
+ main(args)
diff --git a/diffusers/scripts/convert_kakao_brain_unclip_to_diffusers.py b/diffusers/scripts/convert_kakao_brain_unclip_to_diffusers.py
new file mode 100644
index 0000000000000000000000000000000000000000..5135eaed5b98391733185122f7f147201f4e4591
--- /dev/null
+++ b/diffusers/scripts/convert_kakao_brain_unclip_to_diffusers.py
@@ -0,0 +1,1159 @@
+import argparse
+import tempfile
+
+import torch
+from accelerate import load_checkpoint_and_dispatch
+from transformers import CLIPTextModelWithProjection, CLIPTokenizer
+
+from diffusers import UnCLIPPipeline, UNet2DConditionModel, UNet2DModel
+from diffusers.models.transformers.prior_transformer import PriorTransformer
+from diffusers.pipelines.unclip.text_proj import UnCLIPTextProjModel
+from diffusers.schedulers.scheduling_unclip import UnCLIPScheduler
+
+
+r"""
+Example - From the diffusers root directory:
+
+Download weights:
+```sh
+$ wget https://arena.kakaocdn.net/brainrepo/models/karlo-public/v1.0.0.alpha/efdf6206d8ed593961593dc029a8affa/decoder-ckpt-step%3D01000000-of-01000000.ckpt
+$ wget https://arena.kakaocdn.net/brainrepo/models/karlo-public/v1.0.0.alpha/4226b831ae0279020d134281f3c31590/improved-sr-ckpt-step%3D1.2M.ckpt
+$ wget https://arena.kakaocdn.net/brainrepo/models/karlo-public/v1.0.0.alpha/85626483eaca9f581e2a78d31ff905ca/prior-ckpt-step%3D01000000-of-01000000.ckpt
+$ wget https://arena.kakaocdn.net/brainrepo/models/karlo-public/v1.0.0.alpha/0b62380a75e56f073e2844ab5199153d/ViT-L-14_stats.th
+```
+
+Convert the model:
+```sh
+$ python scripts/convert_kakao_brain_unclip_to_diffusers.py \
+ --decoder_checkpoint_path ./decoder-ckpt-step\=01000000-of-01000000.ckpt \
+ --super_res_unet_checkpoint_path ./improved-sr-ckpt-step\=1.2M.ckpt \
+ --prior_checkpoint_path ./prior-ckpt-step\=01000000-of-01000000.ckpt \
+ --clip_stat_path ./ViT-L-14_stats.th \
+ --dump_path
+```
+"""
+
+
+# prior
+
+PRIOR_ORIGINAL_PREFIX = "model"
+
+# Uses default arguments
+PRIOR_CONFIG = {}
+
+
+def prior_model_from_original_config():
+ model = PriorTransformer(**PRIOR_CONFIG)
+
+ return model
+
+
+def prior_original_checkpoint_to_diffusers_checkpoint(model, checkpoint, clip_stats_checkpoint):
+ diffusers_checkpoint = {}
+
+ # .time_embed.0 -> .time_embedding.linear_1
+ diffusers_checkpoint.update(
+ {
+ "time_embedding.linear_1.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.time_embed.0.weight"],
+ "time_embedding.linear_1.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.time_embed.0.bias"],
+ }
+ )
+
+ # .clip_img_proj -> .proj_in
+ diffusers_checkpoint.update(
+ {
+ "proj_in.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.clip_img_proj.weight"],
+ "proj_in.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.clip_img_proj.bias"],
+ }
+ )
+
+ # .text_emb_proj -> .embedding_proj
+ diffusers_checkpoint.update(
+ {
+ "embedding_proj.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.text_emb_proj.weight"],
+ "embedding_proj.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.text_emb_proj.bias"],
+ }
+ )
+
+ # .text_enc_proj -> .encoder_hidden_states_proj
+ diffusers_checkpoint.update(
+ {
+ "encoder_hidden_states_proj.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.text_enc_proj.weight"],
+ "encoder_hidden_states_proj.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.text_enc_proj.bias"],
+ }
+ )
+
+ # .positional_embedding -> .positional_embedding
+ diffusers_checkpoint.update({"positional_embedding": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.positional_embedding"]})
+
+ # .prd_emb -> .prd_embedding
+ diffusers_checkpoint.update({"prd_embedding": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.prd_emb"]})
+
+ # .time_embed.2 -> .time_embedding.linear_2
+ diffusers_checkpoint.update(
+ {
+ "time_embedding.linear_2.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.time_embed.2.weight"],
+ "time_embedding.linear_2.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.time_embed.2.bias"],
+ }
+ )
+
+ # .resblocks. -> .transformer_blocks.
+ for idx in range(len(model.transformer_blocks)):
+ diffusers_transformer_prefix = f"transformer_blocks.{idx}"
+ original_transformer_prefix = f"{PRIOR_ORIGINAL_PREFIX}.transformer.resblocks.{idx}"
+
+ # .attn -> .attn1
+ diffusers_attention_prefix = f"{diffusers_transformer_prefix}.attn1"
+ original_attention_prefix = f"{original_transformer_prefix}.attn"
+ diffusers_checkpoint.update(
+ prior_attention_to_diffusers(
+ checkpoint,
+ diffusers_attention_prefix=diffusers_attention_prefix,
+ original_attention_prefix=original_attention_prefix,
+ attention_head_dim=model.attention_head_dim,
+ )
+ )
+
+ # .mlp -> .ff
+ diffusers_ff_prefix = f"{diffusers_transformer_prefix}.ff"
+ original_ff_prefix = f"{original_transformer_prefix}.mlp"
+ diffusers_checkpoint.update(
+ prior_ff_to_diffusers(
+ checkpoint, diffusers_ff_prefix=diffusers_ff_prefix, original_ff_prefix=original_ff_prefix
+ )
+ )
+
+ # .ln_1 -> .norm1
+ diffusers_checkpoint.update(
+ {
+ f"{diffusers_transformer_prefix}.norm1.weight": checkpoint[
+ f"{original_transformer_prefix}.ln_1.weight"
+ ],
+ f"{diffusers_transformer_prefix}.norm1.bias": checkpoint[f"{original_transformer_prefix}.ln_1.bias"],
+ }
+ )
+
+ # .ln_2 -> .norm3
+ diffusers_checkpoint.update(
+ {
+ f"{diffusers_transformer_prefix}.norm3.weight": checkpoint[
+ f"{original_transformer_prefix}.ln_2.weight"
+ ],
+ f"{diffusers_transformer_prefix}.norm3.bias": checkpoint[f"{original_transformer_prefix}.ln_2.bias"],
+ }
+ )
+
+ # .final_ln -> .norm_out
+ diffusers_checkpoint.update(
+ {
+ "norm_out.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.final_ln.weight"],
+ "norm_out.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.final_ln.bias"],
+ }
+ )
+
+ # .out_proj -> .proj_to_clip_embeddings
+ diffusers_checkpoint.update(
+ {
+ "proj_to_clip_embeddings.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.out_proj.weight"],
+ "proj_to_clip_embeddings.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.out_proj.bias"],
+ }
+ )
+
+ # clip stats
+ clip_mean, clip_std = clip_stats_checkpoint
+ clip_mean = clip_mean[None, :]
+ clip_std = clip_std[None, :]
+
+ diffusers_checkpoint.update({"clip_mean": clip_mean, "clip_std": clip_std})
+
+ return diffusers_checkpoint
+
+
+def prior_attention_to_diffusers(
+ checkpoint, *, diffusers_attention_prefix, original_attention_prefix, attention_head_dim
+):
+ diffusers_checkpoint = {}
+
+ # .c_qkv -> .{to_q, to_k, to_v}
+ [q_weight, k_weight, v_weight], [q_bias, k_bias, v_bias] = split_attentions(
+ weight=checkpoint[f"{original_attention_prefix}.c_qkv.weight"],
+ bias=checkpoint[f"{original_attention_prefix}.c_qkv.bias"],
+ split=3,
+ chunk_size=attention_head_dim,
+ )
+
+ diffusers_checkpoint.update(
+ {
+ f"{diffusers_attention_prefix}.to_q.weight": q_weight,
+ f"{diffusers_attention_prefix}.to_q.bias": q_bias,
+ f"{diffusers_attention_prefix}.to_k.weight": k_weight,
+ f"{diffusers_attention_prefix}.to_k.bias": k_bias,
+ f"{diffusers_attention_prefix}.to_v.weight": v_weight,
+ f"{diffusers_attention_prefix}.to_v.bias": v_bias,
+ }
+ )
+
+ # .c_proj -> .to_out.0
+ diffusers_checkpoint.update(
+ {
+ f"{diffusers_attention_prefix}.to_out.0.weight": checkpoint[f"{original_attention_prefix}.c_proj.weight"],
+ f"{diffusers_attention_prefix}.to_out.0.bias": checkpoint[f"{original_attention_prefix}.c_proj.bias"],
+ }
+ )
+
+ return diffusers_checkpoint
+
+
+def prior_ff_to_diffusers(checkpoint, *, diffusers_ff_prefix, original_ff_prefix):
+ diffusers_checkpoint = {
+ # .c_fc -> .net.0.proj
+ f"{diffusers_ff_prefix}.net.{0}.proj.weight": checkpoint[f"{original_ff_prefix}.c_fc.weight"],
+ f"{diffusers_ff_prefix}.net.{0}.proj.bias": checkpoint[f"{original_ff_prefix}.c_fc.bias"],
+ # .c_proj -> .net.2
+ f"{diffusers_ff_prefix}.net.{2}.weight": checkpoint[f"{original_ff_prefix}.c_proj.weight"],
+ f"{diffusers_ff_prefix}.net.{2}.bias": checkpoint[f"{original_ff_prefix}.c_proj.bias"],
+ }
+
+ return diffusers_checkpoint
+
+
+# done prior
+
+
+# decoder
+
+DECODER_ORIGINAL_PREFIX = "model"
+
+# We are hardcoding the model configuration for now. If we need to generalize to more model configurations, we can
+# update then.
+DECODER_CONFIG = {
+ "sample_size": 64,
+ "layers_per_block": 3,
+ "down_block_types": (
+ "ResnetDownsampleBlock2D",
+ "SimpleCrossAttnDownBlock2D",
+ "SimpleCrossAttnDownBlock2D",
+ "SimpleCrossAttnDownBlock2D",
+ ),
+ "up_block_types": (
+ "SimpleCrossAttnUpBlock2D",
+ "SimpleCrossAttnUpBlock2D",
+ "SimpleCrossAttnUpBlock2D",
+ "ResnetUpsampleBlock2D",
+ ),
+ "mid_block_type": "UNetMidBlock2DSimpleCrossAttn",
+ "block_out_channels": (320, 640, 960, 1280),
+ "in_channels": 3,
+ "out_channels": 6,
+ "cross_attention_dim": 1536,
+ "class_embed_type": "identity",
+ "attention_head_dim": 64,
+ "resnet_time_scale_shift": "scale_shift",
+}
+
+
+def decoder_model_from_original_config():
+ model = UNet2DConditionModel(**DECODER_CONFIG)
+
+ return model
+
+
+def decoder_original_checkpoint_to_diffusers_checkpoint(model, checkpoint):
+ diffusers_checkpoint = {}
+
+ original_unet_prefix = DECODER_ORIGINAL_PREFIX
+ num_head_channels = DECODER_CONFIG["attention_head_dim"]
+
+ diffusers_checkpoint.update(unet_time_embeddings(checkpoint, original_unet_prefix))
+ diffusers_checkpoint.update(unet_conv_in(checkpoint, original_unet_prefix))
+
+ # .input_blocks -> .down_blocks
+
+ original_down_block_idx = 1
+
+ for diffusers_down_block_idx in range(len(model.down_blocks)):
+ checkpoint_update, num_original_down_blocks = unet_downblock_to_diffusers_checkpoint(
+ model,
+ checkpoint,
+ diffusers_down_block_idx=diffusers_down_block_idx,
+ original_down_block_idx=original_down_block_idx,
+ original_unet_prefix=original_unet_prefix,
+ num_head_channels=num_head_channels,
+ )
+
+ original_down_block_idx += num_original_down_blocks
+
+ diffusers_checkpoint.update(checkpoint_update)
+
+ # done .input_blocks -> .down_blocks
+
+ diffusers_checkpoint.update(
+ unet_midblock_to_diffusers_checkpoint(
+ model,
+ checkpoint,
+ original_unet_prefix=original_unet_prefix,
+ num_head_channels=num_head_channels,
+ )
+ )
+
+ # .output_blocks -> .up_blocks
+
+ original_up_block_idx = 0
+
+ for diffusers_up_block_idx in range(len(model.up_blocks)):
+ checkpoint_update, num_original_up_blocks = unet_upblock_to_diffusers_checkpoint(
+ model,
+ checkpoint,
+ diffusers_up_block_idx=diffusers_up_block_idx,
+ original_up_block_idx=original_up_block_idx,
+ original_unet_prefix=original_unet_prefix,
+ num_head_channels=num_head_channels,
+ )
+
+ original_up_block_idx += num_original_up_blocks
+
+ diffusers_checkpoint.update(checkpoint_update)
+
+ # done .output_blocks -> .up_blocks
+
+ diffusers_checkpoint.update(unet_conv_norm_out(checkpoint, original_unet_prefix))
+ diffusers_checkpoint.update(unet_conv_out(checkpoint, original_unet_prefix))
+
+ return diffusers_checkpoint
+
+
+# done decoder
+
+# text proj
+
+
+def text_proj_from_original_config():
+ # From the conditional unet constructor where the dimension of the projected time embeddings is
+ # constructed
+ time_embed_dim = DECODER_CONFIG["block_out_channels"][0] * 4
+
+ cross_attention_dim = DECODER_CONFIG["cross_attention_dim"]
+
+ model = UnCLIPTextProjModel(time_embed_dim=time_embed_dim, cross_attention_dim=cross_attention_dim)
+
+ return model
+
+
+# Note that the input checkpoint is the original decoder checkpoint
+def text_proj_original_checkpoint_to_diffusers_checkpoint(checkpoint):
+ diffusers_checkpoint = {
+ # .text_seq_proj.0 -> .encoder_hidden_states_proj
+ "encoder_hidden_states_proj.weight": checkpoint[f"{DECODER_ORIGINAL_PREFIX}.text_seq_proj.0.weight"],
+ "encoder_hidden_states_proj.bias": checkpoint[f"{DECODER_ORIGINAL_PREFIX}.text_seq_proj.0.bias"],
+ # .text_seq_proj.1 -> .text_encoder_hidden_states_norm
+ "text_encoder_hidden_states_norm.weight": checkpoint[f"{DECODER_ORIGINAL_PREFIX}.text_seq_proj.1.weight"],
+ "text_encoder_hidden_states_norm.bias": checkpoint[f"{DECODER_ORIGINAL_PREFIX}.text_seq_proj.1.bias"],
+ # .clip_tok_proj -> .clip_extra_context_tokens_proj
+ "clip_extra_context_tokens_proj.weight": checkpoint[f"{DECODER_ORIGINAL_PREFIX}.clip_tok_proj.weight"],
+ "clip_extra_context_tokens_proj.bias": checkpoint[f"{DECODER_ORIGINAL_PREFIX}.clip_tok_proj.bias"],
+ # .text_feat_proj -> .embedding_proj
+ "embedding_proj.weight": checkpoint[f"{DECODER_ORIGINAL_PREFIX}.text_feat_proj.weight"],
+ "embedding_proj.bias": checkpoint[f"{DECODER_ORIGINAL_PREFIX}.text_feat_proj.bias"],
+ # .cf_param -> .learned_classifier_free_guidance_embeddings
+ "learned_classifier_free_guidance_embeddings": checkpoint[f"{DECODER_ORIGINAL_PREFIX}.cf_param"],
+ # .clip_emb -> .clip_image_embeddings_project_to_time_embeddings
+ "clip_image_embeddings_project_to_time_embeddings.weight": checkpoint[
+ f"{DECODER_ORIGINAL_PREFIX}.clip_emb.weight"
+ ],
+ "clip_image_embeddings_project_to_time_embeddings.bias": checkpoint[
+ f"{DECODER_ORIGINAL_PREFIX}.clip_emb.bias"
+ ],
+ }
+
+ return diffusers_checkpoint
+
+
+# done text proj
+
+# super res unet first steps
+
+SUPER_RES_UNET_FIRST_STEPS_PREFIX = "model_first_steps"
+
+SUPER_RES_UNET_FIRST_STEPS_CONFIG = {
+ "sample_size": 256,
+ "layers_per_block": 3,
+ "down_block_types": (
+ "ResnetDownsampleBlock2D",
+ "ResnetDownsampleBlock2D",
+ "ResnetDownsampleBlock2D",
+ "ResnetDownsampleBlock2D",
+ ),
+ "up_block_types": (
+ "ResnetUpsampleBlock2D",
+ "ResnetUpsampleBlock2D",
+ "ResnetUpsampleBlock2D",
+ "ResnetUpsampleBlock2D",
+ ),
+ "block_out_channels": (320, 640, 960, 1280),
+ "in_channels": 6,
+ "out_channels": 3,
+ "add_attention": False,
+}
+
+
+def super_res_unet_first_steps_model_from_original_config():
+ model = UNet2DModel(**SUPER_RES_UNET_FIRST_STEPS_CONFIG)
+
+ return model
+
+
+def super_res_unet_first_steps_original_checkpoint_to_diffusers_checkpoint(model, checkpoint):
+ diffusers_checkpoint = {}
+
+ original_unet_prefix = SUPER_RES_UNET_FIRST_STEPS_PREFIX
+
+ diffusers_checkpoint.update(unet_time_embeddings(checkpoint, original_unet_prefix))
+ diffusers_checkpoint.update(unet_conv_in(checkpoint, original_unet_prefix))
+
+ # .input_blocks -> .down_blocks
+
+ original_down_block_idx = 1
+
+ for diffusers_down_block_idx in range(len(model.down_blocks)):
+ checkpoint_update, num_original_down_blocks = unet_downblock_to_diffusers_checkpoint(
+ model,
+ checkpoint,
+ diffusers_down_block_idx=diffusers_down_block_idx,
+ original_down_block_idx=original_down_block_idx,
+ original_unet_prefix=original_unet_prefix,
+ num_head_channels=None,
+ )
+
+ original_down_block_idx += num_original_down_blocks
+
+ diffusers_checkpoint.update(checkpoint_update)
+
+ diffusers_checkpoint.update(
+ unet_midblock_to_diffusers_checkpoint(
+ model,
+ checkpoint,
+ original_unet_prefix=original_unet_prefix,
+ num_head_channels=None,
+ )
+ )
+
+ # .output_blocks -> .up_blocks
+
+ original_up_block_idx = 0
+
+ for diffusers_up_block_idx in range(len(model.up_blocks)):
+ checkpoint_update, num_original_up_blocks = unet_upblock_to_diffusers_checkpoint(
+ model,
+ checkpoint,
+ diffusers_up_block_idx=diffusers_up_block_idx,
+ original_up_block_idx=original_up_block_idx,
+ original_unet_prefix=original_unet_prefix,
+ num_head_channels=None,
+ )
+
+ original_up_block_idx += num_original_up_blocks
+
+ diffusers_checkpoint.update(checkpoint_update)
+
+ # done .output_blocks -> .up_blocks
+
+ diffusers_checkpoint.update(unet_conv_norm_out(checkpoint, original_unet_prefix))
+ diffusers_checkpoint.update(unet_conv_out(checkpoint, original_unet_prefix))
+
+ return diffusers_checkpoint
+
+
+# done super res unet first steps
+
+# super res unet last step
+
+SUPER_RES_UNET_LAST_STEP_PREFIX = "model_last_step"
+
+SUPER_RES_UNET_LAST_STEP_CONFIG = {
+ "sample_size": 256,
+ "layers_per_block": 3,
+ "down_block_types": (
+ "ResnetDownsampleBlock2D",
+ "ResnetDownsampleBlock2D",
+ "ResnetDownsampleBlock2D",
+ "ResnetDownsampleBlock2D",
+ ),
+ "up_block_types": (
+ "ResnetUpsampleBlock2D",
+ "ResnetUpsampleBlock2D",
+ "ResnetUpsampleBlock2D",
+ "ResnetUpsampleBlock2D",
+ ),
+ "block_out_channels": (320, 640, 960, 1280),
+ "in_channels": 6,
+ "out_channels": 3,
+ "add_attention": False,
+}
+
+
+def super_res_unet_last_step_model_from_original_config():
+ model = UNet2DModel(**SUPER_RES_UNET_LAST_STEP_CONFIG)
+
+ return model
+
+
+def super_res_unet_last_step_original_checkpoint_to_diffusers_checkpoint(model, checkpoint):
+ diffusers_checkpoint = {}
+
+ original_unet_prefix = SUPER_RES_UNET_LAST_STEP_PREFIX
+
+ diffusers_checkpoint.update(unet_time_embeddings(checkpoint, original_unet_prefix))
+ diffusers_checkpoint.update(unet_conv_in(checkpoint, original_unet_prefix))
+
+ # .input_blocks -> .down_blocks
+
+ original_down_block_idx = 1
+
+ for diffusers_down_block_idx in range(len(model.down_blocks)):
+ checkpoint_update, num_original_down_blocks = unet_downblock_to_diffusers_checkpoint(
+ model,
+ checkpoint,
+ diffusers_down_block_idx=diffusers_down_block_idx,
+ original_down_block_idx=original_down_block_idx,
+ original_unet_prefix=original_unet_prefix,
+ num_head_channels=None,
+ )
+
+ original_down_block_idx += num_original_down_blocks
+
+ diffusers_checkpoint.update(checkpoint_update)
+
+ diffusers_checkpoint.update(
+ unet_midblock_to_diffusers_checkpoint(
+ model,
+ checkpoint,
+ original_unet_prefix=original_unet_prefix,
+ num_head_channels=None,
+ )
+ )
+
+ # .output_blocks -> .up_blocks
+
+ original_up_block_idx = 0
+
+ for diffusers_up_block_idx in range(len(model.up_blocks)):
+ checkpoint_update, num_original_up_blocks = unet_upblock_to_diffusers_checkpoint(
+ model,
+ checkpoint,
+ diffusers_up_block_idx=diffusers_up_block_idx,
+ original_up_block_idx=original_up_block_idx,
+ original_unet_prefix=original_unet_prefix,
+ num_head_channels=None,
+ )
+
+ original_up_block_idx += num_original_up_blocks
+
+ diffusers_checkpoint.update(checkpoint_update)
+
+ # done .output_blocks -> .up_blocks
+
+ diffusers_checkpoint.update(unet_conv_norm_out(checkpoint, original_unet_prefix))
+ diffusers_checkpoint.update(unet_conv_out(checkpoint, original_unet_prefix))
+
+ return diffusers_checkpoint
+
+
+# done super res unet last step
+
+
+# unet utils
+
+
+# .time_embed -> .time_embedding
+def unet_time_embeddings(checkpoint, original_unet_prefix):
+ diffusers_checkpoint = {}
+
+ diffusers_checkpoint.update(
+ {
+ "time_embedding.linear_1.weight": checkpoint[f"{original_unet_prefix}.time_embed.0.weight"],
+ "time_embedding.linear_1.bias": checkpoint[f"{original_unet_prefix}.time_embed.0.bias"],
+ "time_embedding.linear_2.weight": checkpoint[f"{original_unet_prefix}.time_embed.2.weight"],
+ "time_embedding.linear_2.bias": checkpoint[f"{original_unet_prefix}.time_embed.2.bias"],
+ }
+ )
+
+ return diffusers_checkpoint
+
+
+# .input_blocks.0 -> .conv_in
+def unet_conv_in(checkpoint, original_unet_prefix):
+ diffusers_checkpoint = {}
+
+ diffusers_checkpoint.update(
+ {
+ "conv_in.weight": checkpoint[f"{original_unet_prefix}.input_blocks.0.0.weight"],
+ "conv_in.bias": checkpoint[f"{original_unet_prefix}.input_blocks.0.0.bias"],
+ }
+ )
+
+ return diffusers_checkpoint
+
+
+# .out.0 -> .conv_norm_out
+def unet_conv_norm_out(checkpoint, original_unet_prefix):
+ diffusers_checkpoint = {}
+
+ diffusers_checkpoint.update(
+ {
+ "conv_norm_out.weight": checkpoint[f"{original_unet_prefix}.out.0.weight"],
+ "conv_norm_out.bias": checkpoint[f"{original_unet_prefix}.out.0.bias"],
+ }
+ )
+
+ return diffusers_checkpoint
+
+
+# .out.2 -> .conv_out
+def unet_conv_out(checkpoint, original_unet_prefix):
+ diffusers_checkpoint = {}
+
+ diffusers_checkpoint.update(
+ {
+ "conv_out.weight": checkpoint[f"{original_unet_prefix}.out.2.weight"],
+ "conv_out.bias": checkpoint[f"{original_unet_prefix}.out.2.bias"],
+ }
+ )
+
+ return diffusers_checkpoint
+
+
+# .input_blocks -> .down_blocks
+def unet_downblock_to_diffusers_checkpoint(
+ model, checkpoint, *, diffusers_down_block_idx, original_down_block_idx, original_unet_prefix, num_head_channels
+):
+ diffusers_checkpoint = {}
+
+ diffusers_resnet_prefix = f"down_blocks.{diffusers_down_block_idx}.resnets"
+ original_down_block_prefix = f"{original_unet_prefix}.input_blocks"
+
+ down_block = model.down_blocks[diffusers_down_block_idx]
+
+ num_resnets = len(down_block.resnets)
+
+ if down_block.downsamplers is None:
+ downsampler = False
+ else:
+ assert len(down_block.downsamplers) == 1
+ downsampler = True
+ # The downsample block is also a resnet
+ num_resnets += 1
+
+ for resnet_idx_inc in range(num_resnets):
+ full_resnet_prefix = f"{original_down_block_prefix}.{original_down_block_idx + resnet_idx_inc}.0"
+
+ if downsampler and resnet_idx_inc == num_resnets - 1:
+ # this is a downsample block
+ full_diffusers_resnet_prefix = f"down_blocks.{diffusers_down_block_idx}.downsamplers.0"
+ else:
+ # this is a regular resnet block
+ full_diffusers_resnet_prefix = f"{diffusers_resnet_prefix}.{resnet_idx_inc}"
+
+ diffusers_checkpoint.update(
+ resnet_to_diffusers_checkpoint(
+ checkpoint, resnet_prefix=full_resnet_prefix, diffusers_resnet_prefix=full_diffusers_resnet_prefix
+ )
+ )
+
+ if hasattr(down_block, "attentions"):
+ num_attentions = len(down_block.attentions)
+ diffusers_attention_prefix = f"down_blocks.{diffusers_down_block_idx}.attentions"
+
+ for attention_idx_inc in range(num_attentions):
+ full_attention_prefix = f"{original_down_block_prefix}.{original_down_block_idx + attention_idx_inc}.1"
+ full_diffusers_attention_prefix = f"{diffusers_attention_prefix}.{attention_idx_inc}"
+
+ diffusers_checkpoint.update(
+ attention_to_diffusers_checkpoint(
+ checkpoint,
+ attention_prefix=full_attention_prefix,
+ diffusers_attention_prefix=full_diffusers_attention_prefix,
+ num_head_channels=num_head_channels,
+ )
+ )
+
+ num_original_down_blocks = num_resnets
+
+ return diffusers_checkpoint, num_original_down_blocks
+
+
+# .middle_block -> .mid_block
+def unet_midblock_to_diffusers_checkpoint(model, checkpoint, *, original_unet_prefix, num_head_channels):
+ diffusers_checkpoint = {}
+
+ # block 0
+
+ original_block_idx = 0
+
+ diffusers_checkpoint.update(
+ resnet_to_diffusers_checkpoint(
+ checkpoint,
+ diffusers_resnet_prefix="mid_block.resnets.0",
+ resnet_prefix=f"{original_unet_prefix}.middle_block.{original_block_idx}",
+ )
+ )
+
+ original_block_idx += 1
+
+ # optional block 1
+
+ if hasattr(model.mid_block, "attentions") and model.mid_block.attentions[0] is not None:
+ diffusers_checkpoint.update(
+ attention_to_diffusers_checkpoint(
+ checkpoint,
+ diffusers_attention_prefix="mid_block.attentions.0",
+ attention_prefix=f"{original_unet_prefix}.middle_block.{original_block_idx}",
+ num_head_channels=num_head_channels,
+ )
+ )
+ original_block_idx += 1
+
+ # block 1 or block 2
+
+ diffusers_checkpoint.update(
+ resnet_to_diffusers_checkpoint(
+ checkpoint,
+ diffusers_resnet_prefix="mid_block.resnets.1",
+ resnet_prefix=f"{original_unet_prefix}.middle_block.{original_block_idx}",
+ )
+ )
+
+ return diffusers_checkpoint
+
+
+# .output_blocks -> .up_blocks
+def unet_upblock_to_diffusers_checkpoint(
+ model, checkpoint, *, diffusers_up_block_idx, original_up_block_idx, original_unet_prefix, num_head_channels
+):
+ diffusers_checkpoint = {}
+
+ diffusers_resnet_prefix = f"up_blocks.{diffusers_up_block_idx}.resnets"
+ original_up_block_prefix = f"{original_unet_prefix}.output_blocks"
+
+ up_block = model.up_blocks[diffusers_up_block_idx]
+
+ num_resnets = len(up_block.resnets)
+
+ if up_block.upsamplers is None:
+ upsampler = False
+ else:
+ assert len(up_block.upsamplers) == 1
+ upsampler = True
+ # The upsample block is also a resnet
+ num_resnets += 1
+
+ has_attentions = hasattr(up_block, "attentions")
+
+ for resnet_idx_inc in range(num_resnets):
+ if upsampler and resnet_idx_inc == num_resnets - 1:
+ # this is an upsample block
+ if has_attentions:
+ # There is a middle attention block that we skip
+ original_resnet_block_idx = 2
+ else:
+ original_resnet_block_idx = 1
+
+ # we add the `minus 1` because the last two resnets are stuck together in the same output block
+ full_resnet_prefix = (
+ f"{original_up_block_prefix}.{original_up_block_idx + resnet_idx_inc - 1}.{original_resnet_block_idx}"
+ )
+
+ full_diffusers_resnet_prefix = f"up_blocks.{diffusers_up_block_idx}.upsamplers.0"
+ else:
+ # this is a regular resnet block
+ full_resnet_prefix = f"{original_up_block_prefix}.{original_up_block_idx + resnet_idx_inc}.0"
+ full_diffusers_resnet_prefix = f"{diffusers_resnet_prefix}.{resnet_idx_inc}"
+
+ diffusers_checkpoint.update(
+ resnet_to_diffusers_checkpoint(
+ checkpoint, resnet_prefix=full_resnet_prefix, diffusers_resnet_prefix=full_diffusers_resnet_prefix
+ )
+ )
+
+ if has_attentions:
+ num_attentions = len(up_block.attentions)
+ diffusers_attention_prefix = f"up_blocks.{diffusers_up_block_idx}.attentions"
+
+ for attention_idx_inc in range(num_attentions):
+ full_attention_prefix = f"{original_up_block_prefix}.{original_up_block_idx + attention_idx_inc}.1"
+ full_diffusers_attention_prefix = f"{diffusers_attention_prefix}.{attention_idx_inc}"
+
+ diffusers_checkpoint.update(
+ attention_to_diffusers_checkpoint(
+ checkpoint,
+ attention_prefix=full_attention_prefix,
+ diffusers_attention_prefix=full_diffusers_attention_prefix,
+ num_head_channels=num_head_channels,
+ )
+ )
+
+ num_original_down_blocks = num_resnets - 1 if upsampler else num_resnets
+
+ return diffusers_checkpoint, num_original_down_blocks
+
+
+def resnet_to_diffusers_checkpoint(checkpoint, *, diffusers_resnet_prefix, resnet_prefix):
+ diffusers_checkpoint = {
+ f"{diffusers_resnet_prefix}.norm1.weight": checkpoint[f"{resnet_prefix}.in_layers.0.weight"],
+ f"{diffusers_resnet_prefix}.norm1.bias": checkpoint[f"{resnet_prefix}.in_layers.0.bias"],
+ f"{diffusers_resnet_prefix}.conv1.weight": checkpoint[f"{resnet_prefix}.in_layers.2.weight"],
+ f"{diffusers_resnet_prefix}.conv1.bias": checkpoint[f"{resnet_prefix}.in_layers.2.bias"],
+ f"{diffusers_resnet_prefix}.time_emb_proj.weight": checkpoint[f"{resnet_prefix}.emb_layers.1.weight"],
+ f"{diffusers_resnet_prefix}.time_emb_proj.bias": checkpoint[f"{resnet_prefix}.emb_layers.1.bias"],
+ f"{diffusers_resnet_prefix}.norm2.weight": checkpoint[f"{resnet_prefix}.out_layers.0.weight"],
+ f"{diffusers_resnet_prefix}.norm2.bias": checkpoint[f"{resnet_prefix}.out_layers.0.bias"],
+ f"{diffusers_resnet_prefix}.conv2.weight": checkpoint[f"{resnet_prefix}.out_layers.3.weight"],
+ f"{diffusers_resnet_prefix}.conv2.bias": checkpoint[f"{resnet_prefix}.out_layers.3.bias"],
+ }
+
+ skip_connection_prefix = f"{resnet_prefix}.skip_connection"
+
+ if f"{skip_connection_prefix}.weight" in checkpoint:
+ diffusers_checkpoint.update(
+ {
+ f"{diffusers_resnet_prefix}.conv_shortcut.weight": checkpoint[f"{skip_connection_prefix}.weight"],
+ f"{diffusers_resnet_prefix}.conv_shortcut.bias": checkpoint[f"{skip_connection_prefix}.bias"],
+ }
+ )
+
+ return diffusers_checkpoint
+
+
+def attention_to_diffusers_checkpoint(checkpoint, *, diffusers_attention_prefix, attention_prefix, num_head_channels):
+ diffusers_checkpoint = {}
+
+ #